Compare commits
105 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
925be2ec4e | ||
|
bf7de6b4b2 | ||
|
4d41796d61 | ||
|
6c2e3d064d | ||
|
8e15a54d58 | ||
|
ce7c0a2b10 | ||
|
9c05ec148f | ||
|
a905fe6ec5 | ||
|
0042d7e390 | ||
|
ec51a2e9a5 | ||
|
2f8f0042a9 | ||
|
d1c0bf2bea | ||
|
134e246117 | ||
|
58207d4647 | ||
|
d25b7bc115 | ||
|
4781f979b2 | ||
|
061ef3920e | ||
|
dfa0d45ffe | ||
|
b8f2875d28 | ||
|
e78211c8f5 | ||
|
85d1b02e16 | ||
|
5dd4cfc077 | ||
|
90f953931e | ||
|
2b9b284026 | ||
|
a7437e14c5 | ||
|
4e6fe6329f | ||
|
8e8e65023b | ||
|
b0f37346b0 | ||
|
db6bf71564 | ||
|
bfe92a5d91 | ||
|
15f8e8c3b7 | ||
|
b10e9b89e9 | ||
|
cf5a1ee79e | ||
|
c43535c219 | ||
|
9c7bf0b75d | ||
|
f8846d95ff | ||
|
f07995d6f8 | ||
|
0125ab1740 | ||
|
c847cc6320 | ||
|
53fd9631ef | ||
|
ae3c6a1022 | ||
|
9ed392db9c | ||
|
db55ef4fa7 | ||
|
c8b5cb4ad5 | ||
|
14fee16886 | ||
|
d5e1337782 | ||
|
80f03f2a76 | ||
|
c7d4085f3d | ||
|
fbcf36d40a | ||
|
dae6035b5a | ||
|
4165e574a4 | ||
|
d1e39d7f98 | ||
|
4a784ff13a | ||
|
8ab5498d26 | ||
|
7116df3a32 | ||
|
1726197bb4 | ||
|
8ab68f5424 | ||
|
4a61e65fc9 | ||
|
6e77538087 | ||
|
d60b989bec | ||
|
ffb1e14268 | ||
|
7db9f56d61 | ||
|
0b23f95b20 | ||
|
6ecc40e2be | ||
|
112bca3b8f | ||
|
47acb85769 | ||
|
79d098450f | ||
|
aef1b2b3c9 | ||
|
e11077eab9 | ||
|
299cfe743f | ||
|
4f80fea855 | ||
|
af080740f2 | ||
|
0cf2d22ffc | ||
|
836173434e | ||
|
f374d39cbe | ||
|
1b300bdcbf | ||
|
31839cd269 | ||
|
70e6f38fe3 | ||
|
8856b62af9 | ||
|
d913cbd916 | ||
|
75131b4622 | ||
|
ec5c2ad571 | ||
|
1d2b1b72fc | ||
|
069083e28a | ||
|
22fc1cef16 | ||
|
7710e048f4 | ||
|
947ed0d46c | ||
|
bcc5745ea0 | ||
|
e69ed1bdb2 | ||
|
fa066b531c | ||
|
d73de44070 | ||
|
28e19443d0 | ||
|
51025ca2ad | ||
|
6c5284ed3b | ||
|
a2d54b5ad3 | ||
|
81000db9e9 | ||
|
c1920cbabb | ||
|
ced0f5f944 | ||
|
e2c1ac8d24 | ||
|
37ee888d88 | ||
|
ef122ca9cf | ||
|
9be1d5afc2 | ||
|
6eddf940d2 | ||
|
db7a72bfe3 | ||
|
0cbee3ec56 |
42
.ci/azure-pipelines.yml
Normal file
42
.ci/azure-pipelines.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Triton CI
|
||||
pool:
|
||||
name: default
|
||||
|
||||
# Some variables
|
||||
variables:
|
||||
- name: venv
|
||||
value: venv
|
||||
|
||||
# Run CI when something pushed to master
|
||||
trigger: none
|
||||
# Run CI when a PR is created or updated from master
|
||||
pr:
|
||||
- master
|
||||
|
||||
# Pipeline
|
||||
steps:
|
||||
- script: |
|
||||
mkdir $(venv)
|
||||
python -m virtualenv --python=python3 $(venv)
|
||||
source $(venv)/bin/activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 \
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
cd python
|
||||
python setup.py install
|
||||
displayName: Setup python environment
|
||||
|
||||
- script: |
|
||||
source $(venv)/bin/activate
|
||||
pip install matplotlib pandas
|
||||
cd python/bench
|
||||
python -m run
|
||||
|
||||
- publish: python/bench/results
|
||||
artifact: Benchmarks
|
||||
|
||||
- script: |
|
||||
source $(venv)/bin/activate
|
||||
pip install pytest
|
||||
pytest .
|
||||
displayName: 'Run Python tests'
|
35
.ci/build-wheels.yml
Normal file
35
.ci/build-wheels.yml
Normal file
@@ -0,0 +1,35 @@
|
||||
trigger: none
|
||||
pr: none
|
||||
|
||||
jobs:
|
||||
- job: linux
|
||||
|
||||
timeoutInMinutes: 180
|
||||
|
||||
pool: default
|
||||
|
||||
steps:
|
||||
- bash: |
|
||||
set -o errexit
|
||||
python3 --version
|
||||
python3 -m pip install --upgrade pip
|
||||
pip3 install cibuildwheel==1.10.0
|
||||
pip3 install twine
|
||||
displayName: Install dependencies
|
||||
- bash: |
|
||||
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev`date '+%Y%m%d'`\"/g" python/setup.py
|
||||
echo "" >> python/setup.cfg
|
||||
echo "[build_ext]" >> python/setup.cfg
|
||||
echo "base-dir=/project" >> python/setup.cfg
|
||||
displayName: Patch setup.py
|
||||
- bash: |
|
||||
export CIBW_BEFORE_BUILD="pip install cmake"
|
||||
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
|
||||
python3 -m cibuildwheel python --output-dir wheelhouse
|
||||
displayName: Build wheels
|
||||
- task: PublishBuildArtifacts@1
|
||||
inputs: {pathtoPublish: 'wheelhouse'}
|
||||
- bash: |
|
||||
python3 -m twine upload wheelhouse/* --skip-existing -u $(PYPI_USERNAME) -p $(PYPI_PASSWORD)
|
||||
displayName: Upload wheels to PyPI
|
@@ -1,18 +1,19 @@
|
||||
cmake_minimum_required(VERSION 2.8)
|
||||
cmake_minimum_required(VERSION 3.6)
|
||||
include(ExternalProject)
|
||||
|
||||
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||
endif()
|
||||
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
|
||||
# Options
|
||||
option(BUILD_TESTS "Build C++ Triton tests" ON)
|
||||
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
|
||||
# LLVM
|
||||
find_package(LLVM REQUIRED)
|
||||
link_directories(${LLVM_LIBRARY_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
|
||||
# Default build type
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
message(STATUS "Default build type: Release")
|
||||
@@ -21,22 +22,65 @@ endif()
|
||||
|
||||
# Compiler flags
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++11")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fvisibility=default -std=gnu++14")
|
||||
|
||||
# Tests
|
||||
if(BUILD_TESTS)
|
||||
message(STATUS "Adding C++ tests")
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
|
||||
##########
|
||||
# LLVM
|
||||
##########
|
||||
get_cmake_property(_variableNames VARIABLES)
|
||||
set(__variableNames ${_variableNames})
|
||||
|
||||
configure_file(cmake/DownloadLLVM.in ${TRITON_LLVM_BUILD_DIR}/llvm-download/CMakeLists.txt)
|
||||
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
|
||||
WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download"
|
||||
)
|
||||
execute_process(COMMAND "${CMAKE_COMMAND}" --build .
|
||||
WORKING_DIRECTORY "${TRITON_LLVM_BUILD_DIR}/llvm-download"
|
||||
)
|
||||
set(LLVM_TARGETS_TO_BUILD "NVPTX" CACHE INTERNAL "")
|
||||
set(LLVM_BUILD_RUNTIME "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_BUILD_RUNTIMES "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_BUILD_TOOLS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_BUILD_UTILS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_BENCHMARKS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_DOCS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_EXAMPLES "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_GO_TESTS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_RUNTIME "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_TESTS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_TOOLS "OFF" CACHE INTERNAL "")
|
||||
set(LLVM_INCLUDE_UTILS "OFF" CACHE INTERNAL "")
|
||||
add_subdirectory(${TRITON_LLVM_BUILD_DIR}/llvm-src
|
||||
${TRITON_LLVM_BUILD_DIR}/llvm-build)
|
||||
get_property(LLVM_LIBRARIES GLOBAL PROPERTY LLVM_COMPONENT_LIBS)
|
||||
# remove LLVM-specific variables so we don't pollute GUI
|
||||
get_cmake_property(_variableNames VARIABLES)
|
||||
list(REMOVE_ITEM _variableNames ${__variableNames})
|
||||
list(REMOVE_ITEM _variableNames ${LLVM_LIBRARIES})
|
||||
foreach (_variableName ${_variableNames})
|
||||
unset(${_variableName} CACHE)
|
||||
endforeach()
|
||||
include_directories("${TRITON_LLVM_BUILD_DIR}/llvm-build/include/"
|
||||
"${TRITON_LLVM_BUILD_DIR}/llvm-src/include/")
|
||||
|
||||
# Python module
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# PyBind11 wrapper source file
|
||||
set(PYTHON_SRC bindings.cc launch.cc)
|
||||
set_source_files_properties(launch.cc PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
include_directories("." ${PYTHON_INCLUDE_DIRS})
|
||||
link_directories(${PYTHON_LINK_DIRS})
|
||||
# Build CUTLASS python wrapper if requested
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
|
||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||
set(CUTLASS_LIBRARIES "cutlass.a")
|
||||
endif()
|
||||
message(STATUS ${CUTLASS_INCLUDE_PATH})
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -46,5 +90,11 @@ add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES})
|
||||
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES})
|
||||
target_link_libraries(triton ${TORCH_LIBRARIES} ${CUTLASS_LIBRARIES})
|
||||
endif()
|
||||
|
||||
# Tutorials
|
||||
if(BUILD_TUTORIALS)
|
||||
message(STATUS "Adding C++ tutorials")
|
||||
add_subdirectory(tutorials)
|
||||
endif()
|
||||
|
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
/* Copyright 2018-2020 Philippe Tillet
|
||||
/* Copyright 2018-2021 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
|
47
README.md
47
README.md
@@ -2,46 +2,17 @@
|
||||
|
||||
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
|
||||
|
||||
The main components of Triton at the moment are:
|
||||
- **Triton-C**: An imperative, single-threaded language for writing highly efficient compute-kernels at a relatively high abstraction level (think numpy-like array operations in a C-like language).
|
||||
- **Triton-IR**: A special-purpose intermediate representation (Triton-IR) for aiding array-level program analysis and optimizations in Triton-C programs.
|
||||
- **Triton-JIT**: An optimizing just-in-time compiler for Triton-IR, which generates GPU code on par with state-of-the-art CUDA-C (e.g., [CUTLASS](https://github.com/NVIDIA/cutlass)). This includes transparent support for mixed-precision and Tensor Cores.
|
||||
[](https://dev.azure.com/triton-lang/Triton/_build/latest?definitionId=10&branchName=master)
|
||||
|
||||
Bindings for **automatic** PyTorch custom op generations are included in **PyTriton**, along with a small DSL based on einsum that supports convolutions, shift-convolutions, direct einsums, etc.
|
||||
The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
|
||||
|
||||
The formal foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing us if you use our work!
|
||||
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
|
||||
|
||||
# Compatibility
|
||||
|
||||
## Installation
|
||||
Supported Platforms:
|
||||
* Linux
|
||||
|
||||
Triton is a fairly self-contained package and uses its own parser (forked from [wgtcc](https://github.com/wgtdkp/wgtcc)) and LLVM-8.0+ for code generation.
|
||||
|
||||
You can install the latest release with pip as follows:
|
||||
```
|
||||
sudo apt-get install llvm-9-dev
|
||||
pip install triton
|
||||
```
|
||||
|
||||
or the latest development version with:
|
||||
```
|
||||
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"
|
||||
```
|
||||
|
||||
for the C++ package:
|
||||
```
|
||||
git clone https://github.com/ptillet/triton.git;
|
||||
mkdir build;
|
||||
cd build;
|
||||
cmake ../;
|
||||
make -j8;
|
||||
```
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
Please visit the [documentation](https://docs.triton-lang.org) to get started with Triton
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
Please keep in mind that this is a project I have been carrying out completely on my own as part of my Ph.D. thesis. While I am confident in the approach, there are still many things to fix and to polish. Please contact me (ptillet AT g.harvard.edu) or raise an issue if you want to contribute!
|
||||
Supported Hardware:
|
||||
* NVIDIA GPUs (Compute Capability 7.0+)
|
||||
* Under development: AMD GPUs, CPUs
|
||||
|
15
cmake/DownloadLLVM.in
Normal file
15
cmake/DownloadLLVM.in
Normal file
@@ -0,0 +1,15 @@
|
||||
cmake_minimum_required(VERSION 3.6)
|
||||
|
||||
project(llvm-download NONE)
|
||||
include(ExternalProject)
|
||||
|
||||
|
||||
ExternalProject_Add(llvm
|
||||
URL "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz"
|
||||
SOURCE_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-src"
|
||||
BINARY_DIR "${TRITON_LLVM_BUILD_DIR}/llvm-build"
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
)
|
@@ -28,7 +28,9 @@
|
||||
# We also want an user-specified LLVM_ROOT_DIR to take precedence over the
|
||||
# system default locations such as /usr/local/bin. Executing find_program()
|
||||
# multiples times is the approach recommended in the docs.
|
||||
set(llvm_config_names llvm-config-9 llvm-config-9.0 llvm-config90
|
||||
set(llvm_config_names llvm-config-11 llvm-config-11.0
|
||||
llvm-config-10 llvm-config-10.0 llvm-config100
|
||||
llvm-config-9 llvm-config-9.0 llvm-config90
|
||||
llvm-config-8 llvm-config-8.0 llvm-config80
|
||||
llvm-config)
|
||||
find_program(LLVM_CONFIG
|
||||
|
40
docs/conf.py
40
docs/conf.py
@@ -21,7 +21,6 @@
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
@@ -33,6 +32,20 @@
|
||||
# ones.
|
||||
extensions = []
|
||||
|
||||
# Math Jax
|
||||
extensions += ['sphinx.ext.mathjax']
|
||||
|
||||
# Sphinx gallery
|
||||
extensions += ['sphinx_gallery.gen_gallery']
|
||||
from sphinx_gallery.sorting import FileNameSortKey
|
||||
sphinx_gallery_conf = {
|
||||
'examples_dirs': '../python/tutorials/',
|
||||
'gallery_dirs': 'getting-started/tutorials',
|
||||
'filename_pattern': '',
|
||||
'ignore_pattern': r'__init__\.py',
|
||||
'within_subsection_order': FileNameSortKey,
|
||||
}
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
@@ -77,7 +90,6 @@ pygments_style = 'sphinx'
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = False
|
||||
|
||||
|
||||
# -- Options for HTML output ----------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
@@ -97,6 +109,9 @@ html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_css_files = [
|
||||
'css/custom.css',
|
||||
]
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
@@ -110,13 +125,11 @@ html_sidebars = {
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'Tritondoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
@@ -141,20 +154,14 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'Triton.tex', 'Triton Documentation',
|
||||
'Philippe Tillet', 'manual'),
|
||||
(master_doc, 'Triton.tex', 'Triton Documentation', 'Philippe Tillet', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'triton', 'Triton Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
man_pages = [(master_doc, 'triton', 'Triton Documentation', [author], 1)]
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
|
||||
@@ -162,10 +169,5 @@ man_pages = [
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'Triton', 'Triton Documentation',
|
||||
author, 'Triton', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
|
||||
(master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.', 'Miscellaneous'),
|
||||
]
|
61
docs/getting-started/installation.rst
Normal file
61
docs/getting-started/installation.rst
Normal file
@@ -0,0 +1,61 @@
|
||||
==============
|
||||
Installation
|
||||
==============
|
||||
|
||||
---------------------
|
||||
Binary Distributions
|
||||
---------------------
|
||||
|
||||
You can install the latest nightly release of Triton from pip:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U --pre triton
|
||||
|
||||
|
||||
--------------
|
||||
From Source
|
||||
--------------
|
||||
|
||||
+++++++++++++++
|
||||
Python Package
|
||||
+++++++++++++++
|
||||
|
||||
You can install the Python package from source by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/ptillet/triton.git;
|
||||
cd triton/python;
|
||||
pip install -e .
|
||||
|
||||
This may take a while (10-20 minutes) as it will download and compile LLVM from source.
|
||||
|
||||
You can then test your installation by running the unit tests:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pytest -vs .
|
||||
|
||||
and the benchmarks
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd bench/
|
||||
python -m run --with-plots --result-dir /tmp/triton-bench
|
||||
|
||||
+++++++++++++++
|
||||
C++ Package
|
||||
+++++++++++++++
|
||||
|
||||
Those not interested in Python integration may want to use the internals of Triton (i.e, runtime, parser, codegen, driver, intermediate representation) directly. This can be done by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/ptillet/triton.git;
|
||||
mkdir build;
|
||||
cd build;
|
||||
cmake ../;
|
||||
make -j8;
|
||||
|
||||
Note that while direct usage of the C++ API is not officially supported, a usage tutorial can be found `here <https://github.com/ptillet/triton/blob/master/tutorials/01-matmul.cc>`_
|
@@ -1,22 +1,38 @@
|
||||
.. Triton documentation master file, created by
|
||||
sphinx-quickstart on Mon Feb 10 01:01:37 2020.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to Triton's documentation!
|
||||
==================================
|
||||
|
||||
Triton is an imperative language and compiler for parallel programming. It aims to provide a programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
- Follow the :doc:`installation instructions <getting-started/installation>` for your platform of choice.
|
||||
- Take a look at the :doc:`tutorials <getting-started/tutorials/index>` to learn how to write your first Triton program.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Contents:
|
||||
:caption: Getting Started
|
||||
:hidden:
|
||||
|
||||
installation/index
|
||||
tutorials/index
|
||||
getting-started/installation
|
||||
getting-started/tutorials/index
|
||||
|
||||
Programming Guide
|
||||
------------------
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs:
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
||||
- Chapter 1: :doc:`Introduction <programming-guide/chapter-1/introduction>`
|
||||
- Chapter 2: :doc:`Related Work <programming-guide/chapter-2/related-work>`
|
||||
- Chapter 3: :doc:`The Triton-C Language <programming-guide/chapter-3/triton-c>`
|
||||
- Chapter 4: :doc:`The Triton-IR Intermediate Representation <programming-guide/chapter-4/triton-ir>`
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Programming Guide
|
||||
:hidden:
|
||||
|
||||
programming-guide/chapter-1/introduction
|
||||
programming-guide/chapter-2/related-work
|
||||
programming-guide/chapter-3/triton-c
|
||||
programming-guide/chapter-4/triton-ir
|
@@ -1,21 +0,0 @@
|
||||
***************
|
||||
From Source
|
||||
***************
|
||||
|
||||
Triton is a fairly self-contained package and uses its own parser (forked from `wgtcc <https://github.com/wgtdkp/wgtcc>`_) and LLVM-8.0+ for code generation.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
sudo apt-get install llvm-8-dev
|
||||
git clone https://github.com/ptillet/triton.git;
|
||||
cd triton/python/;
|
||||
python setup.py develop;
|
||||
|
||||
This should take about 15-20 seconds to compile on a modern machine.
|
||||
|
||||
You can then test your installation by running the *einsum.py* example in an environment that contains pytorch:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd examples;
|
||||
python einsum.py
|
@@ -1,16 +0,0 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
Triton can be installed directly from pip with the following command
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install triton
|
||||
|
||||
|
||||
See the information below for more detailed information on custom builds.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
from-source
|
BIN
docs/programming-guide/chapter-1/cuda-parallel-matmul.png
Normal file
BIN
docs/programming-guide/chapter-1/cuda-parallel-matmul.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 9.5 KiB |
69
docs/programming-guide/chapter-1/introduction.rst
Normal file
69
docs/programming-guide/chapter-1/introduction.rst
Normal file
@@ -0,0 +1,69 @@
|
||||
==============
|
||||
Introduction
|
||||
==============
|
||||
|
||||
--------------
|
||||
Motivations
|
||||
--------------
|
||||
|
||||
Over the past decade, Deep Neural Networks (DNNs) have emerged as an important class of Machine Learning (ML) models, capable of achieving state-of-the-art performance across many domains ranging from natural language processing [SUTSKEVER2014]_ to computer vision [REDMON2016]_ to computational neuroscience [LEE2017]_. The strength of these models lies in their hierarchical structure, composed of a sequence of parametric (e.g., convolutional) and non-parametric (e.g., rectified linearity) *layers*. This pattern, though notoriously computationally expensive, also generates a large amount of highly parallelizable work particularly well suited for multi- and many- core processors.
|
||||
|
||||
As a consequence, Graphics Processing Units (GPUs) have become a cheap and accessible resource for exploring and/or deploying novel research ideas in the field. This trend has been accelerated by the release of several frameworks for General-Purpose GPU (GPGPU) computing, such as CUDA and OpenCL, which have made the development of high-performance programs easier. Yet, GPUs remain incredibly challenging to optimize for locality and parallelism, especially for computations that cannot be efficiently implemented using a combination of pre-existing optimized primitives. To make matters worse, GPU architectures are also rapidly evolving and specializing, as evidenced by the addition of tensor cores to NVIDIA (and more recently AMD) micro-architectures.
|
||||
|
||||
This tension between the computational opportunities offered by DNNs and the practical difficulty of GPU programming has created substantial academic and industrial interest for Domain-Specific Languages (DSLs) and compilers. Regrettably, these systems -- whether they be based on polyhedral machinery (*e.g.*, Tiramisu [BAGHDADI2021]_, Tensor Comprehensions [VASILACHE2018]_) or scheduling languages (*e.g.*, Halide [JRK2013]_, TVM [CHEN2018]_) -- remain less flexible and (for the same algorithm) markedly slower than the best handwritten compute kernels available in libraries like `cuBLAS <https://docs.nvidia.com/cuda/cublas/index.html>`_, `cuDNN <https://docs.nvidia.com/deeplearning/cudnn/api/index.html>`_ or `TensorRT <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html>`_.
|
||||
|
||||
The main premise of this project is the following: programming paradigms based on blocked algorithms [LAM1991]_ can facilitate the construction of high-performance compute kernels for neural networks. We specifically revisit traditional "Single Program, Multiple Data" (SPMD [AUGUIN1983]_) execution models for GPUs, and propose a variant in which programs -- rather than threads -- are blocked. For example, in the case of matrix multiplication, CUDA and Triton differ as follows:
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| CUDA Programming Model | Triton Programming Model |
|
||||
| | |
|
||||
| (Scalar Program, Blocked Threads) | (Blocked Program, Scalar Threads) |
|
||||
+=====================================================+=====================================================+
|
||||
| | |
|
||||
|.. code-block:: C |.. code-block:: C |
|
||||
| | :force: |
|
||||
| | |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int m = 0; i < M; m++) | for(int m = 0; m < M; m += MB) |
|
||||
| #pragma parallel | #pragma parallel |
|
||||
| for(int n = 0; j < N; n++){ | for(int n = 0; n < N; n += NB){ |
|
||||
| float acc = 0; | float acc[MB, NB] = 0; |
|
||||
| for(int k = 0; k < K;k ++) | for(int k = 0; k < K; k += KB) |
|
||||
| acc += A[i, k]* B[k, j]; | acc += A[m:m+MB, k:k+KB] |
|
||||
| | @ B[k:k+KB, n:n+NB]; |
|
||||
| C[i, j] = acc; | C[m:m+MB, n:n+NB] = acc; |
|
||||
| } | } |
|
||||
| | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| |pic1| | |pic2| |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
|
||||
.. |pic1| image:: cuda-parallel-matmul.png
|
||||
|
||||
.. |pic2| image:: triton-parallel-matmul.png
|
||||
|
||||
A key benefit of this approach is that it leads to block-structured iteration spaces that offer programmers more flexibility than existing DSLs when implementing sparse operations, all while allowing compilers to aggressively optimize programs for data locality and parallelism.
|
||||
|
||||
--------------
|
||||
Challenges
|
||||
--------------
|
||||
|
||||
The main challenge posed by our proposed paradigm is that of work scheduling, i.e., how the work done by each program instance should be partitioned for efficient execution on modern GPUs. To address this issue, the Triton compiler makes heavy use of *block-level data-flow analysis*, a technique for scheduling iteration blocks statically based on the control- and data-flow structure of the target program. The resulting system actually works surprisingly well: our compiler manages to apply a broad range of interesting optimization automatically (e.g., automatic coalescing, thread swizzling, pre-fetching, automatic vectorization, tensor core-aware instruction selection, shared memory allocation/synchronization, asynchronous copy scheduling). Of course doing all this is not trivial; one of the purposes of this guide is to give you a sense of how it works.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [SUTSKEVER2014] I. Sutskever et al., "Sequence to Sequence Learning with Neural Networks", NIPS 2014
|
||||
.. [REDMON2016] J. Redmon et al., "You Only Look Once: Unified, Real-Time Object Detection", CVPR 2016
|
||||
.. [LEE2017] K. Lee et al., "Superhuman Accuracy on the SNEMI3D Connectomics Challenge", ArXiV 2017
|
||||
.. [BAGHDADI2021] R. Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [VASILACHE2018] N. Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [JRK2013] J. Ragan-Kelley et al., "Halide: A Language and Compiler for Optimizing Parallelism, Locality, and Recomputation in Image Processing Pipelines", PLDI 2013
|
||||
.. [CHEN2018] T. Chen et al., "TVM: An Automated End-to-End Optimizing Compiler for Deep Learning", OSDI 2018
|
||||
.. [LAM1991] M. Lam et al., "The Cache Performance and Optimizations of Blocked Algorithms", ASPLOS 1991
|
||||
.. [AUGUIN1983] M. Auguin et al., "Opsila: an advanced SIMD for numerical analysis and signal processing", EUROMICRO 1983
|
BIN
docs/programming-guide/chapter-1/triton-parallel-matmul.png
Normal file
BIN
docs/programming-guide/chapter-1/triton-parallel-matmul.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.0 KiB |
BIN
docs/programming-guide/chapter-2/halide-iteration.png
Normal file
BIN
docs/programming-guide/chapter-2/halide-iteration.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
BIN
docs/programming-guide/chapter-2/polyhedral-iteration.png
Normal file
BIN
docs/programming-guide/chapter-2/polyhedral-iteration.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 59 KiB |
209
docs/programming-guide/chapter-2/related-work.rst
Normal file
209
docs/programming-guide/chapter-2/related-work.rst
Normal file
@@ -0,0 +1,209 @@
|
||||
==============
|
||||
Related Work
|
||||
==============
|
||||
|
||||
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlights its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
|
||||
|
||||
-----------------------
|
||||
Polyhedral Compilation
|
||||
-----------------------
|
||||
|
||||
Traditional compilers typically rely on intermediate representations, such as LLVM-IR [LATTNER2004]_, that encode control flow information using (un)conditional branches. This relatively low-level format makes it difficult to statically analyze the runtime behavior (e.g., cache misses) of input programs, and to automatically optimize loops accordingly through the use of tiling [WOLFE1989]_, fusion [DARTE1999]_ and interchange [ALLEN1984]_. To solve this issue, polyhedral compilers [ANCOURT1991]_ rely on program representations that have statically predictable control flow, thereby enabling aggressive compile-time program transformations for data locality and parallelism. Though this strategy has been adopted by many languages and compilers for DNNs such as Tiramisu [BAGHDADI2021]_, Tensor Comprehensions [VASILACHE2018]_, Diesel [ELANGO2018]_ and the Affine dialect in MLIR [LATTNER2019]_, it also comes with a number of limitations that will be described later in this section.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Program Representation
|
||||
+++++++++++++++++++++++
|
||||
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic1| |
|
||||
| | |
|
||||
| for(int i = 0; i < 3; i++) | |
|
||||
| for(int j = i; j < 5; j++) | |
|
||||
| A[i][j] = 0; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic1| image:: polyhedral-iteration.png
|
||||
:width: 300
|
||||
|
||||
Polyhedral compilers focus on a class of programs commonly known as **Static Control Parts** (SCoP), *i.e.*, maximal sets of consecutive statements in which conditionals and loop bounds are affine functions of surrounding loop indices and global invariant parameters. As shown above, programs in this format always lead to iteration domains that are bounded by affine inequalities, i.e., polyhedral. These polyhedra can also be defined algebraically; for the above example:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathcal{P} = \{ i, j \in \mathbb{Z}^2
|
||||
~|~
|
||||
\begin{pmatrix}
|
||||
1 & 0 \\
|
||||
-1 & 0 \\
|
||||
-1 & 1 \\
|
||||
0 & -1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i \\
|
||||
j
|
||||
\end{pmatrix}
|
||||
+
|
||||
\begin{pmatrix}
|
||||
0 \\
|
||||
2 \\
|
||||
0 \\
|
||||
4
|
||||
\end{pmatrix}
|
||||
\geq
|
||||
0
|
||||
\}
|
||||
|
||||
|
||||
Each point :math:`(i, j)` in :math:`\mathcal{P}` represents a *polyhedral statement*, that is a program statement which (1) does not induce control-flow side effects (e.g., :code:`for`, :code:`if`, :code:`break`) and (2) contains only affine functions of loop indices and global parameters in array accesses. To facilitate alias analysis, array accesses are also mathematically abstracted, using so-called *access function*. In other words, :code:`A[i][j]` is simply :code:`A[f(i,j)]` where the access function :math:`f` is defined by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(i, j) = \begin{pmatrix}
|
||||
1 & 0\\
|
||||
0 & 1\\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i\\
|
||||
j
|
||||
\end{pmatrix}
|
||||
=
|
||||
(i, j)
|
||||
|
||||
|
||||
Note that the iteration domains of an SCoP does not specify the order in which its statements shall execute. In fact, this iteration domain may be traversed in many different possible legal orders, i.e. *schedules*. Formally, a schedule is defined as a p-dimensional affine transformation :math:`\Theta` of loop indices :math:`\mathbf{x}` and global invariant parameters :math:`\mathbf{g}`:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = T_S \begin{pmatrix}
|
||||
\vec{x}\\
|
||||
\vec{g}\\
|
||||
1
|
||||
\end{pmatrix}
|
||||
\qquad
|
||||
T_S \in \mathbb{Z} ^{p \times (\text{dim}(\mathbf{x}) + \text{dim}(\mathbf{g}) + 1)}
|
||||
|
||||
|
||||
Where :math:`\Theta_S(\mathbf{x})` is a p-dimensional vector representing the slowest to fastest growing indices (from left to right) when traversing the loop nest surrounding :math:`S`. For the code shown above, the original schedule defined by the loop nest in C can be retrieved by using:
|
||||
|
||||
.. math::
|
||||
\Theta_S(\mathbf{x}) = \begin{pmatrix}
|
||||
1 & 0 \\
|
||||
0 & 1 \\
|
||||
\end{pmatrix}
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
=
|
||||
\begin{pmatrix}
|
||||
i & j
|
||||
\end{pmatrix}^T
|
||||
|
||||
|
||||
where :math:`i` and :math:`j` are respectively the slowest and fastest growing loop indices in the nest. If :math:`T_S` is a vector (resp. tensor), then :math:`\Theta_S` is a said to be one-dimensional (resp. multi-dimensional).
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
Programs amenable to polyhedral compilation can be aggressively transformed and optimized. Most of these transformations actually boil down to the production of schedules and iteration domains that enable loop transformations promoting parallelism and spatial/temporal data locality (e.g., fusion, interchange, tiling, parallelization).
|
||||
|
||||
Polyhedral compilers can also automatically go through complex verification processes to ensure that the semantics of their input program is preserved throughout this optimization phase. Note that polyhedral optimizers are not incompatible with more standard optimization techniques. In fact, it is not uncommon for these systems to be implemented as a set of LLVM passes that can be run ahead of more traditional compilation techniques [GROSSER2012]_.
|
||||
|
||||
All in all, polyhedral machinery is extremely powerful, when applicable. It has been shown to support most common loop transformations, and has indeed achieved performance comparable to state-of-the-art GPU libraries for dense matrix multiplication [ELANGO2018]_. Additionally, it is also fully automatic and doesn't require any hint from programmers apart from source-code in a C-like format.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks.
|
||||
|
||||
First, the set of possible program transformations $\Omega = \{ \Theta_S ~|~ S \in \text{program} \}$ is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_.
|
||||
|
||||
Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [GIRBAL2006]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years.
|
||||
|
||||
On the other hand, blocked program representations advocated by this dissertation are less restricted in scope and can achieve close to peak performance using standard dataflow analysis.
|
||||
|
||||
-----------------------
|
||||
Scheduling Languages
|
||||
-----------------------
|
||||
|
||||
Separation of concerns \cite{dijkstra82} is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
|
||||
|
||||
.. code-block:: python
|
||||
:linenos:
|
||||
|
||||
// algorithm
|
||||
Var x("x"), y("y");
|
||||
Func matmul("matmul");
|
||||
RDom k(0, matrix_size);
|
||||
RVar ki;
|
||||
matmul(x, y) = 0.0f;
|
||||
matmul(x, y) += A(k, y) * B(x, k);
|
||||
// schedule
|
||||
Var xi("xi"), xo("xo"), yo("yo"), yi("yo"), yii("yii"), xii("xii");
|
||||
matmul.vectorize(x, 8);
|
||||
matmul.update(0)
|
||||
.split(x, x, xi, block_size).split(xi, xi, xii, 8)
|
||||
.split(y, y, yi, block_size).split(yi, yi, yii, 4)
|
||||
.split(k, k, ki, block_size)
|
||||
.reorder(xii, yii, xi, ki, yi, k, x, y)
|
||||
.parallel(y).vectorize(xii).unroll(xi).unroll(yii);
|
||||
|
||||
|
||||
The resulting code may however not be completely portable, as schedules can sometimes rely on execution models (e.g., SPMD) or hardware intrinsics (e.g., matrix-multiply-accumulate) that are not widely available. This issue can be mitigated by auto-scheduling mechanisms [MULLAPUDI2016]_.
|
||||
|
||||
+++++++++++
|
||||
Advantages
|
||||
+++++++++++
|
||||
|
||||
The main advantage of this approach is that it allows programmers to write an algorithm *only once*, and focus on performance optimization separately. It makes it possible to manually specify optimizations that a polyhedral compiler wouldn't be able to figure out automatically using static data-flow analysis.
|
||||
|
||||
Scheduling languages are, without a doubt, one of the most popular approaches for neural network code generation. The most popular system for this purpose is probably TVM, which provides good performance across a wide range of platforms as well as built-in automatic scheduling mechanisms.
|
||||
|
||||
++++++++++++
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse com-putations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
| | |
|
||||
|.. code-block:: C | |pic2| |
|
||||
| | |
|
||||
| for(int i = 0; i < 4; i++) | |
|
||||
| for(int j = 0; j < 4; j++) | |
|
||||
| float acc = 0; | |
|
||||
| for(int k = 0; k < K[i]; k++) | |
|
||||
| acc += A[i][col[i,k]]*B[k][j] | |
|
||||
| C[i][j] = acc; | |
|
||||
+-----------------------------------------------------+-----------------------------------------------------+
|
||||
|
||||
.. |pic2| image:: halide-iteration.png
|
||||
:width: 300
|
||||
|
||||
On the other hand, the block-based program representation that we advocate for through this work allows for block-structured iteration spaces and allows programmers to manually handle load-balancing as they wish.
|
||||
|
||||
--------------
|
||||
References
|
||||
--------------
|
||||
|
||||
.. [LATTNER2004] C. Lattner et al., "LLVM: a compilation framework for lifelong program analysis transformation", CGO 2004
|
||||
.. [WOLFE1989] M. Wolfe, "More Iteration Space Tiling", SC 1989
|
||||
.. [DARTE1999] A. Darte, "On the Complexity of Loop Fusion", PACT 1999
|
||||
.. [ALLEN1984] J. Allen et al., "Automatic Loop Interchange", SIGPLAN Notices 1984
|
||||
.. [ANCOURT1991] C. Ancourt et al., "Scanning Polyhedra with DO Loops", PPoPP 1991
|
||||
.. [BAGHDADI2021] R. Baghdadi et al., "Tiramisu: A Polyhedral Compiler for Expressing Fast and Portable Code", CGO 2021
|
||||
.. [VASILACHE2018] N. Vasilache et al., "Tensor Comprehensions: Framework-Agnostic High-Performance Machine Learning Abstractions", ArXiV 2018
|
||||
.. [ELANGO2018] V. Elango et al. "Diesel: DSL for Linear Algebra and Neural Net Computations on GPUs", MAPL 2018
|
||||
.. [LATTNER2019] C. Lattner et al., "MLIR Primer: A Compiler Infrastructure for the End of Moore’s Law", Arxiv 2019
|
||||
.. [GROSSER2012] T. Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012
|
||||
.. [SATO2019] Y. Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019
|
||||
.. [GIRBAL2006] S. Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006
|
||||
.. [MULLAPUDI2016] R. Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016
|
84
docs/programming-guide/chapter-3/triton-c.rst
Normal file
84
docs/programming-guide/chapter-3/triton-c.rst
Normal file
@@ -0,0 +1,84 @@
|
||||
=======================
|
||||
The Triton-C Language
|
||||
=======================
|
||||
|
||||
In the introduction, we stressed the importance of blocked algorithms and described their core principles in pseudo-code. To facilitate their implementation on modern GPU hardware, we present Triton-C, a single-threaded imperative kernel language in which block variables are first-class citizen. This language may be used either directly by developers familiar with C, or as an intermediate language for existing (and future) transcompilers. In this chapter, we describe its differences with C, its Numpy-like semantics and its "Single-Program, Multiple-Data" (SPMD) programming model.
|
||||
|
||||
-------------------
|
||||
Differences with C
|
||||
-------------------
|
||||
|
||||
The syntax of Triton-C is based on that of ANSI C, but was modified and extended to accomodate the semantics and programming model described in the next two subsections. These changes fall into the following categories:
|
||||
|
||||
+++++++++++
|
||||
Extensions
|
||||
+++++++++++
|
||||
|
||||
**Variable declarations**: Triton adds special-purpose syntax for multi-dimensional array declarations (e.g., :code:`int block[16, 16]`), which purposely differs from that of nested arrays (i.e., arrays of pointers) found in ANSI C (e.g., :code:`int block[16][16]`). Block dimensions must be constant but can also be made parametric with the use of pre-processor macros. One-dimensional blocks of integers may be initialized using ellipses (e.g., :code:`int range[16] = 0 ... 16`).
|
||||
|
||||
**Primitive types**: Triton-C supports the following primitive data-types: :code:`bool`, :code:`uint8`, :code:`uint16`, :code:`uint32`, :code:`uint64`, :code:`int8`, :code:`int16`, :code:`int32`, :code:`int64`, :code:`half`, :code:`float`, :code:`double`.
|
||||
|
||||
**Operators and built-in function**: The usual C operators were extended to support element-wise array operations (:code:`+`, :code:`-`, :code:`&&`, :code:`*`, etc.) and complex array operations(:code:`@` for matrix multiplication). Additionally, some built-in functions were added for concurrency (:code:`get_program_id`, :code:`atomic_add`).
|
||||
|
||||
**Slicing and broadcasting**: Multi-dimensional blocks can be broadcast along any particular dimension using numpy-like slicing syntax (e.g., :code:`int array[8, 8] = range[:, newaxis]` for stacking columns). Note that, as of now, slicing blocks to retrieve sub-blocks (or scalars) is forbidden as it is incompatible with the automatic parallelization methods used by our JIT. Reductions can be achieved using a syntax similar to slicing (e.g., :code:`array[+]` for summing an array, or :code:`array[:, max]` for row-wise maximum). Currently supported reduction operators are :code:`+`, :code:`min`, :code:`max`.
|
||||
|
||||
**Masked pointer dereferencement**: Block-level operations in Triton-C are "atomic", in the sense that they execute either completely or not at all. Basic element-wise control-flow for block-level operations can nonetheless be achieved using ternary operators and the *masked pointer dereferencement* operator exemplified below:
|
||||
|
||||
.. code-block:: C
|
||||
:force:
|
||||
|
||||
// create mask
|
||||
bool mask[16, 16] = ...;
|
||||
// conditional addition
|
||||
float x[16, 16] = mask ? a + b : 0;
|
||||
// conditional load
|
||||
float y[16] 16] = mask ? *ptr : 0;
|
||||
// conditional store
|
||||
*?(mask)ptr = y;
|
||||
\end{lstlisting}
|
||||
|
||||
|
||||
+++++++++++++
|
||||
Restrictions
|
||||
+++++++++++++
|
||||
|
||||
The Triton project is still in its infancy. As such, there are quite a few features of ANSI C that are not supported:
|
||||
|
||||
**Non-kernel functions**: Right now, all function definitions must be kernels, i.e. be preceded with the :code:`__global__` attribute. We are aware that this is a severe limitations, and the reason why it exists is because our automatic parallelization engine would not be capable of handling array parameter arguments.
|
||||
|
||||
**Non-primitive types**: Non-primitive types defined with :code:`struct` and :code:`union` are currently not supported, again because it is unclear at this point how these constructs would hook into our block-level data-flow analysis passes.
|
||||
|
||||
**While loops**: We just haven't had time to implement those yet.
|
||||
|
||||
----------------
|
||||
Semantics
|
||||
----------------
|
||||
|
||||
The existence of built-in **blocked** types, variable and operations in Triton-C offers two main benefits. First, it simplifies the structure of blocked programs by hiding important details pertaining to concurrent programming such as memory coalescing, cache management and specialized tensor instrinsics. Second, it opens the door for compilers to perform these optimizations automatically. However, it also means that programs have some kind of *block-level semantics* that does not exist in C. Though some aspects of it (e.g., the :code:`@` operator) are pretty intuitive, one in particular might be puzzling to some GPU programmers: broadcasting semantics.
|
||||
|
||||
+++++++++++++++++++++++
|
||||
Broadcasting Semantics
|
||||
+++++++++++++++++++++++
|
||||
|
||||
|
||||
Block variables in Triton are strongly typed, meaning that certain instructions statically require their operands to satisfy strict shape constraints. For example, a scalar may not be added to an array unless it is first appropriately broadcast. *Broadcasting semantics* (first introduced in `Numpy <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_) provides two formal rules for performing these conversions automatically in the case of binary operators: (1) the shape of the lowest-dimension operand is left-padded with ones until both operands have the same dimensionality; and (2) the content of both operands is replicated as many times as needed until their shape is identical. An error is emitted if this cannot be done.
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
int a[16], b[32, 16], c[16, 1];
|
||||
// a is first reshaped to [1, 16]
|
||||
// and then broadcast to [32, 16]
|
||||
int x_1[32, 16] = a[newaxis, :] + b;
|
||||
// Same as above but implicitly
|
||||
int x_2[32, 16] = a + b;
|
||||
// a is first reshaped to [1, 16]
|
||||
// a is broadcast to [16, 16]
|
||||
// c is broadcast to [16, 16]
|
||||
int y[16, 16] = a + c;
|
||||
|
||||
------------------
|
||||
Programming Model
|
||||
------------------
|
||||
|
||||
As discussed in the `CUDA documentation <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_, The execution of CUDA code on GPUs is supported by an `SPMD <https://en.wikipedia.org/wiki/SPMD>`_ programming model in which each kernel instance is associated with an identifiable *thread-block*, itself decomposed into *warps* of 32 *threads*. The Triton programming model is similar, but each kernel is *single-threaded* -- though automatically parallelized -- and associated with a global :code:`program id` which varies from instance to instance. This approach leads to simpler kernels in which CUDA-like concurrency primitives (shared memory synchronization, inter-thread communication, etc.) do not exist. The global program ids associated with each kernel instance can be queried using the :code:`get_program_id(axis)` built-in function where :code:`0 <= axis <= 2`. This is, for example, useful to create e.g., blocks of pointers as shown in the tutorials.
|
||||
|
BIN
docs/programming-guide/chapter-4/broadcast-1.png
Normal file
BIN
docs/programming-guide/chapter-4/broadcast-1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.9 KiB |
BIN
docs/programming-guide/chapter-4/broadcast-2.png
Normal file
BIN
docs/programming-guide/chapter-4/broadcast-2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.6 KiB |
82
docs/programming-guide/chapter-4/triton-ir.rst
Normal file
82
docs/programming-guide/chapter-4/triton-ir.rst
Normal file
@@ -0,0 +1,82 @@
|
||||
==========================================
|
||||
The Triton-IR Intermediate Representation
|
||||
==========================================
|
||||
|
||||
Triton-IR is an LLVM-based Intermediate Representation (IR) whose purpose is to provide an environment suitable for block-level program analysis, transformation and optimization.
|
||||
In our implementation, Triton-IR programs are constructed directly from Triton-C after parsing, but they could also be formed directly by higher-level DSLs in the future.
|
||||
Triton-IR and LLVM-IR programs share the same high-level structure, but the former also includes a number of extensions necessary for block-level data-flow analysis.
|
||||
These extensions are crucial for carrying out the optimizations outlined in the next chapter of this document.
|
||||
|
||||
---------------------------------
|
||||
Structure of a Triton-IR Program
|
||||
---------------------------------
|
||||
|
||||
++++++++
|
||||
Modules
|
||||
++++++++
|
||||
|
||||
At the highest level, Triton-IR programs consist of one or multiple basic units of compilation known as *modules*. These modules are compiled independently from one another, and eventually aggregated by a linker whose role is to resolve forward declarations and adequately merge global definitions. Each module itself is composed of functions, global variables, constants and other miscellaneous symbols such as metadata and attributes.
|
||||
|
||||
++++++++++
|
||||
Functions
|
||||
++++++++++
|
||||
|
||||
Triton-IR function definitions consist of a return type, a name and a potentially empty arguments list. Additional visibility, alignment and linkage specifiers can be added if desired. Function attributes (such as inlining hints) and parameter attributes (such as "readonly", aliasing hints) can also be specified, allowing compiler backends to perform more aggressive optimizations by, for instance, making better use of non-coherent caches found on NVIDIA GPUs. This header is followed by a body composed of a list of basic blocks whose interdependencies form the Control Flow Graph (CFG) of the function.
|
||||
|
||||
+++++++++++++
|
||||
Basic Blocks
|
||||
+++++++++++++
|
||||
|
||||
Basic blocks are straight-line code sequences that may only contain so-called *terminator* instructions (i.e., branching, return) at their end. To simplify program analysis, Triton-IR uses the Static Single Assignment (SSA) form, meaning that each variable in each basic block must be (1) assigned to only once and (2) defined before being used. In so doing, each basic block implicitly defines a Data-Flow Graph (DFG). In our case, the SSA form is created directly from Triton-C's Abstract Syntax Trees (ASTs) using an algorithm from the literature [BRAUN13]_.
|
||||
|
||||
---------------------------------
|
||||
Block-Level Dataflow Analysis
|
||||
---------------------------------
|
||||
|
||||
+++++++
|
||||
Types
|
||||
+++++++
|
||||
|
||||
Multi-dimensional blocks are at the center of data-flow analysis in Triton-JIT. They can be declared using syntax similar to vector declarations in LLVM-IR. For example, :code:`i32<8, 8>` is the type corresponding to :math:`8 \times 8` blocks of 32-bit integers. Note that there is no preprocessor in Triton-IR, hence parametric shape values must be resolved before programs are generated. In our case, this is done by Triton-JIT's auto-tuner.
|
||||
|
||||
+++++++++++++
|
||||
Instructions
|
||||
+++++++++++++
|
||||
|
||||
Triton-IR introduces a set of *reblocking* instructions whose purpose is to support broadcasting semantics as described in the previous chapter. The :code:`reshape` instruction creates a block of the specified shape using the raw data from its input argument. This is particularly useful to re-interpret variables as higher-dimensional arrays by padding their input shapes with ones in preparation for broadcasting. The :code:`broadcast` instruction creates a block of the specified shapes by replicating its input argument as many times as necessary along dimensions of size 1 -- as shown below for the :code:`broadcast<3,3>` instruction.
|
||||
|
||||
|pic1| and |pic2|
|
||||
|
||||
.. |pic1| image:: broadcast-1.png
|
||||
:width: 40%
|
||||
|
||||
.. |pic2| image:: broadcast-2.png
|
||||
:width: 40%
|
||||
|
||||
Usual scalar instructions (:code:`cmp`, :code:`getelementptr`, :code:`add`, :code:`load`...) were preserved and extended to signify element-wise operations when applicable. Finally, Triton-IR also exposes specialized arithmetic instructions for reductions (:code:`reduce`) and matrix multiplications (:code:`dot`).
|
||||
|
||||
----------------------------------
|
||||
Block-Level Control Flow Analysis
|
||||
----------------------------------
|
||||
|
||||
In Triton-IR, operations on block variables are atomic: they execute either in full or not at all. As a result, traditional control flow structures (e.g., conditional, loops) are not applicable to individual block elements. This is problematic, since a program may need to e.g., partially guard blocked loads against memory access violations.
|
||||
|
||||
This could be potentially solved through the use of the Predicated SSA (PSSA) [CARTER99]_ [STOUTCHININ01]_ form for Triton-IR. However, this would create a lot of unnecessary complexity for GPUs, where the benefits of PSSA are close to none as divergent program paths within warps are serialized anyway. Therefore, recent versions of Triton handle intra-block control flow in a much simpler way, using conditional instructions such as :code:`select`, :code:`masked_load` and :code:`masked_store`:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// For all indices [idx], return cond[idx] ? true_value[idx] : false_value[idx];
|
||||
select TYPE<TS1, ..., TSN> cond, true_value, false_value;
|
||||
// For all indices [idx], return cond[idx] ? *true_addr[idx] : false_value[idx];
|
||||
masked_load TYPE<TS1, ..., TSN> cond, true_addr, false_value;
|
||||
// For all indices [idx], execute *true_addr[idx] = true_value[idx] if cond[idx]
|
||||
masked_store TYPE<TS1, ..., TSN> cond, true_addr, true_value;
|
||||
|
||||
|
||||
------------
|
||||
References
|
||||
------------
|
||||
|
||||
.. [BRAUN13] M. Braun et al., "Simple and Efficient Construction of Static Single Assignment Form", CC 2013
|
||||
.. [CARTER99] L. Carter et al., "Predicated Static Single Assignment", PACT 1999
|
||||
.. [STOUTCHININ01] A. Stoutchinin et al., "Efficient Static Single Assignment Form for Predication", MICRO 2001
|
@@ -1,102 +0,0 @@
|
||||
===========================
|
||||
Writing a Custom Operation
|
||||
===========================
|
||||
|
||||
--------------
|
||||
Compute Kernel
|
||||
--------------
|
||||
|
||||
Let us start with something simple, and see how Triton can be used to create a custom vector addition for PyTorch. The Triton compute kernel for this operation is the following:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
// bounds checking
|
||||
bool check[TILE] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
|
||||
As you can see, arrays are first-class citizen in Triton. This has a number of important advantages that will be highlighted in the next tutorial. For now, let's keep it simple and see how to execute the above operation in PyTorch.
|
||||
|
||||
---------------
|
||||
PyTorch Wrapper
|
||||
---------------
|
||||
|
||||
As you will see, a wrapper for the above Triton function can be created in just a few lines of pure python code.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
# source-code for Triton compute kernel
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
// program id
|
||||
int pid = get_program_id(0);
|
||||
// create arrays of pointers
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
// bounds checking
|
||||
bool check[TILE] = offset < N;
|
||||
// write-back
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
# create callable kernel for the source-code
|
||||
# options: 4 warps and a -DTILE=1024
|
||||
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
|
||||
|
||||
# Forward pass
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
# type checking
|
||||
assert x.dtype == torch.float32
|
||||
# allocate output
|
||||
z = torch.empty_like(x).cuda()
|
||||
# create launch grid
|
||||
# this is a function of the launch parameters
|
||||
# triton.cdiv indicates ceil division
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')), )
|
||||
# launch kernel
|
||||
_add.kernel(z, x, y, N, grid = grid)
|
||||
# return output
|
||||
return z
|
||||
|
||||
# get callable from Triton function
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432).cuda()
|
||||
y = torch.rand(98432).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
diff = (za - zb).abs().max()
|
||||
print(diff)
|
||||
print(torch.allclose(za,zb))
|
||||
|
||||
Executing the above code will:
|
||||
|
||||
- Generate a .cpp file containing PyTorch bindings for the Triton function
|
||||
- Compile this .cpp file using distutils
|
||||
- Cache the resulting custom op
|
||||
- Call the resulting custom op
|
||||
|
||||
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.
|
@@ -1,10 +0,0 @@
|
||||
Tutorials
|
||||
==========
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
custom-operation
|
||||
triton-vs-cuda
|
||||
matrix-transposition
|
||||
matrix-multiplication
|
@@ -1,186 +0,0 @@
|
||||
*********************
|
||||
Matrix Multiplication
|
||||
*********************
|
||||
|
||||
The purpose of this section is to present a Triton-C implementation of matrix multiplication that achieves performance competitive with the best existing hand-written CUDA kernels (see `CUTLASS <https://github.com/NVIDIA/cutlass>`_). We will also see how pre-processors macros can be leveraged to fuse transposition operations as well as to provide support for auto-tuning and FP16 Tensor Cores.
|
||||
|
||||
*Note: Bounds-checking is ommitted throughout for the sake of clarity. This feature can be easily added into our kernel, but may result in a slight performance hit because LLVM and PTXAS have issues dealing with conditionals and predicates inside loops.*
|
||||
|
||||
==============
|
||||
Compute Kernel
|
||||
==============
|
||||
|
||||
Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fairly concisely, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton-C
|
||||
// launched on a grid of (M / TM) x (N / TN) programs
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C, int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pm = get_program_id(0); //(1)
|
||||
int pn = get_program_id(1); //(2)
|
||||
int rm[TM] = pm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pn * TN + 0 ... TN; //(4)
|
||||
int rk[TK] = 0 ... TK; //(5)
|
||||
// initialize accumulator
|
||||
float c[TM, TN] = 0; //(6)
|
||||
// pointers to operands
|
||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; //(7)
|
||||
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; //(8)
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
// fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
// matrix-multiply accumulate
|
||||
c += a @ b; //(11)
|
||||
// increment pointers
|
||||
pa = pa + TK * 1; //(12)
|
||||
pb = pb + TK * ldb; //(13)
|
||||
}
|
||||
// epilogue
|
||||
TYPE* pc[TM, TN] = C + rn[newaxis, :] + rm[:, newaxis] * ldc; //(14)
|
||||
*pc = c; //(15)
|
||||
}
|
||||
|
||||
Here, each kernel instance produces a :code:`TM x TN` tile of the output matrix C as follows:
|
||||
|
||||
- Statements (1) - (2) fetch the id of the current program instance.
|
||||
- Statements (3) - (4) construct ranges of indices to process for the vertical and horizontal axes of the output matrix :code:`C`
|
||||
- Statement (5) constructs a range of indices along the reduction axis: :code:`rk = [0, 1, ..., TK - 1]`
|
||||
- Statement (6) initialize a :code:`TM x TN` array of accumulators to hold the result of :code:`A[rm, :] x B[:, rn]`
|
||||
- Statements (7) - (8) initializes arrays of pointers :code:`pa` and :code:`pb` to the operands :code:`A` and :code:`B` using logic similar to that of the above transposition kernel
|
||||
- Statements (9) - (10) load tiles of operands by dereferencing :code:`pa` and :code:`pb`
|
||||
- Statement (11) performs updates the accumulator array using Triton-C's matrix multiplication operator :code:'@'
|
||||
- Statements (12) - (13) updates :code:`pa` and :code:`pb`
|
||||
- Statement (14) creates an array of pointers `pc` to the result matrix :code:`C`
|
||||
- Statement (15) writes back the accumulator to :code:`C`
|
||||
|
||||
Internally, the Triton compiler will perform quite a few optimizations that will ensure good performance for this kernel:
|
||||
|
||||
- Automatic coalescing of load/store operations
|
||||
- Automatic vectorization of load/store operations
|
||||
- Stashing `a` and `b` to shared memory
|
||||
- Automatic allocation of shared memory
|
||||
- Automatic synchronization of shared memory
|
||||
- Automatic padding of shared memory to avoid bank conflicts
|
||||
- Automatic usage of tensor cores when TYPE = half and TK % 4 = 0
|
||||
|
||||
|
||||
==============
|
||||
Optimizations
|
||||
==============
|
||||
|
||||
Nonetheless, there are two important optimizations that the Triton compiler does not do automatically at the moment yet are critical to achieve peak performance: pre-fetching and rematerialization. In this subsection we describe how these optimizations can be done manually by modifying the above source-code.
|
||||
|
||||
-------------
|
||||
Pre-Fetching
|
||||
-------------
|
||||
|
||||
The purpose of pre-fetching is to overlap the update of the accumulator `c` with the memory loads for the next tiles that will need to be multiplied. This can be done by modifying the above reduction loop as follows:
|
||||
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// pre-fetch operands
|
||||
TYPE a[TM, TK] = *pa; //(9)
|
||||
TYPE b[TK, TN] = *pb; //(10)
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += a @ b;
|
||||
pa = pa + TK * 1;
|
||||
pb = pb + TK * ldb;
|
||||
// don't prefetch last iteration
|
||||
bool check = k > TK;
|
||||
// pre-fetch operands
|
||||
a = check ? *pa : 0;
|
||||
b = check ? *pb : 0;
|
||||
}
|
||||
|
||||
|
||||
Note that the Triton-C compiler will now also be able to use double-buffering techniques to make sure that the array `a` can be used and updated at the same time without any memory hazard.
|
||||
|
||||
-----------------
|
||||
Rematerialization
|
||||
-----------------
|
||||
|
||||
`Rematerialization <https://en.wikipedia.org/wiki/Rematerialization>`_ is a compiler optimization which consists in recomputing some values instead of storing and reloading them from (register) memory, so as to decrease register pressure in the compute kernel. Although LLVM does this automatically to some extent, it fails to find good heuristics for the above kernel -- thereby requiring some source code modification to achieve optimal performance. Fortunately, only :code:`rm` and :code:`rn` need to be rematerialized, leading to the following epilogue:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// epilogue
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
|
||||
|
||||
------------------------------------
|
||||
Fused Transpositions and Auto-Tuning
|
||||
------------------------------------
|
||||
|
||||
It is common for optimized matrix-multiplication implementations (e.g., BLAS) to provide variants in which one or both operands are transposed. Fortunately, this can be done by using pre-processors macros for tile shapes and broadcasting directives, leading to the following kernel:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton-C
|
||||
// launched on a grid of (M / TM) x (N / TN) programs
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pm = get_program_id(0);
|
||||
int pn = get_program_id(1);
|
||||
int rm[TM] = pm * TM + 0 ... TM;
|
||||
int rn[TN] = pn * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = (*pa);
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
// epilogue
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
|
||||
|
||||
All matrix multiplications variants can then be retrieved using the following compilation option:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// A is not transposed
|
||||
-DUSE_A=a -DSTRIDE_AK=1-DSTRIDE_AM=lda
|
||||
-DBROADCAST_AK=newaxis,: -DBROADCAST_AN=:,newaxis -DSHAPE_A=TM,TK
|
||||
// A is transposed
|
||||
-DUSE_A=^a -DSTRIDE_AK=lda-DSTRIDE_AM=1
|
||||
-DBROADCAST_AK=:,newaxis -DBROADCAST_AN=newaxis,: -DSHAPE_A=TK,TM
|
||||
// B is not transpose
|
||||
-DUSE_B=b -DSTRIDE_BK=ldb-DSTRIDE_BN=1
|
||||
-DBROADCAST_BK=:,newaxis -DBROADCAST_BN=newaxis,: -DSHAPE_B=TK,TN
|
||||
// B is transpose
|
||||
-DUSE_B=^b -DSTRIDE_BK=1-DSTRIDE_BN=ldb
|
||||
-DBROADCAST_BK=newaxis,: -DBROADCAST_BN=:,newaxis -DSHAPE_B=TN,TK
|
||||
|
||||
|
||||
Auto-tuning can also be handled using pre-processor macros:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
|
||||
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.
|
@@ -1,174 +0,0 @@
|
||||
*********************
|
||||
Matrix Transpositions
|
||||
*********************
|
||||
|
||||
|
||||
Transpositions are (relatively) hard to efficiently write in CUDA because naive implementations typically suffer from *uncoalesced* memory operations when writing back the transposed matrix to DRAM.
|
||||
|
||||
Of course, this can be fixed by using shared memory as shown `here <https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc>`_, but this comes at the cost of simplicity interferes with auto-tuning.
|
||||
|
||||
==============
|
||||
Compute Kernel
|
||||
==============
|
||||
|
||||
In Triton, however, kernels are single-threaded and the compiler automatically detects if and when data should be temporarily stashed to shared memory. Therefore, an optimal Triton kernel for this operation would look like:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// launched on a grid of (M / TM) x (N / TN) programs of 1 thread each
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M, int N, int ldx, int ldy) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
// write back using the transposition operator '^'
|
||||
*py = ^(*px); //(7)
|
||||
}
|
||||
|
||||
At a high level, this kernel loads a :code:`TM x TN` tile from the input matrix :code:`X`, transposes it and writes the resulting :code:`TN x TM` tile to the output matrix :code:`Y`. Eventually, transposition of the full input matrix is achieved by launching a grid of :code:`(M / TM) x (N / TN)` programs decomposed as follows:
|
||||
|
||||
- Statements (1) and (2) extract the coordinates the program in the above 2D launch grid. For example, the program producing the output tile `Y[TN:2TN-1, 2TN:3TN-1]` holds the values:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
pidm = 2
|
||||
pidn = 1
|
||||
|
||||
|
||||
- Statements (3) and (4) construct the ranges of indices:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
rm = [pidm*TM + 0, pidm*TM + 1, ..., pidm*TM + (TM - 1)]
|
||||
rn = [pidn*TN + 0, pidn*TN + 1, ..., pidn*TN + (TN - 1)]
|
||||
|
||||
|
||||
which will be used in statements (5) and (6) to construct tiles of pointers
|
||||
|
||||
- Statements (5) constructs the following array of pointers `px` using numpy-style broadcasting semantics:
|
||||
|
||||
::
|
||||
|
||||
│ X + (pidm*TM + 0) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + 0) + (pidn*TN + TN - 1)*ldx) │
|
||||
│ ⋮ ⋮ │
|
||||
│ ⋮ ⋮ │
|
||||
│ X + (pidm*TM + TM - 1) + (pidn*TN + 0)*ldx, ..., ..., X + (pidm*TM + TM - 1) + (pidn*TN + TN - 1)*ldx) │
|
||||
|
||||
|
||||
- Statement (6) constructs the following array of pointers `py` using numpy-style broadcasting semantics:
|
||||
|
||||
::
|
||||
|
||||
│ Y + (pidn*TN + 0) + (pidm*TM + 0)*ldy, ..., ..., Y + (pidn*TN + 0) + (pidm*TM + TM - 1)*ldy) │
|
||||
│ ⋮ ⋮ │
|
||||
│ ⋮ ⋮ │
|
||||
│ Y + (pidn*TN + TN - 1) + (pidn*TN + 0)*ldy, ..., ..., Y + (pidn*TN + TN - 1) + (pidm*TM + TM - 1)*ldy) │
|
||||
|
||||
- Statement (7) element-wise dereferences the above array of pointers `*px`, transposes it using the unary transposition operator `^`, and writes it back at the location specified by `py`.
|
||||
|
||||
|
||||
==================================
|
||||
A Note on Numpy-style Broadcasting
|
||||
==================================
|
||||
|
||||
The construction statements (5) and (6) are a little subtle. To help understand them, consider the following numpy example.
|
||||
|
||||
First, we create a row vector of numbers 0 to 11, which we reshape into a 4x3 matrix.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
vec = np.linspace(0,11,12)
|
||||
mat = vec.reshape((4,3))
|
||||
|
||||
Imagine that we would like to process this in two 2x3 tiles (i.e. tile 0 will consider the top half, and tile 1 will consider the bottom).
|
||||
|
||||
::
|
||||
|
||||
[[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8],
|
||||
[ 9, 10, 11]]
|
||||
|
||||
Given `pidm=0`, `pidn=0`, `TM=2`, `TN=3`, we would like for tile 0 to have the values:
|
||||
|
||||
::
|
||||
|
||||
[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
|
||||
We construct ranges `rm` and `rn` as:
|
||||
::
|
||||
|
||||
rm = [0, 1]
|
||||
rn = [0, 1, 2]
|
||||
|
||||
Using numpy-style broadcasting, we can add these together to create a matrix:
|
||||
|
||||
::
|
||||
|
||||
rm[:, np.newaxis] + rn[np.newaxis, :]
|
||||
|
||||
rn -> [0, 1, 2]
|
||||
rm -> [0., [[0, 1, 2],
|
||||
1.] [1, 2, 3]]
|
||||
|
||||
The bottom row is incorrect. Notice that `rm` indexes the rows of the matrix; we need to offset it so that each element gives the index
|
||||
of the start of that row. For instance, to access row 1 column 0, we need to access location 3. To access row 2 column 0, we need
|
||||
to access location 6. To translate from row N, column 0, we need to multiply N by the number of columns in each row (the leading dimension).
|
||||
In this case this is 3, so what we really need is:
|
||||
|
||||
::
|
||||
|
||||
ldx = 3
|
||||
px = rm[:, np.newaxis] * ldx + rn[np.newaxis,:]
|
||||
|
||||
`newaxis` is built into Triton, and pointer arrays can be constructed in just the same way (as in this example).
|
||||
|
||||
==========================
|
||||
The __multipleof attribute
|
||||
==========================
|
||||
|
||||
The memory loads and store in our transposition kernel are not vectorizable by default, since `X + ldx` (and `Y + ldy`) may be misaligned when `ldx` (and `ldy`) are not multiples of e.g., 4. This is unfortunate because tensor dimensions can be easily made into nice powers of two in Deep Learning, due to batch-sizes and layer width being flexible.
|
||||
|
||||
For this reason, Triton provides a __multipleof(N) attributes for variables that are guaranteed to always be multiple of N. In the case of Matrix Transpositions, vector loads can be enabled by modifying the function's signature as follows:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N,
|
||||
int ldx __multipleof(8),
|
||||
int ldy __multipleof(8)) {
|
||||
// ...
|
||||
}
|
||||
|
||||
|
||||
==========================
|
||||
Bounds Checking
|
||||
==========================
|
||||
|
||||
|
||||
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// launched on a grid of ((M + TM - 1) / TM) x ((N + TN - 1) / TN) programs
|
||||
__global__ void transpose(TYPE * X, TYPE * Y, int M, int N, int ldx, int ldy) {
|
||||
// ...
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rm[newaxis, :] < M) && (rn[:, newaxis] < N); //(7b)
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
|
||||
|
||||
Here, statements (7a) creates an array of booleans :code:`checkx[TM, TN]` such that :code:`checkx(i, j) = True` if and only if `px(i, j)` should be dereferenced. Statement (7b) does the same for `py`. Both `px` and `py` are then conditionally dereferenced using Triton-C's conditional dereferencing operator :code:`*?(predicate) pointer`.
|
||||
|
||||
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_transpose.py>`_.
|
@@ -1,180 +0,0 @@
|
||||
====================================================
|
||||
Putting It All Together
|
||||
====================================================
|
||||
|
||||
In the previous tutorial, we saw how to write tensor-core-friendly matrix multiplication code competitive with cuBLAS in 20 lines of Triton code. Here, we will see how to wrap it into an automatically differentiable PyTorch functions for easy integration in your Deep Learning pipeline.
|
||||
|
||||
-----------------
|
||||
PyTriton Function
|
||||
-----------------
|
||||
|
||||
The PyTriton API provides a :code:`triton.function` class which automatically handles the interaction with automatic differentiation in whichever framework was detected. Therefore, every differentiable custom operation written with PyTriton should inherit from this class
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import triton
|
||||
|
||||
# Entry point
|
||||
class _dot(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
# Forward Pass
|
||||
def forward(ctx, *args):
|
||||
#...
|
||||
|
||||
@staticmethod
|
||||
# Backward Pass
|
||||
def backward(ctx, dy):
|
||||
#...
|
||||
|
||||
-----------------
|
||||
PyTriton Kernels
|
||||
-----------------
|
||||
|
||||
|
||||
PyTriton also provides a :code:`triton.kernel` class which automatically takes care of interaction with the Triton-JIT as well as the generation and compilation of C++ framework bindings code. For our dot operation we create a kernel from the Triton code shown at the end of the previous tutorial.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
src = """
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pm = get_program_id(0);
|
||||
int pn = get_program_id(1);
|
||||
int rm[TM] = pm * TM + 0 ... TM;
|
||||
int rn[TN] = pn * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
// pointers to operands
|
||||
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
|
||||
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
|
||||
// prefetches operands
|
||||
TYPE a[SHAPE_A] = (*pa);
|
||||
TYPE b[SHAPE_B] = (*pb);
|
||||
// reduction loop
|
||||
for(int k = K; k > 0; k-= TK){
|
||||
c += USE_A @ USE_B;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
// epilogue
|
||||
int rcm[TM] = pm * TM + 0 ... TM;
|
||||
int rcn[TN] = pn * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rcn[newaxis, :] + rcm[:, newaxis] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src)
|
||||
|
||||
|
||||
At this point, `kernel` is a callable object which takes the same signature as the :code:`dot` function in our source code, except that pointers are treated as tensors: :code:`[tensor, tensor, tensor, int, int, int, int, int, int]`.
|
||||
|
||||
-----------------------
|
||||
Using PyTriton Kernels
|
||||
-----------------------
|
||||
|
||||
|
||||
However, in practice only A, B are provided by the user, and all the other :code:`int` arguments should be derived from these operands only. Hence, we create a helper function that extracts shapes from the :code:`A` and :code:`B` tensors, and then returns the results of a call to :code:`kernel`:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, transpose_a, transpose_b):
|
||||
# extract shapes
|
||||
shape_a = a.shape
|
||||
shape_b = b.shape
|
||||
M, Ka = shape_a[0], shape_a[1]
|
||||
Kb, N = shape_b[0], shape_b[1]
|
||||
# transpose shapes
|
||||
if transpose_a:
|
||||
M, Ka = Ka, M
|
||||
if transpose_b:
|
||||
Kb, N = N, Kb
|
||||
# contiguous dimensions
|
||||
lda = M if transpose_a else Ka
|
||||
ldb = Kb if transpose_b else N
|
||||
ldc = N
|
||||
# data-type
|
||||
dtype = a.dtype
|
||||
# allocate output
|
||||
c = triton.empty([M, N], dtype = dtype)
|
||||
# launch grid
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
# pre-processor definitions
|
||||
defines = {# tile sizes
|
||||
'TYPE' : dtype,
|
||||
'AT' : transpose_a,
|
||||
'BT' : transpose_b,
|
||||
'TM' : [32, 64, 128],
|
||||
'TN' : [32, 64, 128],
|
||||
'TK' : [8],
|
||||
# handle A transposition
|
||||
'USE_A' : '^a' if transpose_a else 'a',
|
||||
'STRIDE_AK' : 'lda' if transpose_a else '1',
|
||||
'STRIDE_AM' : '1' if transpose_a else 'lda',
|
||||
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
|
||||
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
|
||||
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
|
||||
# handle B transposition
|
||||
'USE_B' : '^b' if transpose_b else 'b',
|
||||
'STRIDE_BK' : '1' if transpose_b else 'ldb',
|
||||
'STRIDE_BN' : 'ldb' if transpose_b else '1',
|
||||
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
|
||||
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
|
||||
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
|
||||
return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||
grid=grid, num_warps=4, defines=defines)
|
||||
|
||||
|
||||
--------------------------------------------
|
||||
Automatic Differentiation
|
||||
--------------------------------------------
|
||||
|
||||
At this point, our custom operation only takes two tensor arguments and transposition information, which is good. However, it is still not compatible with PyTorch's or TensorFlow's automatic differentiation engine, and a small amount of additional effort is needed.
|
||||
|
||||
|
||||
Creating custom operations for Triton and PyTorch is very similar; programmers have to provide two static methods :code:`forward` and :code:`backward` that take a context as their first input:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.t_a = transpose_a
|
||||
ctx.t_b = transpose_b
|
||||
return _dot._call(a, b, transpose_a, transpose_b)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b = ctx.saved_tensors
|
||||
t_a, t_b = ctx.t_a, ctx.t_b
|
||||
if not t_a and not t_b:
|
||||
da = _dot._call(dy, b, False, True)
|
||||
db = _dot._call(a, dy, True, False)
|
||||
elif not t_a and t_b:
|
||||
da = _dot._call(dy, b, False, False)
|
||||
db = _dot._call(dy, a, True, False)
|
||||
elif t_a and not t_b:
|
||||
da = _dot._call(b, dy, False, True)
|
||||
db = _dot._call(a, dy, False, False)
|
||||
elif t_a and t_b:
|
||||
da = _dot._call(b, dy, True, True)
|
||||
db = _dot._call(dy, a, True, True)
|
||||
else:
|
||||
assert False
|
||||
return da, db, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
A callable operation can be created using the :code:`apply` method of the :code:`torch.autograd.Function` class.
|
||||
|
||||
.. code:: python
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
|
||||
And that's it! In just ~100 lines of pure python, we have written a fully functional matrix multiplication that will not only work with automatic differentiation but also provide performance very close to cuBLAS. And it's all open-source~
|
@@ -1,109 +0,0 @@
|
||||
***************
|
||||
Triton vs. CUDA
|
||||
***************
|
||||
|
||||
|
||||
|
||||
The purpose of this tutorial is to explore in more depth the major differences between Triton and CUDA. To keep things simple, we will still be focusing on the following vector addition code:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// Triton
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float* z, float* x, float* y, int N){
|
||||
int offset[TILE] = get_program_id(0) * TILE + 0 ... TILE;
|
||||
bool check[TILE] = offset < N;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
*?(check)pz = *?(check)*px + *?(check)py;
|
||||
}
|
||||
|
||||
And its CUDA equivalent:
|
||||
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// CUDA
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float *z, float *x, float *y, int N) {
|
||||
int off = blockIdx.x * TILE + threadIdx.x;
|
||||
if(off < N){
|
||||
float *pz = z + off;
|
||||
float *px = x + off;
|
||||
float *py = y + off;
|
||||
*pz = *px + *py
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
==========================
|
||||
Automatic parallelization
|
||||
==========================
|
||||
|
||||
While the two above pieces of code may look at first sight similar, a closer look reveals one *fundamental* difference: While CUDA kernels are launched on a cooperative array of threads, **Triton kernel are single-threaded and automatically parallelized**.
|
||||
|
||||
This is a major difference in programming model, which not only makes your life much easier as a programmer, but also allows the Triton compiler to automatically do all sorts of nice optimizations:
|
||||
|
||||
- *Automatic shared memory allocation and synchronization*
|
||||
|
||||
That's right; programmers don't need to worry about shared memory allocation, usage and synchronization. Instead, the Triton compiler will use complex program analysis techniques to determine when shared memory should be used, where it should be synchronized and how threads should access it to avoid memory bank conflicts.
|
||||
|
||||
- *Automatic memory coalescing*
|
||||
|
||||
When you write Triton code, you also don't need to worry about memory coalescing. The compiler will arrange threads so that global memory accesses are coalesced when possible.
|
||||
|
||||
- *Automatic tensor core utilization*
|
||||
|
||||
Using tensor cores on Volta and Turing is notoriously difficult. Code is hard to write and even harder to optimize. Fortunately, the Triton compiler can also generate very efficient tensor core instructions (e.g., :code:`mma.sync.m8n8k4`) when low-precision matrices are multiplied together:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
half A[16, 8] = ... // initialize A
|
||||
half B[8, 16] = ... // initialize B
|
||||
float C[16, 16] = dot(A, B); // uses Tensor Cores!
|
||||
|
||||
|
||||
- *Automatic instruction predication*
|
||||
|
||||
Contrary to CUDA, Triton directly exposes predicated instruction through masked load/store instructions. This enables the Triton compiler to generate predicated instructions in PTX directly, resulting in sometimes better performance than I/O operations wrapped inside conditionals.
|
||||
|
||||
===========================
|
||||
Vector Addition - Revisited
|
||||
===========================
|
||||
|
||||
In light of these optimizations, it turns out that the GPU code generated by our Triton-C vector addition code is actually more analogous to the following:
|
||||
|
||||
.. code-block:: C
|
||||
|
||||
// CUDA
|
||||
// launch on a grid of (N + TILE - 1) / TILE programs
|
||||
__global__ void add(float *z, float *x, float *y, int N) {
|
||||
int off[4];
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++)
|
||||
off[k] = blockIdx.x * TILE + threadIdx.x + k * blockSize.x;
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++)
|
||||
z[off[0]] = x[off[0]] + y[off[0]]
|
||||
}
|
||||
|
||||
This code is actually more complicated when x, y and z have :code:`half` type, because then the Triton compiler automatically vectorizes data accesses using :code:`half2` to ensure memory transactions of 32-bits per thread.
|
||||
|
||||
============================
|
||||
Auto-Tuning
|
||||
============================
|
||||
|
||||
Now assume that you want to tune the above code for different data types, tile sizes and thread block sizes. This is doable in CUDA but would require you to write cumbersome machinery to handle different vector sizes and loop unrolling factors. In Triton, this can be trivially done by adjusting some compilation parameters. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
kernel = triton.kernel(src, defines = {'TILE': [256, 512, 1024]}, num_warps = [2, 4, 8])
|
||||
|
||||
would benchmark our above triton source-code for tile sizes of 256, 512 and 1024 executed with 2, 4 or 8 warps -- and cache the fastest kernel.
|
||||
|
||||
=============================
|
||||
Going Further
|
||||
=============================
|
||||
|
||||
The benefits of Triton become more and more pronounced as compute kernels get more and more complex. In the next few tutorials, you will see how to implement transposition and tensor-core-compatible matrix multiplication routine on par with cuBLAS and CUTLASS without having to know anything about GPU micro-architecture!
|
@@ -27,7 +27,7 @@ private:
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -24,7 +25,7 @@ class axes;
|
||||
class align;
|
||||
class layout_visitor;
|
||||
class data_layout;
|
||||
class mma884_layout;
|
||||
class mma_layout;
|
||||
class scanline_layout;
|
||||
class shared_layout;
|
||||
|
||||
@@ -32,7 +33,7 @@ class shared_layout;
|
||||
class layout_visitor {
|
||||
public:
|
||||
virtual void visit_layout(data_layout *);
|
||||
virtual void visit_layout_hmma_884(mma884_layout*) = 0;
|
||||
virtual void visit_layout_mma(mma_layout*) = 0;
|
||||
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||
};
|
||||
@@ -40,7 +41,7 @@ public:
|
||||
class data_layout {
|
||||
protected:
|
||||
enum id_t {
|
||||
HMMA_884,
|
||||
MMA,
|
||||
SCANLINE,
|
||||
SHARED
|
||||
};
|
||||
@@ -67,7 +68,7 @@ public:
|
||||
// visitor
|
||||
virtual void accept(layout_visitor* vst) = 0;
|
||||
// downcast
|
||||
mma884_layout* to_mma884() { return downcast<mma884_layout>(HMMA_884); }
|
||||
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
|
||||
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||
// accessors
|
||||
@@ -76,9 +77,10 @@ public:
|
||||
const order_t& get_order() const { return order_; }
|
||||
const values_t& get_values() const { return values_;}
|
||||
int get_axis(size_t k) const { return axes_.at(k); }
|
||||
std::vector<int> get_axes() const { return axes_; }
|
||||
const int get_order(size_t k) const { return order_.at(k); }
|
||||
// find the position of given axis
|
||||
size_t find_axis(int to_find) const;
|
||||
int find_axis(int to_find) const;
|
||||
|
||||
|
||||
private:
|
||||
@@ -91,21 +93,29 @@ protected:
|
||||
shape_t shape_;
|
||||
};
|
||||
|
||||
class mma884_layout: public data_layout {
|
||||
class mma_layout: public data_layout {
|
||||
public:
|
||||
mma884_layout(size_t num_warps,
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); }
|
||||
analysis::align* align, target *tgt,
|
||||
shared_layout* layout_a,
|
||||
shared_layout* layout_b);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||
// accessor
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int spt(size_t k) { return spt_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
|
||||
private:
|
||||
std::vector<int> fpw_;
|
||||
std::vector<int> spw_;
|
||||
std::vector<int> wpt_;
|
||||
std::vector<int> spt_;
|
||||
std::vector<int> rep_;
|
||||
};
|
||||
|
||||
struct scanline_layout: public data_layout {
|
||||
@@ -113,7 +123,8 @@ struct scanline_layout: public data_layout {
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
@@ -136,7 +147,7 @@ private:
|
||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||
|
||||
public:
|
||||
shared_layout(const data_layout *arg,
|
||||
shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
@@ -147,11 +158,22 @@ public:
|
||||
size_t get_size() { return size_; }
|
||||
ir::type* get_type() { return ty_; }
|
||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||
size_t get_num_per_phase() { return num_per_phase_; }
|
||||
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
ir::type *ty_;
|
||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||
size_t num_per_phase_;
|
||||
ir::value* hmma_dot_a_;
|
||||
ir::value* hmma_dot_b_;
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
};
|
||||
|
||||
|
||||
@@ -172,7 +194,7 @@ private:
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
@@ -190,6 +212,7 @@ private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
|
43
include/triton/codegen/analysis/swizzle.h
Normal file
43
include/triton/codegen/analysis/swizzle.h
Normal file
@@ -0,0 +1,43 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class target;
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
class swizzle {
|
||||
public:
|
||||
// constructor
|
||||
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
|
||||
// accessors
|
||||
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
|
||||
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
|
||||
int get_vec (data_layout* layout) { return vec_.at(layout); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
private:
|
||||
layouts* layouts_;
|
||||
target* tgt_;
|
||||
std::map<data_layout*, int> per_phase_;
|
||||
std::map<data_layout*, int> max_phase_;
|
||||
std::map<data_layout*, int> vec_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -5,13 +5,14 @@
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class BasicBlock;
|
||||
class Attribute;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
@@ -25,6 +26,13 @@ namespace llvm{
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class attribute;
|
||||
class load_inst;
|
||||
class store_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
// forward
|
||||
@@ -36,6 +44,7 @@ class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
class swizzle;
|
||||
}
|
||||
// typedef
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
@@ -43,17 +52,14 @@ typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Attribute Attribute;
|
||||
typedef llvm::BasicBlock BasicBlock;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
// forward
|
||||
class machine_data_layout;
|
||||
class tile;
|
||||
class shared_tile;
|
||||
class distributed_tile;
|
||||
class target;
|
||||
|
||||
}
|
||||
@@ -62,109 +68,129 @@ class target;
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
|
||||
Value* get_value(ir::value *x, const indices_t& idx);
|
||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
||||
|
||||
void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
||||
void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
|
||||
void init_idx(ir::value *x);
|
||||
Instruction* add_barrier();
|
||||
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
|
||||
void finalize_shared_layout(analysis::shared_layout*);
|
||||
void finalize_function(ir::function*);
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
private:
|
||||
Type *cvt(ir::type *ty);
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
|
||||
|
||||
void visit_load_inst(ir::load_inst*);
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
|
||||
void visit_recoalesce_inst(ir::recoalesce_inst*);
|
||||
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
|
||||
void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
|
||||
void visit_function(ir::function*);
|
||||
void visit_basic_block(ir::basic_block*);
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
void visit_layout_hmma_884(analysis::mma884_layout*);
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
std::map<const analysis::data_layout*, machine_data_layout*> machine_layouts_;
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
std::map<ir::value *, Value *> vmap_;
|
||||
std::map<ir::value *, tile *> tmap_;
|
||||
target *tgt_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *sh_mem_ptr_;
|
||||
Value *shmem_;
|
||||
unsigned num_warps_;
|
||||
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||
|
||||
|
||||
std::map<ir::value*, Value*> shmems_;
|
||||
std::map<ir::value*, Value*> shoffs_;
|
||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1,138 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_
|
||||
#define _TRITON_SELECTION_MACHINE_LAYOUT_H_
|
||||
|
||||
#include <map>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
}
|
||||
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
|
||||
class distributed_axis;
|
||||
class machine_data_layout;
|
||||
class tile;
|
||||
class shared_tile;
|
||||
class distributed_tile;
|
||||
class target;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
|
||||
class machine_data_layout {
|
||||
public:
|
||||
virtual tile* create(ir::value *v) = 0;
|
||||
};
|
||||
|
||||
class machine_shared_layout: public machine_data_layout {
|
||||
public:
|
||||
machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr,
|
||||
analysis::shared_layout* layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap);
|
||||
|
||||
tile* create(ir::value *v);
|
||||
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
analysis::allocation* alloc_;
|
||||
Value *&sh_mem_ptr_;
|
||||
analysis::shared_layout* layout_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
|
||||
Value *offset_;
|
||||
Value *ptr_;
|
||||
Value *pre_ptr_;
|
||||
Value *next_ptr_;
|
||||
|
||||
};
|
||||
|
||||
class machine_distributed_layout: public machine_data_layout {
|
||||
public:
|
||||
machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::data_layout* layout);
|
||||
|
||||
tile* create(ir::value *v);
|
||||
Module *mod_;
|
||||
Builder *builder_;
|
||||
target *tgt_;
|
||||
analysis::axes *a_axes_;
|
||||
std::map<unsigned, distributed_axis>& axes_;
|
||||
analysis::data_layout* layout_;
|
||||
};
|
||||
|
||||
|
||||
class machine_mma884_layout: public machine_distributed_layout {
|
||||
public:
|
||||
machine_mma884_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::mma884_layout* layout);
|
||||
Value *offset_a_i_, *offset_a_k_;
|
||||
Value *offset_b_j_, *offset_b_k_;
|
||||
unsigned pack_size_0_;
|
||||
unsigned pack_size_1_;
|
||||
unsigned num_packs_0_;
|
||||
unsigned num_packs_1_;
|
||||
};
|
||||
|
||||
class machine_scanline_layout: public machine_distributed_layout {
|
||||
public:
|
||||
machine_scanline_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::scanline_layout* layout);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,152 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_
|
||||
#define _TRITON_SELECTION_MACHINE_VALUE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
}
|
||||
|
||||
class distributed_axis;
|
||||
class machine_data_layout;
|
||||
class tile;
|
||||
class shared_tile;
|
||||
class distributed_tile;
|
||||
class target;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class tile {
|
||||
protected:
|
||||
typedef std::vector<unsigned> shapes_t;
|
||||
|
||||
public:
|
||||
tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ }
|
||||
virtual void set_value(indices_t idx, Value *v) = 0;
|
||||
virtual Value* get_value(indices_t idx) = 0;
|
||||
Type *get_ty() const { return ty_; }
|
||||
shapes_t get_shapes() const { return shapes_; }
|
||||
|
||||
protected:
|
||||
Type *ty_;
|
||||
shapes_t shapes_;
|
||||
};
|
||||
|
||||
class shared_tile: public tile {
|
||||
private:
|
||||
void extract_constant(Value *arg, Value *&non_cst, Value *&cst);
|
||||
void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx);
|
||||
|
||||
|
||||
public:
|
||||
shared_tile(Type* ty, const shapes_t &shapes, const std::vector<int> &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector<int>& perm = {});
|
||||
void set_vector_size(unsigned vector_size);
|
||||
void set_return_mode(bool return_vector);
|
||||
void set_value(indices_t, Value *);
|
||||
Value* get_ptr_to(indices_t idx);
|
||||
Value* get_value(indices_t idx);
|
||||
Value* get_pointer() { return ptr_; }
|
||||
Value* get_offset() { return offset_; }
|
||||
const std::vector<int>& get_perm() { return perm_; }
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector<int>& perm, const std::vector<int>& order, indices_t idx);
|
||||
|
||||
private:
|
||||
Value *ptr_;
|
||||
bool return_vector_;
|
||||
Builder &builder_;
|
||||
Value *offset_;
|
||||
std::map<indices_t, Value*> ptr_cache_;
|
||||
unsigned vector_size_;
|
||||
std::vector<int> order_;
|
||||
std::vector<int> perm_;
|
||||
};
|
||||
|
||||
// Distribtued tile
|
||||
class distributed_tile: public tile{
|
||||
typedef std::vector<distributed_axis> axes_t;
|
||||
typedef std::vector<indices_t> ordered_indices_vec_t;
|
||||
typedef std::map<indices_t, unsigned> indices_map_t;
|
||||
typedef std::map<indices_t, Value*> values_map_t;
|
||||
|
||||
private:
|
||||
void init_indices();
|
||||
|
||||
public:
|
||||
distributed_tile(Type *ty, const shapes_t& shapes, const std::vector<int>& order, const axes_t &axes, Builder &builder);
|
||||
void set_value(indices_t idx, Value *v);
|
||||
Value* get_value(indices_t idx);
|
||||
const std::vector<int>& get_order() { return order_; }
|
||||
unsigned get_linear_index(indices_t idx);
|
||||
indices_t get_ordered_indices(unsigned id);
|
||||
void for_each(std::function<void(indices_t)> fn, int start = 0, int end = -1);
|
||||
void for_each(std::function<void(indices_t)> fn, std::vector<int> start, std::vector<int> size);
|
||||
|
||||
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
||||
private:
|
||||
axes_t axes_;
|
||||
std::vector<int> order_;
|
||||
indices_map_t indices_;
|
||||
values_map_t values_;
|
||||
ordered_indices_vec_t ordered_indices_;
|
||||
Builder &builder_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -35,6 +35,8 @@ namespace codegen{
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||
@@ -47,6 +49,7 @@ public:
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
nvidia_cu_target* as_nvidia();
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -68,7 +71,7 @@ public:
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
public:
|
||||
nvidia_cu_target(): target(true){}
|
||||
nvidia_cu_target(int sm): target(true), sm_(sm){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
@@ -76,7 +79,11 @@ public:
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
int sm() { return sm_; }
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
|
@@ -11,14 +11,22 @@ namespace ir {
|
||||
class value;
|
||||
class phi_node;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -1,12 +1,18 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class masked_load_async_inst;
|
||||
class value;
|
||||
class builder;
|
||||
}
|
||||
@@ -27,18 +33,15 @@ namespace transform{
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::vector<interval_t> interval_vec_t;
|
||||
typedef std::set<ir::value*> val_set_t;
|
||||
typedef std::vector<ir::value*> val_vec_t;
|
||||
|
||||
private:
|
||||
interval_vec_t join(const std::vector<interval_vec_t>& intervals);
|
||||
void insert_barrier(ir::instruction *instr, ir::builder &builder);
|
||||
bool intersect(const interval_vec_t &X, interval_t x);
|
||||
bool intersect(const interval_vec_t &X, const interval_vec_t &Y);
|
||||
void add_reference(ir::value *v, interval_vec_t &res);
|
||||
void get_read_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
void get_written_intervals(ir::instruction *i, interval_vec_t &res);
|
||||
std::pair<interval_vec_t, interval_vec_t> transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from,
|
||||
std::set<ir::instruction *> &insert_loc, std::set<triton::ir::value *> &safe_war);
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
@@ -15,6 +16,10 @@ namespace ir {
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class peephole {
|
||||
@@ -27,12 +32,18 @@ private:
|
||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
private:
|
||||
|
||||
public:
|
||||
peephole() {}
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
target* tgt_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
|
||||
|
28
include/triton/codegen/transform/pipeline.h
Normal file
28
include/triton/codegen/transform/pipeline.h
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
|
||||
// forward declaration
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
namespace transform {
|
||||
|
||||
class pipeline {
|
||||
public:
|
||||
pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {}
|
||||
void run(ir::module &module);
|
||||
|
||||
private:
|
||||
bool has_copy_async_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
26
include/triton/codegen/transform/reorder.h
Normal file
26
include/triton/codegen/transform/reorder.h
Normal file
@@ -0,0 +1,26 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reorder {
|
||||
public:
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -14,17 +14,15 @@ namespace driver
|
||||
class stream;
|
||||
|
||||
// Base
|
||||
class buffer : public polymorphic_resource<CUdeviceptr, cl_mem, host_buffer_t> {
|
||||
class buffer : public polymorphic_resource<CUdeviceptr, host_buffer_t> {
|
||||
public:
|
||||
buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership);
|
||||
buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership);
|
||||
buffer(size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(size_t size, host_buffer_t hst, bool take_ownership);
|
||||
uintptr_t addr_as_uintptr_t();
|
||||
static buffer* create(driver::context* ctx, size_t size);
|
||||
driver::context* context();
|
||||
size_t size();
|
||||
|
||||
protected:
|
||||
driver::context* context_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
@@ -32,22 +30,15 @@ protected:
|
||||
class host_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
host_buffer(driver::context* context, size_t size);
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class ocl_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
ocl_buffer(driver::context* context, size_t size);
|
||||
host_buffer(size_t size);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
cu_buffer(driver::context* context, size_t size);
|
||||
cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership);
|
||||
cu_buffer(size_t size);
|
||||
cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership);
|
||||
void set_zero(triton::driver::stream *queue, size_t size);
|
||||
};
|
||||
|
||||
|
@@ -11,13 +11,12 @@ namespace triton
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class context: public polymorphic_resource<CUcontext, cl_context, host_context_t>{
|
||||
class context: public polymorphic_resource<CUcontext, host_context_t>{
|
||||
protected:
|
||||
static std::string get_cache_path();
|
||||
|
||||
public:
|
||||
context(driver::device *dev, CUcontext cu, bool take_ownership);
|
||||
context(driver::device *dev, cl_context cl, bool take_ownership);
|
||||
context(driver::device *dev, host_context_t hst, bool take_ownership);
|
||||
driver::device* device() const;
|
||||
std::string const & cache_path() const;
|
||||
@@ -37,33 +36,14 @@ public:
|
||||
|
||||
// CUDA
|
||||
class cu_context: public context {
|
||||
public:
|
||||
class context_switcher{
|
||||
public:
|
||||
context_switcher(driver::context const & ctx);
|
||||
~context_switcher();
|
||||
private:
|
||||
driver::cu_context const & ctx_;
|
||||
};
|
||||
|
||||
private:
|
||||
static CUdevice get_device_of(CUcontext);
|
||||
|
||||
public:
|
||||
//Constructors
|
||||
cu_context(CUcontext cu, bool take_ownership = true);
|
||||
cu_context(driver::device* dev);
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class ocl_context: public context {
|
||||
public:
|
||||
ocl_context(driver::device* dev);
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,229 +0,0 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DRIVER_CUBLAS_H
|
||||
#define TDL_INCLUDE_DRIVER_CUBLAS_H
|
||||
|
||||
#include "isaac/templates/common.hpp"
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/tools/collections.hpp"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
enum cublasStrategy_t{
|
||||
CUBLAS_PREFER_FASTEST,
|
||||
CUBLAS_HEURISTICS
|
||||
};
|
||||
|
||||
|
||||
static const std::vector<cublasGemmAlgo_t> cublasAlgorithms = {
|
||||
CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3,
|
||||
CUBLAS_GEMM_ALGO4, CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7
|
||||
};
|
||||
|
||||
static const std::map<DType, cudaDataType> cudtype = {{FLOAT_TYPE, CUDA_R_32F}, {DOUBLE_TYPE,CUDA_R_64F}};
|
||||
static const std::map<char, cublasOperation_t> cuop = {{'N', CUBLAS_OP_N}, {'T', CUBLAS_OP_T}};
|
||||
|
||||
inline cublasGemmAlgo_t cublasGemmFastest(stream& stream, cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc){
|
||||
|
||||
typedef std::tuple<cudaDataType_t, cublasOperation_t, cublasOperation_t, int32_t, int32_t, int32_t> key_t;
|
||||
// Benchmark fastest algorithm in cublasGemmEx
|
||||
auto benchmark_fastest = [&](key_t const &){
|
||||
std::vector<double> times;
|
||||
for(cublasGemmAlgo_t a: cublasAlgorithms){
|
||||
try{
|
||||
times.push_back(bench([&](){ dispatch::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, a); },
|
||||
[&](){ stream.synchronize(); },
|
||||
stream.context().device()));
|
||||
}catch(driver::exception::cublas::base const &){
|
||||
times.push_back(INFINITY);
|
||||
}
|
||||
}
|
||||
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
|
||||
return cublasAlgorithms[argmin];
|
||||
};
|
||||
// Cache result
|
||||
static cpp::CachedMap<key_t, cublasGemmAlgo_t> cache(benchmark_fastest);
|
||||
return cache.get(std::make_tuple(cudt, AT, BT, M, N, K));
|
||||
}
|
||||
|
||||
/* Wrapper for cublasGemmEx */
|
||||
inline void cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo)
|
||||
{ dispatch::cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo); }
|
||||
|
||||
|
||||
/* Simplified API for default GEMM */
|
||||
inline void cublasGemm(DType dtype, stream& stream, char cAT, char cBT, int32_t M, int32_t N, int32_t K, scalar alpha, cu_buffer const & A, int32_t lda, cu_buffer const & B, int32_t ldb, scalar beta, cu_buffer& C, int32_t ldc, cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT){
|
||||
ContextSwitcher ctx_switch(stream.context());
|
||||
cublasHandle_t handle = dispatch::cublasHandle(stream.context());
|
||||
dispatch::cublasSetStream_v2(handle, (CUstream)stream);
|
||||
if(fastest)
|
||||
*fastest = cublasGemmFastest(stream, handle, cudtype.at(dtype), cuop.at(cAT), cuop.at(cBT), M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
|
||||
else
|
||||
cublasGemmEx(handle, cudtype.at(dtype), cuop.at(cAT), cuop.at(cBT), M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc, algo);
|
||||
}
|
||||
|
||||
inline cudnnDataType_t cudnnDtype(DType dtype){
|
||||
switch(dtype){
|
||||
case INT8X4_TYPE: return CUDNN_DATA_INT8x4;
|
||||
case INT32_TYPE: return CUDNN_DATA_INT32;
|
||||
case FLOAT_TYPE: return CUDNN_DATA_FLOAT;
|
||||
case DOUBLE_TYPE: return CUDNN_DATA_DOUBLE;
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
inline cudnnTensorFormat_t format(cudnnDataType_t cutype){
|
||||
switch(cutype){
|
||||
case CUDNN_DATA_INT8x4: return CUDNN_TENSOR_NCHW_VECT_C;
|
||||
default: return CUDNN_TENSOR_NCHW;
|
||||
}
|
||||
}
|
||||
|
||||
inline void cudnnConv(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t C, int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, cu_buffer const & I, cu_buffer const & F, scalar beta, cu_buffer const & O){
|
||||
driver::driver::context const & ctx = stream.context();
|
||||
ContextSwitcher switch_ctx(ctx);
|
||||
|
||||
std::vector<int> pad = {pad_d, pad_h, pad_w};
|
||||
std::vector<int> stride = {stride_d, stride_h, stride_w};
|
||||
std::vector<int> upscale = {1, 1, 1};
|
||||
std::vector<int> Oshapes = {N, K, M, P, Q};
|
||||
std::vector<int> Fshapes = {K, C, T, R, S};
|
||||
std::vector<int> Ishapes = {N, C, D, H, W};
|
||||
if(M == 1 && T == 1 && D == 1){
|
||||
pad.erase(pad.begin());
|
||||
stride.erase(stride.begin());
|
||||
upscale.erase(upscale.begin());
|
||||
Oshapes.erase(Oshapes.begin() + 2);
|
||||
Ishapes.erase(Ishapes.begin() + 2);
|
||||
Fshapes.erase(Fshapes.begin() + 2);
|
||||
}
|
||||
|
||||
cudnnHandle_t handle = dispatch::cudnnHandle(ctx);
|
||||
cudnnDataType_t in_cutype = cudnnDtype(dtype);
|
||||
cudnnDataType_t conv_cutype = (dtype == INT8X4_TYPE)?CUDNN_DATA_INT32:in_cutype;
|
||||
|
||||
dispatch::cudnnSetStream(handle, (CUstream)stream);
|
||||
cudnnTensorDescriptor_t tO, tI;
|
||||
cudnnFilterDescriptor_t tF;
|
||||
cudnnConvolutionDescriptor_t conv;
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
dispatch::cudnnCreateTensorDescriptor(&tO);
|
||||
dispatch::cudnnCreateTensorDescriptor(&tI);
|
||||
dispatch::cudnnCreateFilterDescriptor(&tF);
|
||||
|
||||
dispatch::cudnnSetTensorNdDescriptorEx(tO, format(in_cutype), in_cutype, Oshapes.size(), Oshapes.data());
|
||||
dispatch::cudnnSetFilterNdDescriptor(tF, in_cutype, format(in_cutype), Fshapes.size(), Fshapes.data());
|
||||
dispatch::cudnnSetTensorNdDescriptorEx(tI, format(in_cutype), in_cutype, Ishapes.size(), Ishapes.data());
|
||||
|
||||
dispatch::cudnnCreateConvolutionDescriptor(&conv);
|
||||
dispatch::cudnnSetConvolutionNdDescriptor(conv, pad.size(), pad.data(), stride.data(), upscale.data(), CUDNN_CROSS_CORRELATION, conv_cutype);
|
||||
dispatch::cudnnGetConvolutionForwardAlgorithm(handle, tI, tF, conv, tO, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 1024*1024*64, &algo);
|
||||
|
||||
size_t workspace_size;
|
||||
dispatch::cudnnGetConvolutionForwardWorkspaceSize(handle, tI, tF, conv, tO, algo, &workspace_size);
|
||||
static cu_buffer work(ctx, 1024*1024*64);
|
||||
CUdeviceptr twork = work;
|
||||
CUdeviceptr pI = I, pF = F, pO = O;
|
||||
dispatch::cudnnConvolutionForward(handle, alpha.data(), tI, (void*)pI, tF, (void*)pF, conv, algo, (void*)twork, workspace_size, beta.data(), tO, (void*)pO);
|
||||
}
|
||||
|
||||
|
||||
inline void cudnnPool(DType dtype, stream& stream, int32_t D, int32_t H, int32_t W, int32_t N, int32_t K, int32_t M, int32_t P, int32_t Q, int32_t T, int32_t R, int32_t S,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w, int32_t stride_d, int32_t stride_h, int32_t stride_w, scalar alpha, cu_buffer const & I, scalar beta, cu_buffer const & O){
|
||||
driver::driver::context const & ctx = stream.context();
|
||||
ContextSwitcher switch_ctx(ctx);
|
||||
|
||||
std::vector<int> pad = {pad_d, pad_h, pad_w};
|
||||
std::vector<int> stride = {stride_d, stride_h, stride_w};
|
||||
std::vector<int> upscale = {1, 1, 1};
|
||||
std::vector<int> Oshapes = {N, K, M, P, Q};
|
||||
std::vector<int> Ishapes = {N, K, D, H, W};
|
||||
std::vector<int> window = {T, R, S};
|
||||
if(M == 1 && T == 1 && D == 1){
|
||||
window.erase(window.begin());
|
||||
pad.erase(pad.begin());
|
||||
stride.erase(stride.begin());
|
||||
upscale.erase(upscale.begin());
|
||||
Oshapes.erase(Oshapes.begin() + 2);
|
||||
Ishapes.erase(Ishapes.begin() + 2);
|
||||
}
|
||||
|
||||
cudnnHandle_t handle = dispatch::cudnnHandle(ctx);
|
||||
cudnnDataType_t cutype = cudnnDtype(dtype);
|
||||
|
||||
dispatch::cudnnSetStream(handle, (CUstream)stream);
|
||||
cudnnTensorDescriptor_t tO, tI;
|
||||
cudnnPoolingDescriptor_t desc;
|
||||
dispatch::cudnnCreateTensorDescriptor(&tO);
|
||||
dispatch::cudnnCreateTensorDescriptor(&tI);
|
||||
|
||||
dispatch::cudnnSetTensorNdDescriptorEx(tO, CUDNN_TENSOR_NCHW, cutype, Oshapes.size(), Oshapes.data());
|
||||
dispatch::cudnnSetTensorNdDescriptorEx(tI, CUDNN_TENSOR_NCHW, cutype, Ishapes.size(), Ishapes.data());
|
||||
|
||||
dispatch::cudnnCreatePoolingDescriptor(&desc);
|
||||
dispatch::cudnnSetPoolingNdDescriptor(desc, CUDNN_POOLING_MAX, CUDNN_NOT_PROPAGATE_NAN, window.size(), window.data(), pad.data(), stride.data());
|
||||
|
||||
CUdeviceptr pI = I, pO = O;
|
||||
dispatch::cudnnPoolingForward(handle, desc, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO);
|
||||
}
|
||||
|
||||
inline void cudnnTransformTensor(driver::cu_stream & stream,
|
||||
DType in_dtype, DType out_dtype,
|
||||
cudnnTensorFormat_t in_layout, cudnnTensorFormat_t out_layout,
|
||||
int32_t N, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
scalar alpha, driver::cu_buffer const & I, scalar beta, driver::cu_buffer& O)
|
||||
{
|
||||
cudnnHandle_t handle = dispatch::cudnnHandle(stream.context());
|
||||
dispatch::cudnnSetStream(handle, (CUstream)stream);
|
||||
|
||||
cudnnTensorDescriptor_t tO, tI;
|
||||
std::vector<int> shapes = {N, C, D, H, W};
|
||||
dispatch::cudnnCreateTensorDescriptor(&tI);
|
||||
dispatch::cudnnSetTensorNdDescriptorEx(tI, in_layout, cudnnDtype(in_dtype), shapes.size(), shapes.data());
|
||||
dispatch::cudnnCreateTensorDescriptor(&tO);
|
||||
dispatch::cudnnSetTensorNdDescriptorEx(tO, out_layout, cudnnDtype(out_dtype), shapes.size(), shapes.data());
|
||||
|
||||
CUdeviceptr pI = I, pO = O;
|
||||
dispatch::cudnnTransformTensor(handle, alpha.data(), tI, (void*)pI, beta.data(), tO, (void*)pO);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
@@ -20,7 +20,7 @@ namespace driver
|
||||
class context;
|
||||
|
||||
// Base device
|
||||
class device: public polymorphic_resource<CUdevice, cl_device_id, host_device_t>{
|
||||
class device: public polymorphic_resource<CUdevice, host_device_t>{
|
||||
public:
|
||||
using polymorphic_resource::polymorphic_resource;
|
||||
virtual size_t max_threads_per_block() const = 0;
|
||||
@@ -37,54 +37,25 @@ public:
|
||||
std::unique_ptr<codegen::target> make_target() const;
|
||||
};
|
||||
|
||||
// OpenCL device
|
||||
class ocl_device: public device {
|
||||
public:
|
||||
ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { }
|
||||
size_t max_threads_per_block() const;
|
||||
size_t max_shared_memory() const;
|
||||
std::unique_ptr<codegen::target> make_target() const;
|
||||
};
|
||||
|
||||
// CUDA device
|
||||
class cu_device: public device {
|
||||
public:
|
||||
//Supported architectures
|
||||
enum class Architecture{
|
||||
//NVidia
|
||||
SM_2_0,
|
||||
SM_2_1,
|
||||
SM_3_0,
|
||||
SM_3_5,
|
||||
SM_3_7,
|
||||
SM_5_0,
|
||||
SM_5_2,
|
||||
SM_6_0,
|
||||
SM_6_1,
|
||||
SM_7_0,
|
||||
UNKNOWN
|
||||
};
|
||||
|
||||
private:
|
||||
//Metaprogramming elper to get cuda info from attribute
|
||||
template<CUdevice_attribute attr>
|
||||
int cuGetInfo() const;
|
||||
|
||||
inline Architecture nv_arch(std::pair<unsigned int, unsigned int> sm) const;
|
||||
inline nvmlDevice_t nvml_device() const;
|
||||
|
||||
public:
|
||||
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
|
||||
// Accessors
|
||||
Architecture architecture() const;
|
||||
// Informations
|
||||
std::string infos() const;
|
||||
size_t address_bits() const;
|
||||
std::vector<size_t> max_block_dim() const;
|
||||
size_t warp_size() const;
|
||||
// Compute Capability
|
||||
void interpret_as(std::pair<size_t, size_t> cc);
|
||||
std::pair<size_t, size_t> compute_capability() const;
|
||||
void interpret_as(int cc);
|
||||
int compute_capability() const;
|
||||
// Identifier
|
||||
std::string name() const;
|
||||
std::string pci_bus_id() const;
|
||||
@@ -100,7 +71,7 @@ public:
|
||||
std::unique_ptr<codegen::target> make_target() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::pair<size_t, size_t>> interpreted_as_;
|
||||
std::shared_ptr<int> interpreted_as_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -9,8 +9,6 @@
|
||||
//CUDA Backend
|
||||
#include "triton/external/CUDA/cuda.h"
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
#include "triton/external/CL/cl.h"
|
||||
#include "triton/external/CL/cl_ext.h"
|
||||
|
||||
//Exceptions
|
||||
#include <iostream>
|
||||
@@ -30,7 +28,6 @@ class cu_context;
|
||||
|
||||
template<class T> void check(T){}
|
||||
void check(CUresult err);
|
||||
void check(cl_int err);
|
||||
|
||||
class dispatch
|
||||
{
|
||||
@@ -61,48 +58,11 @@ protected:
|
||||
}
|
||||
|
||||
public:
|
||||
static bool clinit();
|
||||
static bool nvmlinit();
|
||||
static bool cuinit();
|
||||
static bool spvllvminit();
|
||||
static void release();
|
||||
|
||||
// OpenCL
|
||||
static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *);
|
||||
static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *);
|
||||
static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *);
|
||||
static cl_int clReleaseMemObject(cl_mem);
|
||||
static cl_int clFinish(cl_command_queue);
|
||||
static cl_int clGetMemObjectInfo(cl_mem, cl_mem_info, size_t, void *, size_t *);
|
||||
static cl_int clGetCommandQueueInfo(cl_command_queue, cl_command_queue_info, size_t, void *, size_t *);
|
||||
static cl_int clReleaseContext(cl_context);
|
||||
static cl_int clReleaseEvent(cl_event);
|
||||
static cl_int clEnqueueWriteBuffer(cl_command_queue, cl_mem, cl_bool, size_t, size_t, const void *, cl_uint, const cl_event *, cl_event *);
|
||||
static cl_int clEnqueueReadBuffer(cl_command_queue, cl_mem, cl_bool, size_t, size_t, void *, cl_uint, const cl_event *, cl_event *);
|
||||
static cl_int clGetProgramBuildInfo(cl_program, cl_device_id, cl_program_build_info, size_t, void *, size_t *);
|
||||
static cl_int clReleaseDevice(cl_device_id);
|
||||
static cl_context clCreateContext(const cl_context_properties *, cl_uint, const cl_device_id *, void (*)(const char *, const void *, size_t, void *), void *, cl_int *);
|
||||
static cl_int clGetDeviceIDs(cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *);
|
||||
static cl_int clGetContextInfo(cl_context, cl_context_info, size_t, void *, size_t *);
|
||||
static cl_int clGetDeviceInfo(cl_device_id, cl_device_info, size_t, void *, size_t *);
|
||||
static cl_int clReleaseCommandQueue(cl_command_queue);
|
||||
static cl_int clGetPlatformIDs(cl_uint, cl_platform_id *, cl_uint *);
|
||||
static cl_int clGetPlatformInfo(cl_platform_id, cl_platform_info, size_t, void *, size_t *);
|
||||
static cl_int clGetEventProfilingInfo(cl_event, cl_profiling_info, size_t, void *, size_t *);
|
||||
static cl_program clCreateProgramWithBinary(cl_context, cl_uint, const cl_device_id *, const size_t *, const unsigned char **, cl_int *, cl_int *);
|
||||
static cl_command_queue clCreateCommandQueue(cl_context, cl_device_id, cl_command_queue_properties, cl_int *);
|
||||
static cl_int clRetainEvent(cl_event);
|
||||
static cl_int clReleaseProgram(cl_program);
|
||||
static cl_int clFlush(cl_command_queue);
|
||||
static cl_int clGetProgramInfo(cl_program, cl_program_info, size_t, void *, size_t *);
|
||||
static cl_int clGetKernelInfo(cl_kernel, cl_kernel_info, size_t, void *, size_t *);
|
||||
static cl_int clGetKernelWorkGroupInfo(cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *);
|
||||
static cl_kernel clCreateKernel(cl_program, const char *, cl_int *);
|
||||
static cl_int clCreateKernelsInProgram(cl_program, cl_uint, cl_kernel*, cl_uint*);
|
||||
static cl_mem clCreateBuffer(cl_context, cl_mem_flags, size_t, void *, cl_int *);
|
||||
static cl_program clCreateProgramWithSource(cl_context, cl_uint, const char **, const size_t *, cl_int *);
|
||||
static cl_int clReleaseKernel(cl_kernel);
|
||||
|
||||
// CUDA
|
||||
static CUresult cuCtxGetCurrent(CUcontext *pctx);
|
||||
static CUresult cuCtxSetCurrent(CUcontext ctx);
|
||||
@@ -133,6 +93,7 @@ public:
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
|
||||
static CUresult cuStreamSynchronize(CUstream hStream);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
|
||||
static CUresult cuStreamDestroy_v2(CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
@@ -157,7 +118,6 @@ public:
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
static void* opencl_;
|
||||
static void* cuda_;
|
||||
static void* nvml_;
|
||||
static void* vulkan_;
|
||||
@@ -165,41 +125,6 @@ private:
|
||||
static void* spvcross_;
|
||||
static void* opengl_;
|
||||
|
||||
// OpenCL functions
|
||||
static void* clBuildProgram_;
|
||||
static void* clEnqueueNDRangeKernel_;
|
||||
static void* clSetKernelArg_;
|
||||
static void* clReleaseMemObject_;
|
||||
static void* clFinish_;
|
||||
static void* clGetMemObjectInfo_;
|
||||
static void* clGetCommandQueueInfo_;
|
||||
static void* clReleaseContext_;
|
||||
static void* clReleaseEvent_;
|
||||
static void* clEnqueueWriteBuffer_;
|
||||
static void* clEnqueueReadBuffer_;
|
||||
static void* clGetProgramBuildInfo_;
|
||||
static void* clReleaseDevice_;
|
||||
static void* clCreateContext_;
|
||||
static void* clGetDeviceIDs_;
|
||||
static void* clGetContextInfo_;
|
||||
static void* clGetDeviceInfo_;
|
||||
static void* clReleaseCommandQueue_;
|
||||
static void* clGetPlatformIDs_;
|
||||
static void* clGetPlatformInfo_;
|
||||
static void* clGetEventProfilingInfo_;
|
||||
static void* clCreateProgramWithBinary_;
|
||||
static void* clCreateCommandQueue_;
|
||||
static void* clRetainEvent_;
|
||||
static void* clReleaseProgram_;
|
||||
static void* clFlush_;
|
||||
static void* clGetProgramInfo_;
|
||||
static void* clGetKernelInfo_;
|
||||
static void* clGetKernelWorkGroupInfo_;
|
||||
static void* clCreateKernel_;
|
||||
static void* clCreateKernelsInProgram_;
|
||||
static void* clCreateBuffer_;
|
||||
static void* clCreateProgramWithSource_;
|
||||
static void* clReleaseKernel_;
|
||||
|
||||
// CUDA functions
|
||||
static void* cuCtxGetCurrent_;
|
||||
@@ -230,6 +155,7 @@ private:
|
||||
static void* cuModuleGetFunction_;
|
||||
static void* cuStreamSynchronize_;
|
||||
static void* cuStreamDestroy_v2_;
|
||||
static void* cuStreamGetCtx_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
|
@@ -19,18 +19,18 @@ namespace triton
|
||||
namespace nvrtc
|
||||
{
|
||||
|
||||
#define ISAAC_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
|
||||
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
|
||||
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
||||
ISAAC_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
||||
|
||||
#undef ISAAC_CREATE_NVRTC_EXCEPTION
|
||||
#undef TRITON_CREATE_NVRTC_EXCEPTION
|
||||
}
|
||||
|
||||
|
||||
@@ -38,169 +38,109 @@ namespace triton
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define ISAAC_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
|
||||
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
|
||||
|
||||
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
||||
ISAAC_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef ISAAC_CREATE_CUDA_EXCEPTION
|
||||
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||
}
|
||||
|
||||
namespace cublas
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define ISAAC_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
|
||||
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
|
||||
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
||||
ISAAC_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef ISAAC_CREATE_CUBLAS_EXCEPTION
|
||||
#undef TRITON_CREATE_CUBLAS_EXCEPTION
|
||||
}
|
||||
|
||||
namespace cudnn
|
||||
{
|
||||
#define ISAAC_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
|
||||
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
|
||||
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
||||
ISAAC_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
||||
}
|
||||
|
||||
namespace ocl
|
||||
{
|
||||
|
||||
class base: public std::exception{};
|
||||
|
||||
#define ISAAC_CREATE_CL_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "OpenCL: Error- " msg; } }
|
||||
|
||||
|
||||
ISAAC_CREATE_CL_EXCEPTION(device_not_found, "device not found");
|
||||
ISAAC_CREATE_CL_EXCEPTION(device_not_available, "device not available");
|
||||
ISAAC_CREATE_CL_EXCEPTION(compiler_not_available, "compiler not available");
|
||||
ISAAC_CREATE_CL_EXCEPTION(mem_object_allocation_failure, "object allocation failure");
|
||||
ISAAC_CREATE_CL_EXCEPTION(out_of_resources, "launch out of resources");
|
||||
ISAAC_CREATE_CL_EXCEPTION(out_of_host_memory, "out of host memory");
|
||||
ISAAC_CREATE_CL_EXCEPTION(profiling_info_not_available, "profiling info not available");
|
||||
ISAAC_CREATE_CL_EXCEPTION(mem_copy_overlap, "mem copy overlap");
|
||||
ISAAC_CREATE_CL_EXCEPTION(image_format_mismatch, "image format mismatch");
|
||||
ISAAC_CREATE_CL_EXCEPTION(image_format_not_supported, "image format not supported");
|
||||
ISAAC_CREATE_CL_EXCEPTION(build_program_failure, "build program failure");
|
||||
ISAAC_CREATE_CL_EXCEPTION(map_failure, "map failure");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_value, "invalid value");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_device_type, "invalid device type");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_platform, "invalid platform");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_device, "invalid device");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_context, "invalid context");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_queue_properties, "invalid queue properties");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_command_queue, "invalid command queue");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_host_ptr, "invalid host pointer");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_mem_object, "invalid mem object");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_image_format_descriptor, "invalid image format descriptor");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_image_size, "invalid image size");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_sampler, "invalid sampler");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_binary, "invalid binary");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_build_options, "invalid build options");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_program, "invalid program");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_program_executable, "invalid program executable");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel_name, "invalid kernel name");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel_definition, "invalid kernel definition");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel, "invalid kernel");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_arg_index, "invalid arg index");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_arg_value, "invalid arg value");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_arg_size, "invalid arg size");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_kernel_args, "invalid kernel args");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_work_dimension, "invalid work dimension");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_work_group_size, "invalid work group size");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_work_item_size, "invalid work item size");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_global_offset, "invalid global offset");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_event_wait_list, "invalid event wait list");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_event, "invalid event");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_operation, "invalid operation");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_gl_object, "invalid GL object");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_buffer_size, "invalid buffer size");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_mip_level, "invalid MIP level");
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_global_work_size, "invalid global work size");
|
||||
#ifdef CL_INVALID_PROPERTY
|
||||
ISAAC_CREATE_CL_EXCEPTION(invalid_property, "invalid property");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_EVENT_H_
|
||||
#define _TRITON_DRIVER_EVENT_H_
|
||||
|
||||
#include "triton/driver/handle.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
// event
|
||||
class event
|
||||
{
|
||||
public:
|
||||
float elapsed_time() const;
|
||||
handle<cu_event_t> const & cu() const;
|
||||
|
||||
private:
|
||||
handle<cu_event_t> cu_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -9,6 +9,15 @@
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "llvm/ExecutionEngine/JITSymbol.h"
|
||||
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Core.h"
|
||||
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "triton/tools/thread_pool.h"
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
@@ -24,7 +33,6 @@ namespace driver
|
||||
|
||||
enum backend_t {
|
||||
CUDA,
|
||||
OpenCL,
|
||||
Host
|
||||
};
|
||||
|
||||
@@ -42,13 +50,23 @@ struct host_context_t{
|
||||
};
|
||||
|
||||
struct host_stream_t{
|
||||
|
||||
std::shared_ptr<ThreadPool> pool;
|
||||
std::shared_ptr<std::vector<std::future<void>>> futures;
|
||||
std::vector<std::shared_ptr<char*>> args;
|
||||
};
|
||||
|
||||
struct host_module_t{
|
||||
std::string error;
|
||||
llvm::ExecutionEngine* engine;
|
||||
std::map<std::string, llvm::Function*> functions;
|
||||
void(*fn)(char**, int32_t, int32_t, int32_t);
|
||||
llvm::orc::ExecutionSession* ES;
|
||||
llvm::orc::RTDyldObjectLinkingLayer* ObjectLayer;
|
||||
llvm::orc::IRCompileLayer* CompileLayer;
|
||||
llvm::DataLayout* DL;
|
||||
llvm::orc::MangleAndInterner* Mangle;
|
||||
llvm::orc::ThreadSafeContext* Ctx;
|
||||
llvm::orc::JITDylib *MainJD;
|
||||
};
|
||||
|
||||
struct host_function_t{
|
||||
@@ -103,24 +121,20 @@ protected:
|
||||
bool has_ownership_;
|
||||
};
|
||||
|
||||
template<class CUType, class CLType, class HostType>
|
||||
template<class CUType, class HostType>
|
||||
class polymorphic_resource {
|
||||
public:
|
||||
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){}
|
||||
polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership), backend_(OpenCL){}
|
||||
polymorphic_resource(HostType hst, bool take_ownership): hst_(hst, take_ownership), backend_(Host){}
|
||||
virtual ~polymorphic_resource() { }
|
||||
|
||||
handle<CUType> cu() { return cu_; }
|
||||
handle<CLType> cl() { return cl_; }
|
||||
handle<HostType> hst() { return hst_; }
|
||||
const handle<CUType>& cu() const { return cu_; }
|
||||
const handle<CLType>& cl() const { return cl_; }
|
||||
const handle<HostType>& hst() const { return hst_; }
|
||||
backend_t backend() { return backend_; }
|
||||
|
||||
protected:
|
||||
handle<CLType> cl_;
|
||||
handle<CUType> cu_;
|
||||
handle<HostType> hst_;
|
||||
backend_t backend_;
|
||||
|
@@ -21,19 +21,12 @@ namespace driver
|
||||
class cu_buffer;
|
||||
|
||||
// Base
|
||||
class kernel: public polymorphic_resource<CUfunction, cl_kernel, host_function_t> {
|
||||
class kernel: public polymorphic_resource<CUfunction, host_function_t> {
|
||||
public:
|
||||
kernel(driver::module* program, CUfunction fn, bool has_ownership);
|
||||
kernel(driver::module* program, cl_kernel fn, bool has_ownership);
|
||||
kernel(driver::module* program, host_function_t fn, bool has_ownership);
|
||||
// Getters
|
||||
driver::module* module();
|
||||
// Factory methods
|
||||
static kernel* create(driver::module* program, const char* name);
|
||||
// Arguments setters
|
||||
virtual void setArg(unsigned int index, std::size_t size, void* ptr) = 0;
|
||||
virtual void setArg(unsigned int index, buffer *) = 0;
|
||||
template<class T> void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); }
|
||||
private:
|
||||
driver::module* program_;
|
||||
};
|
||||
@@ -43,25 +36,6 @@ class host_kernel: public kernel {
|
||||
public:
|
||||
//Constructors
|
||||
host_kernel(driver::module* program, const char* name);
|
||||
// Arguments setters
|
||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
||||
void setArg(unsigned int index, driver::buffer* buffer);
|
||||
// Params
|
||||
const std::vector<void*>& params();
|
||||
private:
|
||||
std::vector<std::shared_ptr<void> > params_store_;
|
||||
std::vector<void*> params_;
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class ocl_kernel: public kernel {
|
||||
public:
|
||||
//Constructors
|
||||
ocl_kernel(driver::module* program, const char* name);
|
||||
// Arguments setters
|
||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
||||
void setArg(unsigned int index, driver::buffer* buffer);
|
||||
|
||||
};
|
||||
|
||||
// CUDA
|
||||
@@ -69,15 +43,6 @@ class cu_kernel: public kernel {
|
||||
public:
|
||||
//Constructors
|
||||
cu_kernel(driver::module* program, const char * name);
|
||||
// Arguments setters
|
||||
void setArg(unsigned int index, std::size_t size, void* ptr);
|
||||
void setArg(unsigned int index, driver::buffer* buffer);
|
||||
//Arguments getters
|
||||
void* const* cu_params() const;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<void> > cu_params_store_;
|
||||
std::vector<void*> cu_params_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -25,7 +25,7 @@ class cu_context;
|
||||
class cu_device;
|
||||
|
||||
// Base
|
||||
class module: public polymorphic_resource<CUmodule, cl_program, host_module_t> {
|
||||
class module: public polymorphic_resource<CUmodule, host_module_t> {
|
||||
protected:
|
||||
void init_llvm();
|
||||
|
||||
@@ -35,49 +35,43 @@ protected:
|
||||
};
|
||||
|
||||
public:
|
||||
module(driver::context* ctx, CUmodule mod, bool has_ownership);
|
||||
module(driver::context* ctx, cl_program mod, bool has_ownership);
|
||||
module(driver::context* ctx, host_module_t mod, bool has_ownership);
|
||||
static module* create(driver::context* ctx, std::unique_ptr<llvm::Module> src);
|
||||
driver::context* context() const;
|
||||
module(CUmodule mod, bool has_ownership);
|
||||
module(host_module_t mod, bool has_ownership);
|
||||
static module* create(driver::device* device, std::unique_ptr<llvm::Module> src);
|
||||
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
||||
const std::string &proc, std::string layout,
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string &features,
|
||||
file_type_t file_type);
|
||||
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
||||
|
||||
int spilled() const { return spilled_; }
|
||||
|
||||
protected:
|
||||
driver::context* ctx_;
|
||||
int spilled_;
|
||||
};
|
||||
|
||||
// CPU
|
||||
class host_module: public module{
|
||||
public:
|
||||
host_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class ocl_module: public module{
|
||||
public:
|
||||
ocl_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
host_module(std::unique_ptr<llvm::Module> module);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_module: public module {
|
||||
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
|
||||
void init_from_ptx(const std::string& ptx);
|
||||
|
||||
public:
|
||||
cu_module(driver::context* context, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::context* context, const std::string& source);
|
||||
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::device* device, const std::string& source);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
const std::string& source() const { return source_; }
|
||||
std::string llir() const { return llir_; }
|
||||
const std::string& ptx() const { return ptx_; }
|
||||
|
||||
private:
|
||||
std::string source_;
|
||||
std::string ptx_;
|
||||
std::string llir_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -42,18 +42,6 @@ private:
|
||||
handle<CUPlatform> cu_;
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class cl_platform: public platform
|
||||
{
|
||||
public:
|
||||
cl_platform(cl_platform_id cl): platform("OpenCL"), cl_(cl) { }
|
||||
std::string version() const;
|
||||
void devices(std::vector<driver::device*> &devices) const;
|
||||
|
||||
private:
|
||||
handle<cl_platform_id> cl_;
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_platform: public platform
|
||||
{
|
||||
|
@@ -21,18 +21,15 @@ class Range;
|
||||
class cu_buffer;
|
||||
|
||||
// Base
|
||||
class stream: public polymorphic_resource<CUstream, cl_command_queue, host_stream_t> {
|
||||
class stream: public polymorphic_resource<CUstream, host_stream_t> {
|
||||
public:
|
||||
stream(driver::context *ctx, CUstream, bool has_ownership);
|
||||
stream(driver::context *ctx, cl_command_queue, bool has_ownership);
|
||||
stream(driver::context *ctx, host_stream_t, bool has_ownership);
|
||||
stream(CUstream, bool has_ownership);
|
||||
stream(host_stream_t, bool has_ownership);
|
||||
// factory
|
||||
static driver::stream* create(driver::context* ctx);
|
||||
// accessors
|
||||
driver::context* context() const;
|
||||
static driver::stream* create(backend_t backend);
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const * = NULL, event *event = NULL, void **extra = NULL) = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
@@ -40,33 +37,14 @@ public:
|
||||
{ write(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
template<class T> void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T>& x)
|
||||
{ read(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
|
||||
protected:
|
||||
driver::context *ctx_;
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_stream: public stream {
|
||||
public:
|
||||
// Constructors
|
||||
host_stream(driver::context *ctx);
|
||||
|
||||
// Overridden
|
||||
host_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
||||
// OpenCL
|
||||
class cl_stream: public stream {
|
||||
public:
|
||||
// Constructors
|
||||
cl_stream(driver::context *ctx);
|
||||
|
||||
// Overridden
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
@@ -74,13 +52,10 @@ public:
|
||||
// CUDA
|
||||
class cu_stream: public stream {
|
||||
public:
|
||||
// Constructors
|
||||
cu_stream(CUstream str, bool take_ownership);
|
||||
cu_stream(driver::context* context);
|
||||
|
||||
// Overridden
|
||||
cu_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event *event, void **extra);
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
@@ -35,6 +35,7 @@ public:
|
||||
basic_block* get_insert_block() { return block_; }
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
// Types
|
||||
@@ -136,8 +137,9 @@ public:
|
||||
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
|
||||
value *create_exp(value* arg, const std::string &name = "");
|
||||
value *create_log(value* arg, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
@@ -145,8 +147,10 @@ public:
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
|
@@ -7,7 +7,7 @@ namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
enum binary_op_t {
|
||||
enum binary_op_t: unsigned int{
|
||||
Add,
|
||||
FAdd,
|
||||
Sub,
|
||||
@@ -28,7 +28,7 @@ enum binary_op_t {
|
||||
Xor
|
||||
};
|
||||
|
||||
enum cast_op_t {
|
||||
enum cast_op_t: unsigned int {
|
||||
Trunc,
|
||||
ZExt,
|
||||
SExt,
|
||||
@@ -44,7 +44,7 @@ enum cast_op_t {
|
||||
AddrSpaceCast
|
||||
};
|
||||
|
||||
enum cmp_pred_t {
|
||||
enum cmp_pred_t: unsigned int {
|
||||
FIRST_FCMP_PREDICATE,
|
||||
FCMP_FALSE,
|
||||
FCMP_OEQ,
|
||||
@@ -113,6 +113,7 @@ enum value_id_t: unsigned {
|
||||
// io
|
||||
INST_UNMASKED_LOAD,
|
||||
INST_MASKED_LOAD,
|
||||
INST_MASKED_LOAD_ASYNC,
|
||||
INST_UNMASKED_STORE,
|
||||
INST_MASKED_STORE,
|
||||
// retile
|
||||
@@ -129,6 +130,7 @@ enum value_id_t: unsigned {
|
||||
INST_ATOMIC_ADD,
|
||||
// math
|
||||
INST_EXP,
|
||||
INST_LOG,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
@@ -138,6 +140,7 @@ enum value_id_t: unsigned {
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_RECOALESCE,
|
||||
INST_BARRIER,
|
||||
INST_ASYNC_WAIT,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE
|
||||
|
@@ -72,6 +72,7 @@ public:
|
||||
case noalias: return ".noalias";
|
||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||
case multiple_of: return ".readonly";
|
||||
case retune: return ".retunr";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
|
@@ -64,9 +64,10 @@ public:
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
for(auto it = op_begin(); it != op_end(); it++)
|
||||
(*it)->add_use(res);
|
||||
// for(auto it = op_begin(); it != op_end(); it++)
|
||||
// (*it)->add_use(res);
|
||||
res->parent_ = nullptr;
|
||||
res->users_.clear();
|
||||
return res;
|
||||
}
|
||||
// instruction id
|
||||
@@ -91,6 +92,7 @@ private:
|
||||
public:
|
||||
void set_incoming_value(unsigned i, value *v);
|
||||
void set_incoming_block(unsigned i, basic_block *block);
|
||||
value *get_value_for_block(basic_block *block);
|
||||
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
||||
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
||||
unsigned get_num_incoming() { return get_num_operands(); }
|
||||
@@ -431,6 +433,37 @@ public:
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
||||
};
|
||||
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async_async"; }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
|
||||
};
|
||||
|
||||
class atomic_add_inst: public io_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
@@ -601,16 +634,6 @@ public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class atomic_add_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
@@ -623,6 +646,18 @@ public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class log_inst: public builtin_inst {
|
||||
private:
|
||||
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "log"; }
|
||||
_TRITON_DEFINE_CLONE(log_inst)
|
||||
_TRITON_DEFINE_ACCEPT(log_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
@@ -713,6 +748,9 @@ private:
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
value* get_pred_op() { return get_operand(0); }
|
||||
value* get_if_value_op() { return get_operand(1); }
|
||||
value* get_else_value_op() { return get_operand(2); }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -743,6 +781,7 @@ public:
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
|
||||
class recoalesce_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
@@ -766,6 +805,22 @@ public:
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class async_wait_inst: public instruction{
|
||||
private:
|
||||
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
|
||||
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||
|
||||
public:
|
||||
static async_wait_inst* create(context &ctx, int N,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
int get_N() { return N_; }
|
||||
|
||||
private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
// On NVIDIA, implementation is such that
|
||||
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
// so as to enable re-association on nv_static_program_idx which is constant
|
||||
|
@@ -8,8 +8,9 @@
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "builder.h"
|
||||
#include "metadata.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/context.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
@@ -60,7 +61,7 @@ private:
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, context &ctx);
|
||||
module(const std::string &name);
|
||||
context& get_context();
|
||||
builder& get_builder();
|
||||
// Setters
|
||||
@@ -94,7 +95,7 @@ public:
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
context &context_;
|
||||
context context_;
|
||||
builder builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<val_key_t, type*> types_;
|
||||
|
@@ -49,6 +49,7 @@ class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class exp_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_program_inst;
|
||||
@@ -64,7 +65,9 @@ class select_inst;
|
||||
class recoalesce_inst;
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class masked_load_async_inst;
|
||||
class barrier_inst;
|
||||
class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
|
||||
@@ -117,6 +120,7 @@ public:
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
@@ -137,7 +141,9 @@ public:
|
||||
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
virtual void visit_make_range_dyn(make_range_dyn*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
|
||||
|
@@ -48,6 +48,7 @@ protected:
|
||||
|
||||
void set_ret(ir::value* value);
|
||||
ir::value *GenUnaryMinus(ir::value* arg);
|
||||
ir::value *GenUnaryInc(UnaryOp* arg, bool is_postfix, bool is_inc);
|
||||
|
||||
public:
|
||||
Generator(Parser* parser) : parser_(parser) {}
|
||||
|
@@ -83,7 +83,7 @@ public:
|
||||
Constant* ParseSizeof();
|
||||
Constant* ParseAlignof();
|
||||
UnaryOp* ParsePrefixIncDec(const Token* tok);
|
||||
UnaryOp* ParseUnaryIntrinsicOp(const Token* tok, int op);
|
||||
UnaryOp* ParseUnaryIntrinsicOp(int op);
|
||||
UnaryOp* ParseUnaryOp(const Token* tok, int op);
|
||||
Expr* ParseDerefOp(const Token* tok);
|
||||
|
||||
@@ -227,7 +227,7 @@ public:
|
||||
FuncDef* CurFunc() { return curFunc_; }
|
||||
const TokenSequence& ts() const { return ts_; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
static bool IsBuiltin(FuncType* type);
|
||||
static bool IsBuiltin(const std::string& name);
|
||||
static Identifier* GetBuiltin(const Token* tok);
|
||||
|
@@ -167,6 +167,8 @@ public:
|
||||
// function keywords
|
||||
BITCAST,
|
||||
EXP,
|
||||
LOG,
|
||||
SQRTF,
|
||||
// KEYWORD END
|
||||
|
||||
IDENTIFIER,
|
||||
|
@@ -331,7 +331,6 @@ public:
|
||||
using ShapeInt = std::vector<int>;
|
||||
|
||||
public:
|
||||
static TileType* New(const ShapeExpr& expr, QualType eleType);
|
||||
static TileType* New(const ShapeInt& shape, QualType eleType);
|
||||
virtual ~TileType() { }
|
||||
|
||||
@@ -345,6 +344,7 @@ public:
|
||||
}
|
||||
|
||||
ShapeInt Shape() { return shape_; }
|
||||
|
||||
int NumEle() const {
|
||||
int ret = 1;
|
||||
for(int s: shape_)
|
||||
@@ -352,24 +352,13 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
protected:
|
||||
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)
|
||||
: DerivedType(pool, derived),
|
||||
shapeExpr_(expr) {
|
||||
bool isComplete = true;
|
||||
for(Expr* s: shapeExpr_)
|
||||
isComplete = isComplete && !s;
|
||||
SetComplete(isComplete);
|
||||
bool CheckPow2NumEl() const {
|
||||
int n = NumEle();
|
||||
return n && !(n & (n - 1));
|
||||
}
|
||||
|
||||
TileType(MemPool* pool, const ShapeInt& shape, QualType derived)
|
||||
: DerivedType(pool, derived),
|
||||
shape_(shape) {
|
||||
bool isComplete = true;
|
||||
for(int s: shape_)
|
||||
isComplete = isComplete && (s>=0);
|
||||
SetComplete(isComplete);
|
||||
}
|
||||
protected:
|
||||
TileType(MemPool* pool, const ShapeInt& shape, QualType derived);
|
||||
|
||||
protected:
|
||||
ShapeExpr shapeExpr_;
|
||||
|
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
@@ -17,73 +18,8 @@ namespace driver{
|
||||
|
||||
namespace runtime {
|
||||
|
||||
enum arg_type {
|
||||
INT1_T,
|
||||
INT8_T,
|
||||
INT16_T,
|
||||
INT32_T,
|
||||
INT64_T,
|
||||
HALF_T,
|
||||
FLOAT_T,
|
||||
DOUBLE_T,
|
||||
BUFFER_T
|
||||
};
|
||||
|
||||
arg_type convert(ir::type *ty);
|
||||
|
||||
|
||||
inline size_t size_of(arg_type ty){
|
||||
switch(ty){
|
||||
case INT1_T: return 1;
|
||||
case INT8_T: return 1;
|
||||
case INT16_T: return 2;
|
||||
case INT32_T: return 4;
|
||||
case INT64_T: return 8;
|
||||
case HALF_T: return 2;
|
||||
case FLOAT_T: return 4;
|
||||
case DOUBLE_T: return 8;
|
||||
case BUFFER_T: return 8;
|
||||
default: throw std::runtime_error("unknown type");
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_int_type(arg_type ty){
|
||||
return ty == INT1_T || ty == INT8_T || ty == INT16_T ||
|
||||
ty == INT32_T || ty == INT64_T;
|
||||
}
|
||||
|
||||
class arg {
|
||||
public:
|
||||
union value_t {
|
||||
bool int1;
|
||||
int8_t int8;
|
||||
int16_t int16;
|
||||
int32_t int32;
|
||||
int64_t int64;
|
||||
uint16_t fp16;
|
||||
float fp32;
|
||||
double fp64;
|
||||
driver::buffer* buf;
|
||||
};
|
||||
|
||||
public:
|
||||
// construct from primitive types
|
||||
arg(arg_type ty, value_t val): ty_(ty) { val_ = val; }
|
||||
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
|
||||
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
|
||||
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
|
||||
arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; }
|
||||
arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; }
|
||||
// accessors
|
||||
arg_type type() const { return ty_; }
|
||||
void* data() const { return (void*)&val_; }
|
||||
driver::buffer* buffer() const { return val_.buf; }
|
||||
|
||||
|
||||
private:
|
||||
arg_type ty_;
|
||||
value_t val_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
34
include/triton/runtime/error.h
Normal file
34
include/triton/runtime/error.h
Normal file
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_RUNTIME_ERROR_H_
|
||||
#define _TRITON_RUNTIME_ERROR_H_
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
namespace triton {
|
||||
namespace runtime{
|
||||
namespace exception {
|
||||
|
||||
class base: public std::exception {};
|
||||
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
|
||||
|
||||
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
|
||||
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
|
||||
|
||||
class no_valid_configuration: public exception::base {
|
||||
public:
|
||||
no_valid_configuration(const std::string& err): err_(err) { }
|
||||
const char * what() const throw(){ return err_.c_str(); }
|
||||
private:
|
||||
std::string err_;
|
||||
};
|
||||
|
||||
|
||||
#undef TRITON_CREATE_RUNTIME_EXCEPTION
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -4,25 +4,20 @@
|
||||
#define _TRITON_RUNTIME_FUNCTION_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <set>
|
||||
// codegen
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/runtime/error.h"
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
class Parser;
|
||||
|
||||
// driver forward declaration
|
||||
namespace triton {
|
||||
|
||||
namespace driver{
|
||||
class module;
|
||||
class stream;
|
||||
@@ -30,126 +25,132 @@ namespace driver{
|
||||
class context;
|
||||
class device;
|
||||
}
|
||||
|
||||
namespace lang{
|
||||
class translation_unit;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class tiles;
|
||||
}
|
||||
}
|
||||
|
||||
// ir forward declaration
|
||||
namespace triton{
|
||||
namespace ir {
|
||||
class module;
|
||||
class function;
|
||||
class context;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace runtime{
|
||||
|
||||
|
||||
typedef std::vector<size_t> grid_t;
|
||||
typedef std::map<std::string, size_t> params_t;
|
||||
template<typename T> inline T convert(const std::string& name);
|
||||
template<> inline long convert<long>(const std::string& name) { return std::stol(name); }
|
||||
template<> inline int convert<int>(const std::string& name) { return std::stoi(name); }
|
||||
/* ------------------------- */
|
||||
/* Compilation options */
|
||||
/* ------------------------- */
|
||||
|
||||
struct options_t {
|
||||
template<class T>
|
||||
T D(const std::string& name) const {
|
||||
return std::stoi(defines.at(name));
|
||||
}
|
||||
std::unordered_map<std::string, std::string> defines;
|
||||
int num_warps;
|
||||
};
|
||||
|
||||
/* ------------------------- */
|
||||
/* Runtime arguments */
|
||||
/* ------------------------- */
|
||||
|
||||
enum arg_type {
|
||||
INT1_T,
|
||||
INT8_T,
|
||||
INT16_T,
|
||||
INT32_T,
|
||||
INT64_T,
|
||||
HALF_T,
|
||||
FLOAT_T,
|
||||
DOUBLE_T,
|
||||
BUFFER_T
|
||||
};
|
||||
|
||||
inline size_t size_of(arg_type ty){
|
||||
switch(ty){
|
||||
case INT1_T : return 1;
|
||||
case INT8_T : return 1;
|
||||
case INT16_T : return 2;
|
||||
case INT32_T : return 4;
|
||||
case INT64_T : return 8;
|
||||
case HALF_T : return 2;
|
||||
case FLOAT_T : return 4;
|
||||
case DOUBLE_T: return 8;
|
||||
case BUFFER_T: return 8;
|
||||
default: throw std::runtime_error("unknown type");
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void add_arg(std::stringstream& ss, T arg) {
|
||||
ss.write((char*)&arg, sizeof(T));
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------- */
|
||||
/* ------------------------- */
|
||||
|
||||
class kernel{
|
||||
public:
|
||||
typedef std::vector<size_t> grid_t;
|
||||
|
||||
public:
|
||||
static std::shared_ptr<ir::module> src_to_ir(const std::string& src, const options_t& opt);
|
||||
static std::tuple<std::shared_ptr<driver::module>,
|
||||
std::shared_ptr<driver::kernel>,
|
||||
size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt);
|
||||
|
||||
public:
|
||||
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
|
||||
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
|
||||
std::string get_asm(const std::string &mode);
|
||||
|
||||
public:
|
||||
const options_t opt;
|
||||
|
||||
private:
|
||||
driver::device* dev_;
|
||||
// handles
|
||||
std::shared_ptr<ir::module> ir_;
|
||||
std::shared_ptr<driver::module> mod_;
|
||||
std::shared_ptr<driver::kernel> ker_;
|
||||
// shared mem
|
||||
size_t shared_mem_;
|
||||
};
|
||||
|
||||
struct config {
|
||||
std::map<std::string, std::string> defines;
|
||||
int num_warps;
|
||||
};
|
||||
|
||||
class function {
|
||||
public:
|
||||
struct options_space_t {
|
||||
typedef std::pair<std::string, std::vector<std::string>> define_t;
|
||||
std::vector<define_t> defines;
|
||||
std::vector<int> num_warps;
|
||||
std::vector<int> recompile_key;
|
||||
};
|
||||
|
||||
struct options_t {
|
||||
template<class T>
|
||||
T D(const std::string& name) const {
|
||||
return convert<T>(defines.at(name));
|
||||
}
|
||||
bool operator<(const options_t& other) const {
|
||||
return std::make_pair(defines, num_warps) <
|
||||
std::make_pair(other.defines, other.num_warps);
|
||||
}
|
||||
std::string to_str() const;
|
||||
|
||||
std::map<std::string, std::string> defines;
|
||||
size_t num_warps;
|
||||
};
|
||||
|
||||
typedef std::function<grid_t(const options_t&)> grid_fn_ty;
|
||||
|
||||
|
||||
private:
|
||||
class caller {
|
||||
public:
|
||||
// constructors
|
||||
caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt);
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, const options_t& opt);
|
||||
// serialization
|
||||
void write(std::ofstream& ofs);
|
||||
void read(driver::context* ctx, std::ifstream& ifs);
|
||||
// accessors
|
||||
const options_t opt() const { return opt_; }
|
||||
const driver::module* parent() const { return &*parent_; }
|
||||
const driver::kernel* bin() const { return &*bin_; }
|
||||
arg_type param_ty(size_t i) const { return param_tys_.at(i);}
|
||||
const std::vector<arg_type>& param_tys() const { return param_tys_; }
|
||||
|
||||
std::vector<int> retune() const { return retune_; }
|
||||
// entry points
|
||||
void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
std::shared_ptr<driver::module> parent_;
|
||||
std::vector<arg_type> param_tys_;
|
||||
std::vector<int> retune_;
|
||||
options_t opt_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
private:
|
||||
typedef std::pair<driver::device*, std::vector<int32_t>> cache_key_t;
|
||||
|
||||
private:
|
||||
// cache
|
||||
static std::string get_cache_prefix();
|
||||
// make
|
||||
triton::lang::translation_unit *make_ast(const std::string &src);
|
||||
std::unique_ptr<ir::module> make_ir(Parser &parser);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options_t &opt);
|
||||
caller *make(driver::stream *stream, options_t opt);
|
||||
void precompile(driver::stream *stream, const options_space_t& tuning_space);
|
||||
// autotune
|
||||
caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size);
|
||||
typedef std::function<kernel::grid_t(const options_t&)> grid_fn_ty;
|
||||
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
||||
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
||||
typedef std::vector<config> autotune_confs_t;
|
||||
|
||||
public:
|
||||
static std::string preheader();
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = "");
|
||||
void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream);
|
||||
void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void set_cst(const std::string& name, void* data, size_t n_bytes);
|
||||
function(const std::string& src, const options_t& opt, driver::device *device,
|
||||
const std::vector<config>& tune_confs = {}, const std::vector<std::string> &tune_key = {});
|
||||
kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
const std::vector<arg_type> get_signature() { return sig_; }
|
||||
|
||||
private:
|
||||
std::map<std::string, std::vector<char>> cst_;
|
||||
// pre-compilation
|
||||
ir::context ctx_;
|
||||
std::map<std::vector<uint64_t>, std::vector<std::shared_ptr<kernel>>> kernels_;
|
||||
std::map<std::vector<uint64_t>, kernel*> cache_;
|
||||
std::vector<arg_type> sig_;
|
||||
std::vector<int> align_idxs_;
|
||||
std::vector<int> int_idxs_;
|
||||
std::vector<int> key_idxs_;
|
||||
std::vector<int> arg_size_;
|
||||
std::vector<int> arg_off_;
|
||||
std::vector<options_t> opts_;
|
||||
std::string src_;
|
||||
options_space_t opt_;
|
||||
std::set<options_t> compiled_;
|
||||
std::map<options_t, std::unique_ptr<caller>> callers_;
|
||||
std::vector<int> args_off_;
|
||||
size_t args_size_;
|
||||
// caching
|
||||
std::string cache_ref_;
|
||||
std::string cache_path_;
|
||||
std::map<cache_key_t, caller*> cache_;
|
||||
driver::device* device_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -30,28 +30,20 @@ private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, bool normalize = false)
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
|
||||
{
|
||||
// const driver::device * device = stream->context()->device();
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
for(size_t i = 0; i < warmup; i++)
|
||||
op();
|
||||
stream->synchronize();
|
||||
tmr.start();
|
||||
for(size_t i = 0; i < 10; i++){
|
||||
// while(total_time*1e-9 < 1e-2){
|
||||
// float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(normalize)
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
for(size_t i = 0; i < repeat; i++){
|
||||
op();
|
||||
// times.push_back(norm*tmr.get().count());
|
||||
// total_time+=times.back();
|
||||
}
|
||||
stream->synchronize();
|
||||
return (float)tmr.get().count() / 10;
|
||||
return (float)tmr.get().count() / repeat;
|
||||
|
||||
// return *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
|
@@ -54,7 +54,7 @@ namespace sha1
|
||||
}
|
||||
}
|
||||
|
||||
void innerHash(unsigned int* result, unsigned int* w)
|
||||
inline void innerHash(unsigned int* result, unsigned int* w)
|
||||
{
|
||||
unsigned int a = result[0];
|
||||
unsigned int b = result[1];
|
||||
@@ -114,7 +114,7 @@ namespace sha1
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void calc(const void* src, const int bytelength, unsigned char* hash)
|
||||
inline void calc(const void* src, const int bytelength, unsigned char* hash)
|
||||
{
|
||||
// Init the result array.
|
||||
unsigned int result[5] = { 0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0 };
|
||||
@@ -170,7 +170,7 @@ namespace sha1
|
||||
}
|
||||
}
|
||||
|
||||
void toHexString(const unsigned char* hash, char* hexstring)
|
||||
inline void toHexString(const unsigned char* hash, char* hexstring)
|
||||
{
|
||||
const char hexDigits[] = { "0123456789abcdef" };
|
||||
|
||||
|
@@ -15,11 +15,65 @@
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
ThreadPool(size_t threads)
|
||||
: stop(false) {
|
||||
for(size_t i = 0;i < threads;++i)
|
||||
workers.emplace_back(
|
||||
[this] {
|
||||
for(;;){
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
@@ -32,69 +86,5 @@ private:
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -416,8 +416,10 @@ std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_in
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++)
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl;
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
@@ -524,8 +526,7 @@ void align::run(ir::module &mod) {
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
// ir::for_each_value(mod, [this](ir::value* v) {
|
||||
// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0]
|
||||
// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||
// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
|
||||
// });
|
||||
}
|
||||
|
||||
|
@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(!i->get_type()->is_void_ty())
|
||||
if(connect_ret && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
|
@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
result = v;
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,21 +98,32 @@ data_layout::data_layout(id_t id,
|
||||
extract_io_use(v, ptr);
|
||||
order_.resize(axes_.size());
|
||||
std::iota(order_.begin(), order_.end(), 0);
|
||||
auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
|
||||
return xx < yy;
|
||||
});
|
||||
if(*largest){
|
||||
auto max_contiguous = align->contiguous(*largest);
|
||||
std::vector<unsigned> max_contiguous;
|
||||
for(ir::value* p: ptr){
|
||||
std::vector<unsigned> curr = align->contiguous(p);
|
||||
if(curr.size() > max_contiguous.size())
|
||||
max_contiguous = curr;
|
||||
else if(curr.size() == max_contiguous.size()){
|
||||
if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
|
||||
max_contiguous = curr;
|
||||
}
|
||||
}
|
||||
bool is_recoalesce = false;
|
||||
for(ir::value* v: values)
|
||||
is_recoalesce = is_recoalesce || dynamic_cast<ir::recoalesce_inst*>(v);
|
||||
if(max_contiguous.size() > 0){
|
||||
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
|
||||
return max_contiguous[a] > max_contiguous[b];
|
||||
});
|
||||
// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
|
||||
// std::cout << order_[0] << " " << order_[1] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_layout::find_axis(int to_find) const {
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||
if(it == axes_.end())
|
||||
return -1;
|
||||
return std::distance(axes_.begin(), it);
|
||||
}
|
||||
|
||||
@@ -121,23 +132,41 @@ size_t data_layout::find_axis(int to_find) const {
|
||||
* MMA Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
mma884_layout::mma884_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
|
||||
mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {2, 2, 1};
|
||||
// std::vector<int> fpw_nm1;
|
||||
// unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
// do {
|
||||
// fpw_nm1 = fpw_;
|
||||
// if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
// if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
// }while(fpw_nm1 != fpw_);
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
bool is_b_row = ord_b[0] != 0;
|
||||
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
spw_ = {16, 8, 1};
|
||||
rep_ = {2, 2, 1};
|
||||
}
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
@@ -146,17 +175,13 @@ mma884_layout::mma884_layout(size_t num_warps,
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shape.size(); d++)
|
||||
effective_num_warps *= wpt_[d];
|
||||
// if(num_warps != effective_num_warps)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
/* shape per block */
|
||||
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
}
|
||||
|
||||
|
||||
@@ -168,9 +193,9 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){
|
||||
analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){
|
||||
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
|
||||
unsigned num_threads = num_warps * 32;
|
||||
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
|
||||
nts_.resize(shape_.size());
|
||||
mts_.resize(shape_.size());
|
||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||
@@ -179,13 +204,17 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
ir::value *ptr = nullptr;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
||||
ptr = st->get_pointer_operand();
|
||||
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||
ptr = io->get_pointer_operand();
|
||||
}
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 4;
|
||||
if(ptr)
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||
int contiguous = 1;
|
||||
if(ptr){
|
||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
|
||||
}
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
@@ -200,14 +229,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
num_threads = num_threads / mts_[i];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
effective_num_threads *= mts_[d];
|
||||
|
||||
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
|
||||
// if(num_warps * 32 != effective_num_threads)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
|
||||
@@ -242,9 +263,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
|
||||
if(!(i_0 && !i_1) &&
|
||||
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
|
||||
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
@@ -253,7 +274,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(const data_layout *arg,
|
||||
shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
@@ -261,6 +282,7 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
|
||||
// double-buffering
|
||||
for(ir::value *v: values)
|
||||
@@ -280,36 +302,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
}
|
||||
|
||||
|
||||
// non-mma ordering
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
for(size_t s = 2; s < get_rank(); s++){
|
||||
col.push_back(s);
|
||||
row.push_back(s);
|
||||
}
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
order_ = is_trans(dot_a) ? row : col;
|
||||
else if(is_nonhmma_dot_b)
|
||||
order_ = is_trans(dot_b) ? col : row;
|
||||
|
||||
// padding
|
||||
size_t pad = 0;
|
||||
if(hmma_dot_a){
|
||||
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 0 : 1] % 32;
|
||||
}
|
||||
else if(hmma_dot_b){
|
||||
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 1 : 0] % 32;
|
||||
}
|
||||
else if(order_ != arg_order) {
|
||||
pad = 4;
|
||||
}
|
||||
shape_[order_[0]] += pad;
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
@@ -324,8 +318,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
* ---- Layouts Inference Pass ---- *
|
||||
* -------------------------------- */
|
||||
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps) { }
|
||||
layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
|
||||
: axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
|
||||
|
||||
|
||||
void layouts::connect(ir::value *x, ir::value *y) {
|
||||
@@ -358,6 +352,8 @@ void layouts::make_graph(ir::instruction *i) {
|
||||
}
|
||||
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
@@ -370,19 +366,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v);
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end())
|
||||
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
|
||||
if(it_hmma_c != values.end()){
|
||||
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
}
|
||||
else
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_);
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
@@ -416,7 +420,7 @@ void layouts::run(ir::module &mod) {
|
||||
}
|
||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||
ir::value *val = recoalasce->get_operand(0);
|
||||
mma884_layout* in_layout = get(val)->to_mma884();
|
||||
mma_layout* in_layout = get(val)->to_mma();
|
||||
scanline_layout* out_layout = get(i)->to_scanline();
|
||||
if(!in_layout || !out_layout)
|
||||
return;
|
||||
@@ -427,7 +431,7 @@ void layouts::run(ir::module &mod) {
|
||||
shape[ld] = in_shape[ld];
|
||||
for(size_t k = 0; k < in_shape.size(); k++)
|
||||
if(k != ld)
|
||||
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
|
||||
shape[k] = in_layout->to_mma()->spt(k);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[recoalasce] = id;
|
||||
|
54
lib/codegen/analysis/swizzle.cc
Normal file
54
lib/codegen/analysis/swizzle.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void swizzle::run(ir::module &) {
|
||||
per_phase_.clear();
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else{
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = 8 / per_phase_[layout];
|
||||
vec_[layout] = 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,325 +0,0 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/machine_layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
// Grid construction
|
||||
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder.getInt32(shapes[order[k]]);
|
||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
||||
result[order[k]] = rem;
|
||||
}
|
||||
result[order[dim - 1]] = trailing;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
|
||||
|
||||
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
||||
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout_->get_double_buffer()) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto info = *layout_->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
||||
if(parent->empty())
|
||||
builder_->SetInsertPoint(parent);
|
||||
else
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create pointers
|
||||
ptr_ = builder_->CreatePHI(ptr_ty, 2);
|
||||
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
|
||||
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
|
||||
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
|
||||
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
|
||||
builder_->SetInsertPoint(current);
|
||||
}
|
||||
else{
|
||||
size_t offset = alloc_->offset(layout_);
|
||||
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
|
||||
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
tile* machine_shared_layout::create(ir::value *v) {
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
auto double_buffer = layout_->get_double_buffer();
|
||||
// offset
|
||||
Value *offset = nullptr;
|
||||
if(double_buffer && v == double_buffer->phi)
|
||||
offset = offset_;
|
||||
// base pointer
|
||||
Value *ptr = ptr_;
|
||||
if(double_buffer && v == double_buffer->latch)
|
||||
ptr = next_ptr_;
|
||||
else if(double_buffer && v == double_buffer->first)
|
||||
ptr = pre_ptr_;
|
||||
// create tile
|
||||
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
||||
}
|
||||
|
||||
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::data_layout *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
tile *machine_distributed_layout::create(ir::value *v) {
|
||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> order(rank);
|
||||
// compute axes
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
// compute order
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
auto cmp = [&](int x, int y) {
|
||||
unsigned axx = a_axes_->get(v, x);
|
||||
unsigned axy = a_axes_->get(v, y);
|
||||
size_t posx = layout_->find_axis(axx);
|
||||
size_t posy = layout_->find_axis(axy);
|
||||
if(posx < rank && posy < rank)
|
||||
return layout_->get_order(posx) < layout_->get_order(posy);
|
||||
return false;
|
||||
};
|
||||
std::sort(order.begin(), order.end(), cmp);
|
||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||
}
|
||||
|
||||
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||
target *tgt, analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::mma884_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shape = layout->get_shape();
|
||||
if(shape.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shape.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw(0);
|
||||
unsigned fpw_1 = layout->fpw(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||
// mma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// mma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::scanline_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
||||
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
||||
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
}
|
||||
thread_id[order[dim - 1]] = full_thread_id;
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
int nts = layout->nts(k);
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts * mts;
|
||||
unsigned per_thread = nts * shape[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_block + n % nts;
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,214 +0,0 @@
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
/* Distributed Tile */
|
||||
void distributed_tile::init_indices() {
|
||||
std::vector<size_t> id(axes_.size(), 0);
|
||||
// build
|
||||
size_t k = 0;
|
||||
while(true) {
|
||||
indices_t current;
|
||||
for(size_t d = 0; d < id.size(); d++)
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[order_[0]]++;
|
||||
while(id[order_[k]] == axes_[order_[k]].values.size()){
|
||||
if(k == id.size() - 1)
|
||||
return;
|
||||
id[order_[k++]] = 0;
|
||||
id[order_[k]]++;
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
||||
if(end < 0)
|
||||
end = ordered_indices_.size() + end + 1;
|
||||
for(unsigned i = start; i < end; i++)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
|
||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
||||
int rank = sizes.size();
|
||||
int len = 1;
|
||||
for(int s: sizes)
|
||||
len *= s;
|
||||
|
||||
for(int i = 0; i < len; i++){
|
||||
indices_t idx(rank);
|
||||
int current = i;
|
||||
for(int k = 0; k < rank; k++){
|
||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
||||
current = current / sizes[k];
|
||||
}
|
||||
fn(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Shared Tile */
|
||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
||||
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
|
||||
if(dyn_cast<Constant>(arg)){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
return;
|
||||
}
|
||||
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
return;
|
||||
}
|
||||
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
|
||||
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
|
||||
if(cst_lhs && cst_rhs){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
}
|
||||
else if(cst_lhs){
|
||||
cst = cst_lhs;
|
||||
non_cst = bin_op->getOperand(1);
|
||||
}
|
||||
else if(cst_rhs){
|
||||
cst = cst_rhs;
|
||||
non_cst = bin_op->getOperand(0);
|
||||
}
|
||||
else{
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
|
||||
non_cst_idx.clear();
|
||||
cst_idx.clear();
|
||||
for(Value *idx: arg_idx){
|
||||
Value *non_cst, *cst;
|
||||
extract_constant(idx, non_cst, cst);
|
||||
non_cst_idx.push_back(non_cst);
|
||||
cst_idx.push_back(cst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||
// result
|
||||
Value *result = builder.getInt32(0);
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
||||
return_vector_ = false;
|
||||
if(perm_.empty()){
|
||||
perm_.resize(shapes.size());
|
||||
std::iota(perm_.begin(), perm_.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
}
|
||||
|
||||
void shared_tile::set_vector_size(unsigned vector_size) {
|
||||
vector_size_ = vector_size;
|
||||
}
|
||||
|
||||
void shared_tile::set_return_mode(bool return_vector){
|
||||
return_vector_ = return_vector;
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::get_value(indices_t idx) {
|
||||
indices_t non_cst_idx, cst_idx;
|
||||
extract_constant(idx, non_cst_idx, cst_idx);
|
||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||
unsigned vector_size = vector_size_;
|
||||
Type *ty = ty_;
|
||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
||||
ty = IntegerType::get(ty->getContext(), 32);
|
||||
vector_size = vector_size / 2;
|
||||
}
|
||||
if(base_ptr == nullptr){
|
||||
// BasicBlock* store = builder_.GetInsertBlock();
|
||||
// if(!non_cst_idx.empty())
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
||||
Value *result = builder_.CreateLoad(ptr);
|
||||
if(return_vector_ == false && vector_size_ > 1) {
|
||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
||||
result = builder_.CreateExtractElement(result, rem);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -2,6 +2,8 @@
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include <iostream>
|
||||
@@ -12,6 +14,12 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
}
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
}
|
||||
@@ -23,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m
|
||||
|
||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||
return builder.CreateCall(barrier, {});
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
@@ -43,19 +51,12 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
Intrinsic::amdgcn_workgroup_id_z
|
||||
};
|
||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
||||
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::r600_read_ngroups_x,
|
||||
Intrinsic::r600_read_ngroups_y,
|
||||
Intrinsic::r600_read_ngroups_z
|
||||
};
|
||||
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_num_group, {});
|
||||
throw std::runtime_error("not implemented on AMD");
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
@@ -103,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
@@ -124,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_nctaid, {});
|
||||
return builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
@@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) {
|
||||
|
||||
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
if(!layout_->get(id)->to_mma884())
|
||||
if(!layout_->get(id)->to_mma())
|
||||
continue;
|
||||
// extract memory stores
|
||||
const auto& values = layout_->values_of(id);
|
||||
|
@@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
@@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
if(to_shared){
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
|
@@ -15,103 +15,105 @@ namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
||||
return left_intersect || right_intersect;
|
||||
});
|
||||
}
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
||||
return intersect(X, y);
|
||||
});
|
||||
}
|
||||
|
||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
return;
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
if(!layout)
|
||||
return;
|
||||
if(alloc_->has_offset(layout)){
|
||||
unsigned offset = alloc_->offset(layout);
|
||||
res.push_back(interval_t(offset, offset + layout->get_size()));
|
||||
int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
analysis::double_buffer_info_t* info = layout->get_double_buffer();
|
||||
if(info)
|
||||
return group_of(info->first, async_write);
|
||||
std::vector<int> groups(phi->get_num_operands());
|
||||
std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
return *std::max_element(groups.begin(), groups.end());
|
||||
}
|
||||
else{
|
||||
auto it = std::find(async_write.begin(), async_write.end(), v);
|
||||
return std::distance(async_write.begin(), it);
|
||||
}
|
||||
}
|
||||
|
||||
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
for(ir::value *op: i->ops())
|
||||
add_reference(op, res);
|
||||
|
||||
membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
|
||||
val_set_t ret;
|
||||
for(ir::value* a: as){
|
||||
if(!a->get_type()->is_tile_ty())
|
||||
continue;
|
||||
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
|
||||
if(!a_layout)
|
||||
continue;
|
||||
int a_start = alloc_->offset(a_layout);
|
||||
int a_end = a_start + a_layout->get_size();
|
||||
for(ir::value* b: bs){
|
||||
if(!b->get_type()->is_tile_ty())
|
||||
continue;
|
||||
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
|
||||
if(!b_layout)
|
||||
continue;
|
||||
int b_start = alloc_->offset(b_layout);
|
||||
int b_end = b_start + b_layout->get_size();
|
||||
if(a_start < b_end || b_start < a_end)
|
||||
ret.insert(b);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
||||
assert(inc_val);
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
void membar::transfer(ir::basic_block *block,
|
||||
val_vec_t& async_write,
|
||||
val_set_t& sync_write,
|
||||
val_set_t& sync_read,
|
||||
std::set<ir::value*>& safe_war,
|
||||
bool& inserted, ir::builder& builder) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
for(ir::instruction *i: instructions){
|
||||
if(dynamic_cast<ir::phi_node*>(i))
|
||||
continue;
|
||||
if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
|
||||
dynamic_cast<ir::masked_load_async_inst*>(i)){
|
||||
async_write.push_back(i);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||
sync_write.insert(i);
|
||||
ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
|
||||
ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
|
||||
// Get shared memory reads
|
||||
std::set<ir::value*> read;
|
||||
std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
|
||||
[&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();});
|
||||
// RAW (async)
|
||||
val_set_t tmp;
|
||||
std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
|
||||
if(intersect_with(read, tmp).size()){
|
||||
std::vector<int> groups(read.size());
|
||||
std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
|
||||
int N = *std::max_element(groups.begin(), groups.end());
|
||||
if(N < async_write.size()){
|
||||
builder.set_insert_point(i);
|
||||
async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
builder.set_insert_point(instr);
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
|
||||
membar::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc,
|
||||
std::set<ir::value*>& safe_war) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering
|
||||
if(safe_war.find(i) != safe_war.end()){
|
||||
write_after_read = false;
|
||||
read_after_write = false;
|
||||
// RAW, WAR
|
||||
if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){
|
||||
builder.set_insert_point(i);
|
||||
barrier = (ir::barrier_inst*)builder.create_barrier();
|
||||
inserted = true;
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
// update state of asynchronous copies
|
||||
if(async_wait){
|
||||
int N = async_write.size() - async_wait->get_N();
|
||||
async_write.erase(async_write.begin(), async_write.begin() + N);
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
// all the copy_to_shared and read from shared are synchronized after barrier
|
||||
if(barrier){
|
||||
sync_write.clear();
|
||||
sync_read.clear();
|
||||
}
|
||||
sync_read.insert(read.begin(), read.end());
|
||||
|
||||
}
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
@@ -125,41 +127,40 @@ void membar::run(ir::module &mod) {
|
||||
if(!layout || !layout->get_double_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi)
|
||||
if(v != layout->get_double_buffer()->phi){
|
||||
safe_war.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::set<ir::instruction*> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
std::map<ir::basic_block*, val_vec_t> async_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_writes;
|
||||
std::map<ir::basic_block*, val_set_t> sync_reads;
|
||||
std::list<ir::value *> pipelined;
|
||||
bool inserted;
|
||||
do{
|
||||
inserted = false;
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// written to
|
||||
std::vector<interval_vec_t> pred_written_to;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_written_to.push_back(written_to[pred]);
|
||||
// read from
|
||||
std::vector<interval_vec_t> pred_read_from;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
// join inputs
|
||||
val_vec_t async_write;
|
||||
val_set_t sync_write;
|
||||
val_set_t sync_read;
|
||||
val_set_t tmp;
|
||||
for(ir::basic_block* pred: block->get_predecessors()){
|
||||
for(ir::value* v: async_writes[pred])
|
||||
if(tmp.insert(v).second)
|
||||
async_write.push_back(v);
|
||||
sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
|
||||
sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
|
||||
}
|
||||
transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
|
||||
async_writes[block] = async_write;
|
||||
sync_writes[block] = sync_write;
|
||||
sync_reads[block] = sync_read;
|
||||
}
|
||||
size_t n_inserted_i = insert_locs.size();
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs)
|
||||
insert_barrier(i, builder);
|
||||
}while(inserted);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,9 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
@@ -97,6 +99,33 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
|
||||
//}
|
||||
|
||||
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
|
||||
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
|
||||
if(!copy_to_shared)
|
||||
return false;
|
||||
ir::value *arg = copy_to_shared->get_operand(0);
|
||||
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
|
||||
if(!ld)
|
||||
return false;
|
||||
builder.set_insert_point(copy_to_shared);
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
ir::value *msk = ld->get_mask_operand();
|
||||
ir::value *val = ld->get_false_value_operand();
|
||||
analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
int nts = layout->nts(layout->get_order()[0]);
|
||||
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(nts*dtsize >= 4){
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
|
||||
// std::cout << layout->nts(layout->get_order(0)) << std::endl;
|
||||
// return true;
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
@@ -164,6 +193,22 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
|
||||
auto select = dynamic_cast<ir::select_inst*>(value);
|
||||
if(!select)
|
||||
return false;
|
||||
auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
|
||||
if(!if_value)
|
||||
return false;
|
||||
if(select->get_pred_op() != if_value->get_mask_operand())
|
||||
return false;
|
||||
builder.set_insert_point(select);
|
||||
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||
if_value->get_mask_operand(),
|
||||
select->get_else_value_op());
|
||||
select->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
}
|
||||
|
||||
void peephole::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
@@ -197,10 +242,13 @@ void peephole::run(ir::module &mod) {
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
was_modified = was_modified || rewrite_select_masked_load(i, builder);
|
||||
if(tgt_->as_nvidia()->sm() >= 80)
|
||||
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
|
116
lib/codegen/transform/pipeline.cc
Normal file
116
lib/codegen/transform/pipeline.cc
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/pipeline.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i || i->get_parent() != block)
|
||||
return;
|
||||
if(i->get_id()==ir::INST_PHI)
|
||||
return;
|
||||
ret.push_back(i);
|
||||
for(ir::user* u: i->get_users())
|
||||
recursive_deps(u, block, ret);
|
||||
}
|
||||
|
||||
void pipeline::run(ir::module &mod) {
|
||||
// *Very* conservative heuristics for pre-fetching.
|
||||
// A load instruction can be pipelined if:
|
||||
// - the pointer is a phi node that references a value
|
||||
// in its basic block (i.e., pointer induction variable)
|
||||
// - the load has only a single use in a dot instruction
|
||||
// As more use cases become apparent, this pass will be improved
|
||||
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
|
||||
auto users = load->get_users();
|
||||
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
|
||||
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
|
||||
to_pipeline.push_back({load, ptr});
|
||||
}});
|
||||
// do the pipelining
|
||||
std::vector<ir::phi_node*> new_loads;
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(auto info: to_pipeline){
|
||||
ir::load_inst* load = info.first;
|
||||
ir::phi_node* ptr = info.second;
|
||||
ir::basic_block* block = load->get_parent();
|
||||
ir::basic_block* header = block->get_predecessors()[0];
|
||||
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
|
||||
auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
|
||||
assert(block_br);
|
||||
assert(header_br);
|
||||
ir::type* ty = load->get_type();
|
||||
// pre-fetch first iteration
|
||||
builder.set_insert_point(header->get_inst_list().back());
|
||||
ir::value* first_ptr = ptr->get_value_for_block(header);
|
||||
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes());
|
||||
ir::value* false_value;
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
|
||||
first_mask = builder.create_and(first_mask, masked_load->get_mask_operand());
|
||||
false_value = masked_load->get_false_value_operand();
|
||||
}
|
||||
else
|
||||
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes());
|
||||
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
|
||||
// pre-fetch next iteration
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
ir::value* next_ptr = ptr->get_value_for_block(block);
|
||||
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes());
|
||||
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load))
|
||||
next_mask = builder.create_and(next_mask, masked_load->get_mask_operand());
|
||||
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
|
||||
// phi node
|
||||
builder.set_insert_point(block->get_first_non_phi());
|
||||
ir::phi_node* new_load = builder.create_phi(ty, 2);
|
||||
new_load->add_incoming(first_load, header);
|
||||
new_load->add_incoming(next_load, block);
|
||||
load->replace_all_uses_with(new_load);
|
||||
new_loads.push_back(new_load);
|
||||
}
|
||||
|
||||
|
||||
// try to move dot_inst after loads
|
||||
// for better overlap of io and compute
|
||||
struct move_config_t{
|
||||
std::vector<ir::instruction*> insts;
|
||||
ir::load_inst* dst;
|
||||
};
|
||||
std::map<ir::basic_block*, move_config_t> to_move;
|
||||
|
||||
if(has_copy_async_){
|
||||
for(ir::function* fn: mod.get_function_list())
|
||||
for(ir::basic_block* bb: fn->blocks())
|
||||
for(ir::instruction* inst: bb->get_inst_list()){
|
||||
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
|
||||
recursive_deps(i, bb, to_move[bb].insts);
|
||||
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
|
||||
to_move[bb].dst = i;
|
||||
}
|
||||
|
||||
for(auto& x: to_move){
|
||||
builder.set_insert_point_after(x.second.dst);
|
||||
for(ir::instruction* i: x.second.insts){
|
||||
x.first->erase(i);
|
||||
builder.insert(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(dynamic_cast<ir::make_range*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
|
51
lib/codegen/transform/reorder.cc
Normal file
51
lib/codegen/transform/reorder.cc
Normal file
@@ -0,0 +1,51 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
ir::value* _ptr = ld->get_pointer_operand();
|
||||
ir::value* _msk = ld->get_mask_operand();
|
||||
ir::value* _val = ld->get_false_value_operand();
|
||||
auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
auto val = std::find(block->begin(), block->end(), _val);
|
||||
if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
continue;
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
int dist_ptr = std::distance(ptr, it);
|
||||
int dist_msk = std::distance(msk, it);
|
||||
int dist_val = std::distance(val, it);
|
||||
if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
builder.set_insert_point(++ptr);
|
||||
if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
builder.set_insert_point(++msk);
|
||||
if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
builder.set_insert_point(++val);
|
||||
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& x: to_replace)
|
||||
x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -47,6 +47,12 @@ void backend::platforms::init() {
|
||||
if(dispatch::cuinit()){
|
||||
cache_.push_back(new cu_platform());
|
||||
}
|
||||
//if host should be added
|
||||
bool host_visible = true;
|
||||
if(host_visible){
|
||||
cache_.push_back(new host_platform());
|
||||
}
|
||||
|
||||
// //if OpenCL is here
|
||||
// if(dispatch::clinit()){
|
||||
// cl_uint num_platforms;
|
||||
@@ -56,11 +62,7 @@ void backend::platforms::init() {
|
||||
// for(cl_platform_id id: ids)
|
||||
// cache_.push_back(new cl_platform(id));
|
||||
// }
|
||||
// //if host is here
|
||||
// bool host_visible = true;
|
||||
// if(host_visible){
|
||||
// cache_.push_back(new host_platform());
|
||||
// }
|
||||
|
||||
if(cache_.empty())
|
||||
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
|
||||
}
|
||||
@@ -132,7 +134,7 @@ std::map<std::tuple<driver::module*, std::string>, driver::kernel*> backend::ker
|
||||
void backend::streams::init(std::list<driver::context*> const & contexts){
|
||||
for(driver::context* ctx : contexts)
|
||||
if(cache_.find(ctx)==cache_.end())
|
||||
cache_.insert(std::make_pair(ctx, std::vector<driver::stream*>{driver::stream::create(ctx)}));
|
||||
cache_.insert(std::make_pair(ctx, std::vector<driver::stream*>{driver::stream::create(ctx->backend())}));
|
||||
}
|
||||
|
||||
void backend::streams::release(){
|
||||
|
@@ -35,66 +35,53 @@ namespace driver
|
||||
|
||||
//
|
||||
|
||||
buffer::buffer(driver::context* ctx, size_t size, CUdeviceptr cu, bool take_ownership)
|
||||
: polymorphic_resource(cu, take_ownership), context_(ctx), size_(size) { }
|
||||
buffer::buffer(size_t size, CUdeviceptr cu, bool take_ownership)
|
||||
: polymorphic_resource(cu, take_ownership), size_(size) { }
|
||||
|
||||
buffer::buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership)
|
||||
: polymorphic_resource(cl, take_ownership), context_(ctx), size_(size) { }
|
||||
|
||||
buffer::buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership)
|
||||
: polymorphic_resource(hst, take_ownership), context_(ctx), size_(size) { }
|
||||
|
||||
|
||||
driver::context* buffer::context() {
|
||||
return context_;
|
||||
}
|
||||
buffer::buffer(size_t size, host_buffer_t hst, bool take_ownership)
|
||||
: polymorphic_resource(hst, take_ownership), size_(size) { }
|
||||
|
||||
size_t buffer::size() {
|
||||
return size_;
|
||||
}
|
||||
|
||||
uintptr_t buffer::addr_as_uintptr_t() {
|
||||
switch(backend_){
|
||||
case CUDA: return *cu_;
|
||||
case Host: return (uintptr_t)hst_->data;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
buffer* buffer::create(driver::context* ctx, size_t size) {
|
||||
switch(ctx->backend()){
|
||||
case CUDA: return new cu_buffer(ctx, size);
|
||||
case OpenCL: return new ocl_buffer(ctx, size);
|
||||
case Host: return new host_buffer(ctx, size);
|
||||
case CUDA: return new cu_buffer(size);
|
||||
case Host: return new host_buffer(size);
|
||||
default: throw std::runtime_error("unknown backend");
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
host_buffer::host_buffer(driver::context *context, size_t size)
|
||||
: buffer(context, size, host_buffer_t(), true){
|
||||
host_buffer::host_buffer(size_t size)
|
||||
: buffer(size, host_buffer_t(), true){
|
||||
hst_->data = new char[size];
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
ocl_buffer::ocl_buffer(driver::context* context, size_t size)
|
||||
: buffer(context, size, cl_mem(), true){
|
||||
cl_int err;
|
||||
*cl_ = dispatch::clCreateBuffer(*context->cl(), CL_MEM_READ_WRITE, size, NULL, &err);
|
||||
check(err);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
|
||||
cu_buffer::cu_buffer(driver::context* context, size_t size)
|
||||
: buffer(context, size, CUdeviceptr(), true) {
|
||||
cu_context::context_switcher ctx_switch(*context_);
|
||||
cu_buffer::cu_buffer(size_t size)
|
||||
: buffer(size, CUdeviceptr(), true) {
|
||||
dispatch::cuMemAlloc(&*cu_, size);
|
||||
}
|
||||
|
||||
cu_buffer::cu_buffer(driver::context* context, size_t size, CUdeviceptr cu, bool take_ownership)
|
||||
: buffer(context, size, cu, take_ownership){
|
||||
cu_buffer::cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership)
|
||||
: buffer(size, cu, take_ownership){
|
||||
}
|
||||
|
||||
void cu_buffer::set_zero(driver::stream* queue, size_t size)
|
||||
{
|
||||
cu_context::context_switcher ctx_switch(*context_);
|
||||
void cu_buffer::set_zero(driver::stream* queue, size_t size){
|
||||
dispatch::cuMemsetD8Async(*cu_, 0, size, *queue->cu());
|
||||
}
|
||||
|
||||
|
@@ -41,11 +41,6 @@ context::context(driver::device *dev, CUcontext cu, bool take_ownership):
|
||||
dev_(dev), cache_path_(get_cache_path()) {
|
||||
}
|
||||
|
||||
context::context(driver::device *dev, cl_context cl, bool take_ownership):
|
||||
polymorphic_resource(cl, take_ownership),
|
||||
dev_(dev), cache_path_(get_cache_path()){
|
||||
}
|
||||
|
||||
context::context(driver::device *dev, host_context_t hst, bool take_ownership):
|
||||
polymorphic_resource(hst, take_ownership),
|
||||
dev_(dev), cache_path_(get_cache_path()){
|
||||
@@ -54,7 +49,6 @@ context::context(driver::device *dev, host_context_t hst, bool take_ownership):
|
||||
context* context::create(driver::device *dev){
|
||||
switch(dev->backend()){
|
||||
case CUDA: return new cu_context(dev);
|
||||
case OpenCL: return new ocl_context(dev);
|
||||
case Host: return new host_context(dev);
|
||||
default: throw std::runtime_error("unknown backend");
|
||||
}
|
||||
@@ -100,17 +94,6 @@ host_context::host_context(driver::device* dev): context(dev, host_context_t(),
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
// RAII context switcher
|
||||
cu_context::context_switcher::context_switcher(const context &ctx): ctx_((const cu_context&)ctx) {
|
||||
dispatch::cuCtxPushCurrent_v2(*ctx_.cu());
|
||||
}
|
||||
|
||||
cu_context::context_switcher::~context_switcher() {
|
||||
CUcontext tmp;
|
||||
dispatch::cuCtxPopCurrent_v2(&tmp);
|
||||
assert(tmp==*ctx_.cu() && "Switching back to invalid context!");
|
||||
}
|
||||
|
||||
// import CUdevice
|
||||
CUdevice cu_context::get_device_of(CUcontext context){
|
||||
dispatch::cuCtxPushCurrent_v2(context);
|
||||
@@ -127,21 +110,9 @@ cu_context::cu_context(CUcontext context, bool take_ownership): driver::context(
|
||||
|
||||
cu_context::cu_context(driver::device* device): context(device, CUcontext(), true){
|
||||
dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, *((driver::cu_device*)dev_)->cu());
|
||||
dispatch::cuCtxPopCurrent_v2(NULL);
|
||||
// dispatch::cuCtxPopCurrent_v2(NULL);
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
ocl_context::ocl_context(driver::device* dev): context(dev, cl_context(), true) {
|
||||
cl_int err;
|
||||
*cl_ = dispatch::clCreateContext(nullptr, 1, &*dev->cl(), nullptr, nullptr, &err);
|
||||
check(err);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -44,69 +44,10 @@ std::unique_ptr<codegen::target> host_device::make_target() const {
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
// maximum amount of shared memory per block
|
||||
size_t ocl_device::max_shared_memory() const {
|
||||
throw std::runtime_error("not implemented");
|
||||
// return ocl::info<CL_DEVICE_LOCAL_MEM_SIZE>(*cl_);
|
||||
}
|
||||
|
||||
size_t ocl_device::max_threads_per_block() const {
|
||||
throw std::runtime_error("not implemented");
|
||||
// return ocl::info<CL_DEVICE_MAX_WORK_ITEM_SIZES>(*cl_).at(0);
|
||||
}
|
||||
|
||||
std::unique_ptr<codegen::target> ocl_device::make_target() const {
|
||||
return std::unique_ptr<codegen::amd_cl_target>(new codegen::amd_cl_target());
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
// architecture
|
||||
cu_device::Architecture cu_device::nv_arch(std::pair<unsigned int, unsigned int> sm) const {
|
||||
switch(sm.first) {
|
||||
case 7:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_7_0;
|
||||
}
|
||||
|
||||
case 6:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_6_0;
|
||||
case 1: return Architecture::SM_6_1;
|
||||
}
|
||||
|
||||
case 5:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_5_0;
|
||||
case 2: return Architecture::SM_5_2;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
case 3:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_3_0;
|
||||
case 5: return Architecture::SM_3_5;
|
||||
case 7: return Architecture::SM_3_7;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
case 2:
|
||||
switch(sm.second){
|
||||
case 0: return Architecture::SM_2_0;
|
||||
case 1: return Architecture::SM_2_1;
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
|
||||
default: return Architecture::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
// information query
|
||||
template<CUdevice_attribute attr>
|
||||
int cu_device::cuGetInfo() const{
|
||||
@@ -127,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{
|
||||
return map.at(key);
|
||||
}
|
||||
|
||||
// architecture
|
||||
cu_device::Architecture cu_device::architecture() const{
|
||||
return nv_arch(compute_capability());
|
||||
}
|
||||
|
||||
// number of address bits
|
||||
size_t cu_device::address_bits() const{
|
||||
return sizeof(size_t)*8;
|
||||
@@ -152,17 +88,17 @@ std::string cu_device::pci_bus_id() const{
|
||||
}
|
||||
|
||||
// force the device to be interpreted as a particular cc
|
||||
void cu_device::interpret_as(std::pair<size_t, size_t> cc){
|
||||
interpreted_as_ = std::make_shared<std::pair<size_t, size_t>>(cc);
|
||||
void cu_device::interpret_as(int cc){
|
||||
interpreted_as_ = std::make_shared<int>(cc);
|
||||
}
|
||||
|
||||
// compute capability
|
||||
std::pair<size_t, size_t> cu_device::compute_capability() const {
|
||||
int cu_device::compute_capability() const {
|
||||
if(interpreted_as_)
|
||||
return *interpreted_as_;
|
||||
size_t _major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
size_t _minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||
return std::make_pair(_major, _minor);
|
||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>();
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
|
||||
return major*10 + minor;
|
||||
}
|
||||
|
||||
// maximum number of threads per block
|
||||
@@ -237,7 +173,7 @@ std::string cu_device::infos() const{
|
||||
|
||||
// target
|
||||
std::unique_ptr<codegen::target> cu_device::make_target() const {
|
||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target());
|
||||
return std::unique_ptr<codegen::nvidia_cu_target>(new codegen::nvidia_cu_target(compute_capability()));
|
||||
}
|
||||
|
||||
|
||||
|
@@ -72,17 +72,6 @@ namespace driver
|
||||
#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\
|
||||
{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }
|
||||
|
||||
//Specialized helpers for OpenCL
|
||||
#define OCL_DEFINE1(ret, fname, t1) DEFINE1(clinit, opencl_, ret, fname, t1)
|
||||
#define OCL_DEFINE2(ret, fname, t1, t2) DEFINE2(clinit, opencl_, ret, fname, t1, t2)
|
||||
#define OCL_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(clinit, opencl_, ret, fname, t1, t2, t3)
|
||||
#define OCL_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(clinit, opencl_, ret, fname, t1, t2, t3, t4)
|
||||
#define OCL_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(clinit, opencl_, ret, fname, t1, t2, t3, t4, t5)
|
||||
#define OCL_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(clinit, opencl_, ret, fname, t1, t2, t3, t4, t5, t6)
|
||||
#define OCL_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(clinit, opencl_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
|
||||
#define OCL_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(clinit, opencl_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
|
||||
#define OCL_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(clinit, opencl_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
|
||||
|
||||
//Specialized helpers for CUDA
|
||||
#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
|
||||
#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
|
||||
@@ -101,15 +90,10 @@ namespace driver
|
||||
#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
|
||||
#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
|
||||
|
||||
bool dispatch::clinit()
|
||||
{
|
||||
if(opencl_==nullptr)
|
||||
opencl_ = dlopen("libOpenCL.so", RTLD_LAZY);
|
||||
return opencl_ != nullptr;
|
||||
}
|
||||
|
||||
bool dispatch::cuinit(){
|
||||
if(cuda_==nullptr){
|
||||
putenv((char*)"CUDA_CACHE_DISABLE=1");
|
||||
std::string libcuda = tools::getenv("TRITON_LIBCUDA");
|
||||
if(libcuda.empty())
|
||||
cuda_ = dlopen("libcuda.so", RTLD_LAZY);
|
||||
@@ -171,6 +155,7 @@ CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
|
||||
CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
|
||||
CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
|
||||
CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
|
||||
CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
|
||||
CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
|
||||
CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
|
||||
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
|
||||
@@ -189,46 +174,6 @@ NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
|
||||
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
|
||||
|
||||
// OpenCL
|
||||
cl_int dispatch::clBuildProgram(cl_program a, cl_uint b, const cl_device_id * c, const char * d, void (*e)(cl_program, void *), void * f)
|
||||
{ return f_impl<dispatch::clinit>(opencl_, clBuildProgram, clBuildProgram_, "clBuildProgram", a, b, c, d, e, f); }
|
||||
|
||||
cl_context dispatch::clCreateContext(const cl_context_properties * a, cl_uint b, const cl_device_id * c, void (*d)(const char *, const void *, size_t, void *), void * e, cl_int * f)
|
||||
{ return f_impl<dispatch::clinit>(opencl_, dispatch::clCreateContext, dispatch::clCreateContext_, "clCreateContext", a, b, c, d, e, f); }
|
||||
|
||||
OCL_DEFINE9(cl_int, clEnqueueNDRangeKernel, cl_command_queue, cl_kernel, cl_uint, const size_t*, const size_t*, const size_t*, cl_uint, const cl_event*, cl_event*)
|
||||
OCL_DEFINE4(cl_int, clSetKernelArg, cl_kernel, cl_uint, size_t, const void *)
|
||||
OCL_DEFINE1(cl_int, clReleaseMemObject, cl_mem)
|
||||
OCL_DEFINE1(cl_int, clFinish, cl_command_queue)
|
||||
OCL_DEFINE5(cl_int, clGetMemObjectInfo, cl_mem, cl_mem_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE5(cl_int, clGetCommandQueueInfo, cl_command_queue, cl_command_queue_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE1(cl_int, clReleaseContext, cl_context)
|
||||
OCL_DEFINE1(cl_int, clReleaseEvent, cl_event)
|
||||
OCL_DEFINE9(cl_int, clEnqueueWriteBuffer, cl_command_queue, cl_mem, cl_bool, size_t, size_t, const void *, cl_uint, const cl_event *, cl_event *)
|
||||
OCL_DEFINE9(cl_int, clEnqueueReadBuffer, cl_command_queue, cl_mem, cl_bool, size_t, size_t, void *, cl_uint, const cl_event *, cl_event *)
|
||||
OCL_DEFINE6(cl_int, clGetProgramBuildInfo, cl_program, cl_device_id, cl_program_build_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE1(cl_int, clReleaseDevice, cl_device_id)
|
||||
OCL_DEFINE5(cl_int, clGetDeviceIDs, cl_platform_id, cl_device_type, cl_uint, cl_device_id *, cl_uint *)
|
||||
OCL_DEFINE5(cl_int, clGetContextInfo, cl_context, cl_context_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE5(cl_int, clGetDeviceInfo, cl_device_id, cl_device_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE1(cl_int, clReleaseCommandQueue, cl_command_queue)
|
||||
OCL_DEFINE3(cl_int, clGetPlatformIDs, cl_uint, cl_platform_id *, cl_uint *)
|
||||
OCL_DEFINE5(cl_int, clGetPlatformInfo, cl_platform_id, cl_platform_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE5(cl_int, clGetEventProfilingInfo, cl_event, cl_profiling_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE7(cl_program, clCreateProgramWithBinary, cl_context, cl_uint, const cl_device_id *, const size_t *, const unsigned char **, cl_int *, cl_int *)
|
||||
OCL_DEFINE4(cl_command_queue, clCreateCommandQueue, cl_context, cl_device_id, cl_command_queue_properties, cl_int *)
|
||||
OCL_DEFINE1(cl_int, clRetainEvent, cl_event)
|
||||
OCL_DEFINE1(cl_int, clReleaseProgram, cl_program)
|
||||
OCL_DEFINE1(cl_int, clFlush, cl_command_queue)
|
||||
OCL_DEFINE5(cl_int, clGetProgramInfo, cl_program, cl_program_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE5(cl_int, clGetKernelInfo, cl_kernel, cl_kernel_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE6(cl_int, clGetKernelWorkGroupInfo, cl_kernel, cl_device_id, cl_kernel_work_group_info, size_t, void *, size_t *)
|
||||
OCL_DEFINE3(cl_kernel, clCreateKernel, cl_program, const char *, cl_int *)
|
||||
OCL_DEFINE4(cl_int, clCreateKernelsInProgram, cl_program, cl_uint, cl_kernel*, cl_uint*)
|
||||
OCL_DEFINE5(cl_mem, clCreateBuffer, cl_context, cl_mem_flags, size_t, void *, cl_int *)
|
||||
OCL_DEFINE5(cl_program, clCreateProgramWithSource, cl_context, cl_uint, const char **, const size_t *, cl_int *)
|
||||
OCL_DEFINE1(cl_int, clReleaseKernel, cl_kernel)
|
||||
|
||||
// LLVM to SPIR-V
|
||||
int dispatch::initializeLLVMToSPIRVPass(llvm::PassRegistry ®istry){
|
||||
return f_impl<dispatch::spvllvminit>(spvllvm_, initializeLLVMToSPIRVPass, initializeLLVMToSPIRVPass_, "initializeLLVMToSPIRVPass", std::ref(registry));
|
||||
@@ -246,47 +191,10 @@ void dispatch::release(){
|
||||
}
|
||||
}
|
||||
|
||||
void * dispatch::opencl_;
|
||||
void* dispatch::cuda_;
|
||||
void* dispatch::nvml_;
|
||||
void* dispatch::spvllvm_;
|
||||
|
||||
//OpenCL
|
||||
void* dispatch::clBuildProgram_;
|
||||
void* dispatch::clEnqueueNDRangeKernel_;
|
||||
void* dispatch::clSetKernelArg_;
|
||||
void* dispatch::clReleaseMemObject_;
|
||||
void* dispatch::clFinish_;
|
||||
void* dispatch::clGetMemObjectInfo_;
|
||||
void* dispatch::clGetCommandQueueInfo_;
|
||||
void* dispatch::clReleaseContext_;
|
||||
void* dispatch::clReleaseEvent_;
|
||||
void* dispatch::clEnqueueWriteBuffer_;
|
||||
void* dispatch::clEnqueueReadBuffer_;
|
||||
void* dispatch::clGetProgramBuildInfo_;
|
||||
void* dispatch::clReleaseDevice_;
|
||||
void* dispatch::clCreateContext_;
|
||||
void* dispatch::clGetDeviceIDs_;
|
||||
void* dispatch::clGetContextInfo_;
|
||||
void* dispatch::clGetDeviceInfo_;
|
||||
void* dispatch::clReleaseCommandQueue_;
|
||||
void* dispatch::clGetPlatformIDs_;
|
||||
void* dispatch::clGetPlatformInfo_;
|
||||
void* dispatch::clGetEventProfilingInfo_;
|
||||
void* dispatch::clCreateProgramWithBinary_;
|
||||
void* dispatch::clCreateCommandQueue_;
|
||||
void* dispatch::clRetainEvent_;
|
||||
void* dispatch::clReleaseProgram_;
|
||||
void* dispatch::clFlush_;
|
||||
void* dispatch::clGetProgramInfo_;
|
||||
void* dispatch::clGetKernelInfo_;
|
||||
void* dispatch::clGetKernelWorkGroupInfo_;
|
||||
void* dispatch::clCreateKernel_;
|
||||
void* dispatch::clCreateKernelsInProgram_;
|
||||
void* dispatch::clCreateBuffer_;
|
||||
void* dispatch::clCreateProgramWithSource_;
|
||||
void* dispatch::clReleaseKernel_;
|
||||
|
||||
//CUDA
|
||||
void* dispatch::cuCtxGetCurrent_;
|
||||
void* dispatch::cuCtxSetCurrent_;
|
||||
@@ -317,6 +225,7 @@ void* dispatch::cuCtxCreate_v2_;
|
||||
void* dispatch::cuModuleGetFunction_;
|
||||
void* dispatch::cuStreamSynchronize_;
|
||||
void* dispatch::cuStreamDestroy_v2_;
|
||||
void* dispatch::cuStreamGetCtx_;
|
||||
void* dispatch::cuEventDestroy_v2_;
|
||||
void* dispatch::cuMemAlloc_v2_;
|
||||
void* dispatch::cuPointerGetAttribute_;
|
||||
|
@@ -94,67 +94,6 @@ void check(CUresult err)
|
||||
}
|
||||
}
|
||||
|
||||
void check(cl_int err)
|
||||
{
|
||||
using namespace exception::ocl;
|
||||
switch(err)
|
||||
{
|
||||
case CL_SUCCESS: break;
|
||||
case CL_DEVICE_NOT_FOUND: throw device_not_found();
|
||||
case CL_DEVICE_NOT_AVAILABLE: throw device_not_available();
|
||||
case CL_COMPILER_NOT_AVAILABLE: throw compiler_not_available();
|
||||
case CL_MEM_OBJECT_ALLOCATION_FAILURE: throw mem_object_allocation_failure();
|
||||
case CL_OUT_OF_RESOURCES: throw out_of_resources();
|
||||
case CL_OUT_OF_HOST_MEMORY: throw out_of_host_memory();
|
||||
case CL_PROFILING_INFO_NOT_AVAILABLE: throw profiling_info_not_available();
|
||||
case CL_MEM_COPY_OVERLAP: throw mem_copy_overlap();
|
||||
case CL_IMAGE_FORMAT_MISMATCH: throw image_format_mismatch();
|
||||
case CL_IMAGE_FORMAT_NOT_SUPPORTED: throw image_format_not_supported();
|
||||
case CL_BUILD_PROGRAM_FAILURE: throw build_program_failure();
|
||||
case CL_MAP_FAILURE: throw map_failure();
|
||||
|
||||
case CL_INVALID_VALUE: throw invalid_value();
|
||||
case CL_INVALID_DEVICE_TYPE: throw invalid_device_type();
|
||||
case CL_INVALID_PLATFORM: throw invalid_platform();
|
||||
case CL_INVALID_DEVICE: throw invalid_device();
|
||||
case CL_INVALID_CONTEXT: throw invalid_context();
|
||||
case CL_INVALID_QUEUE_PROPERTIES: throw invalid_queue_properties();
|
||||
case CL_INVALID_COMMAND_QUEUE: throw invalid_command_queue();
|
||||
case CL_INVALID_HOST_PTR: throw invalid_host_ptr();
|
||||
case CL_INVALID_MEM_OBJECT: throw invalid_mem_object();
|
||||
case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: throw invalid_image_format_descriptor();
|
||||
case CL_INVALID_IMAGE_SIZE: throw invalid_image_size();
|
||||
case CL_INVALID_SAMPLER: throw invalid_sampler();
|
||||
case CL_INVALID_BINARY: throw invalid_binary();
|
||||
case CL_INVALID_BUILD_OPTIONS: throw invalid_build_options();
|
||||
case CL_INVALID_PROGRAM: throw invalid_program();
|
||||
case CL_INVALID_PROGRAM_EXECUTABLE: throw invalid_program_executable();
|
||||
case CL_INVALID_KERNEL_NAME: throw invalid_kernel_name();
|
||||
case CL_INVALID_KERNEL_DEFINITION: throw invalid_kernel_definition();
|
||||
case CL_INVALID_KERNEL: throw invalid_kernel();
|
||||
case CL_INVALID_ARG_INDEX: throw invalid_arg_index();
|
||||
case CL_INVALID_ARG_VALUE: throw invalid_arg_value();
|
||||
case CL_INVALID_ARG_SIZE: throw invalid_arg_size();
|
||||
case CL_INVALID_KERNEL_ARGS: throw invalid_kernel_args();
|
||||
case CL_INVALID_WORK_DIMENSION: throw invalid_work_dimension();
|
||||
case CL_INVALID_WORK_GROUP_SIZE: throw invalid_work_group_size();
|
||||
case CL_INVALID_WORK_ITEM_SIZE: throw invalid_work_item_size();
|
||||
case CL_INVALID_GLOBAL_OFFSET: throw invalid_global_offset();
|
||||
case CL_INVALID_EVENT_WAIT_LIST: throw invalid_event_wait_list();
|
||||
case CL_INVALID_EVENT: throw invalid_event();
|
||||
case CL_INVALID_OPERATION: throw invalid_operation();
|
||||
case CL_INVALID_GL_OBJECT: throw invalid_gl_object();
|
||||
case CL_INVALID_BUFFER_SIZE: throw invalid_buffer_size();
|
||||
case CL_INVALID_MIP_LEVEL: throw invalid_mip_level();
|
||||
case CL_INVALID_GLOBAL_WORK_SIZE: throw invalid_global_work_size();
|
||||
#ifdef CL_INVALID_PROPERTY
|
||||
case CL_INVALID_PROPERTY: throw invalid_property();
|
||||
#endif
|
||||
default: throw;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,40 +0,0 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include "triton/driver/event.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
float event::elapsed_time() const{
|
||||
float time;
|
||||
dispatch::cuEventElapsedTime(&time, cu_->first, cu_->second);
|
||||
return time;
|
||||
}
|
||||
|
||||
handle<cu_event_t> const & event::cu() const
|
||||
{ return cu_; }
|
||||
|
||||
}
|
||||
}
|
@@ -38,15 +38,6 @@ inline void _delete(host_stream_t) { }
|
||||
inline void _delete(host_buffer_t x) { if(x.data) delete[] x.data; }
|
||||
inline void _delete(host_function_t) { }
|
||||
|
||||
//OpenCL
|
||||
inline void _delete(cl_platform_id) { }
|
||||
inline void _delete(cl_device_id x) { dispatch::clReleaseDevice(x); }
|
||||
inline void _delete(cl_context x) { dispatch::clReleaseContext(x); }
|
||||
inline void _delete(cl_program x) { dispatch::clReleaseProgram(x); }
|
||||
inline void _delete(cl_kernel x) { dispatch::clReleaseKernel(x); }
|
||||
inline void _delete(cl_command_queue x) { dispatch::clReleaseCommandQueue(x); }
|
||||
inline void _delete(cl_mem x) { dispatch::clReleaseMemObject(x); }
|
||||
|
||||
//CUDA
|
||||
inline void _delete(CUcontext x) { dispatch::cuCtxDestroy(x); }
|
||||
inline void _delete(CUdeviceptr x) { dispatch::cuMemFree(x); }
|
||||
@@ -87,14 +78,6 @@ template class handle<CUfunction>;
|
||||
template class handle<CUmodule>;
|
||||
template class handle<CUPlatform>;
|
||||
|
||||
template class handle<cl_platform_id>;
|
||||
template class handle<cl_device_id>;
|
||||
template class handle<cl_context>;
|
||||
template class handle<cl_program>;
|
||||
template class handle<cl_command_queue>;
|
||||
template class handle<cl_mem>;
|
||||
template class handle<cl_kernel>;
|
||||
|
||||
template class handle<host_platform_t>;
|
||||
template class handle<host_device_t>;
|
||||
template class handle<host_context_t>;
|
||||
|
@@ -39,9 +39,6 @@ kernel::kernel(driver::module *program, CUfunction fn, bool has_ownership):
|
||||
polymorphic_resource(fn, has_ownership), program_(program){
|
||||
}
|
||||
|
||||
kernel::kernel(driver::module *program, cl_kernel fn, bool has_ownership):
|
||||
polymorphic_resource(fn, has_ownership), program_(program){
|
||||
}
|
||||
|
||||
kernel::kernel(driver::module *program, host_function_t fn, bool has_ownership):
|
||||
polymorphic_resource(fn, has_ownership), program_(program){
|
||||
@@ -50,7 +47,6 @@ kernel::kernel(driver::module *program, host_function_t fn, bool has_ownership):
|
||||
kernel* kernel::create(driver::module* program, const char* name) {
|
||||
switch(program->backend()){
|
||||
case CUDA: return new cu_kernel(program, name);
|
||||
case OpenCL: return new ocl_kernel(program, name);
|
||||
case Host: return new host_kernel(program, name);
|
||||
default: throw std::runtime_error("unknown backend");
|
||||
}
|
||||
@@ -68,84 +64,29 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog
|
||||
hst_->fn = program->hst()->functions.at(name);
|
||||
}
|
||||
|
||||
void host_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
||||
if(index + 1> params_store_.size()){
|
||||
params_store_.resize(index+1);
|
||||
params_.resize(index+1);
|
||||
}
|
||||
params_store_[index].reset(malloc(size), free);
|
||||
memcpy(params_store_[index].get(), ptr, size);
|
||||
params_[index] = params_store_[index].get();
|
||||
}
|
||||
|
||||
void host_kernel::setArg(unsigned int index, driver::buffer* buffer){
|
||||
if(buffer)
|
||||
kernel::setArg(index, (void*)buffer->hst()->data);
|
||||
else
|
||||
kernel::setArg(index, (std::ptrdiff_t)0);
|
||||
}
|
||||
|
||||
const std::vector<void *> &host_kernel::params(){
|
||||
return params_;
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
ocl_kernel::ocl_kernel(driver::module* program, const char* name): kernel(program, cl_kernel(), true) {
|
||||
// cl_uint res;
|
||||
// check(dispatch::clCreateKernelsInProgram(*program->cl(), 0, NULL, &res));
|
||||
// std::cout << res << std::endl;
|
||||
cl_int err;
|
||||
*cl_ = dispatch::clCreateKernel(*program->cl(), "matmul", &err);
|
||||
check(err);
|
||||
}
|
||||
|
||||
void ocl_kernel::setArg(unsigned int index, std::size_t size, void* ptr) {
|
||||
check(dispatch::clSetKernelArg(*cl_, index, size, ptr));
|
||||
}
|
||||
|
||||
void ocl_kernel::setArg(unsigned int index, driver::buffer* buffer) {
|
||||
if(buffer)
|
||||
check(dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl()));
|
||||
else
|
||||
kernel::setArg(index, (std::ptrdiff_t)0);
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) {
|
||||
cu_params_store_.reserve(64);
|
||||
cu_params_.reserve(64);
|
||||
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
|
||||
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
}
|
||||
|
||||
void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){
|
||||
if(index + 1> cu_params_store_.size()){
|
||||
cu_params_store_.resize(index+1);
|
||||
cu_params_.resize(index+1);
|
||||
dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
|
||||
// properties
|
||||
int shared_total, shared_optin, shared_static;
|
||||
int n_spills, n_reg;
|
||||
CUdevice dev;
|
||||
dispatch::cuCtxGetDevice(&dev);
|
||||
dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
||||
dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
||||
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
||||
if (shared_optin > 49152){
|
||||
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
||||
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||
}
|
||||
cu_params_store_[index].reset(malloc(size), free);
|
||||
memcpy(cu_params_store_[index].get(), ptr, size);
|
||||
cu_params_[index] = cu_params_store_[index].get();
|
||||
}
|
||||
|
||||
void cu_kernel::setArg(unsigned int index, driver::buffer* data){
|
||||
if(data)
|
||||
kernel::setArg(index, *data->cu());
|
||||
else
|
||||
kernel::setArg(index, (std::ptrdiff_t)0);
|
||||
}
|
||||
|
||||
void* const* cu_kernel::cu_params() const
|
||||
{ return cu_params_.data(); }
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -20,14 +20,21 @@
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <unistd.h>
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "triton/tools/sha1.hpp"
|
||||
#include "triton/tools/sys/getenv.hpp"
|
||||
#include "triton/tools/sys/mkdir.hpp"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/CodeGen.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
@@ -39,6 +46,19 @@
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
|
||||
std::string exec(const char* cmd) {
|
||||
std::array<char, 128> buffer;
|
||||
std::string result;
|
||||
std::unique_ptr<FILE, decltype(&pclose)> pipe(popen(cmd, "r"), pclose);
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed!");
|
||||
}
|
||||
while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) {
|
||||
result += buffer.data();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
@@ -60,27 +80,19 @@ void module::init_llvm() {
|
||||
}
|
||||
}
|
||||
|
||||
module::module(driver::context* ctx, CUmodule mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
|
||||
module::module(CUmodule mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||
}
|
||||
|
||||
module::module(driver::context* ctx, cl_program mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
|
||||
module::module(host_module_t mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), spilled_(0) {
|
||||
}
|
||||
|
||||
module::module(driver::context* ctx, host_module_t mod, bool has_ownership)
|
||||
: polymorphic_resource(mod, has_ownership), ctx_(ctx) {
|
||||
}
|
||||
|
||||
driver::context* module::context() const {
|
||||
return ctx_;
|
||||
}
|
||||
|
||||
module* module::create(driver::context* ctx, std::unique_ptr<llvm::Module> src) {
|
||||
switch(ctx->backend()){
|
||||
case CUDA: return new cu_module(ctx, std::move(src));
|
||||
case OpenCL: return new ocl_module(ctx, std::move(src));
|
||||
case Host: return new host_module(ctx, std::move(src));
|
||||
module* module::create(driver::device* device, std::unique_ptr<llvm::Module> src) {
|
||||
switch(device->backend()){
|
||||
case CUDA: return new cu_module(device, std::move(src));
|
||||
case Host: return new host_module(std::move(src));
|
||||
default: throw std::runtime_error("unknown backend");
|
||||
}
|
||||
}
|
||||
@@ -90,12 +102,150 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string& features,
|
||||
file_type_t ft) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// Host //
|
||||
/* ------------------------ */
|
||||
|
||||
host_module::host_module(std::unique_ptr<llvm::Module> src): module(host_module_t(), true) {
|
||||
init_llvm();
|
||||
// // debug
|
||||
// create kernel wrapper
|
||||
llvm::LLVMContext &ctx = src->getContext();
|
||||
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
|
||||
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src);
|
||||
llvm::Function* fn = &*src->getFunctionList().begin();
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
|
||||
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
|
||||
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
|
||||
llvm::IRBuilder<> ir_builder(ctx);
|
||||
ir_builder.SetInsertPoint(entry);
|
||||
auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; };
|
||||
llvm::Value* base = main->arg_begin();
|
||||
llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType());
|
||||
|
||||
size_t offset = 0;
|
||||
for(unsigned i = 0; i < ptrs.size(); i++){
|
||||
ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset));
|
||||
size_t nbytes = get_size(fn_ty->getParamType(i));
|
||||
offset += nbytes;
|
||||
if(i < ptrs.size() - 1){
|
||||
size_t np1bytes = get_size(fn_ty->getParamType(i+1));
|
||||
offset = (offset + np1bytes - 1) / np1bytes * np1bytes;
|
||||
}
|
||||
}
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo());
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
fn_args[i] = ir_builder.CreateLoad(ptrs[i]);
|
||||
|
||||
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
|
||||
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
|
||||
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
|
||||
ir_builder.CreateCall(fn, fn_args);
|
||||
ir_builder.CreateRetVoid();
|
||||
|
||||
// llvm::legacy::PassManager pm;
|
||||
// pm.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
// pm.add(llvm::createVerifierPass());
|
||||
// pm.run(*module);
|
||||
// pm.run(*src);
|
||||
|
||||
// create execution engine
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName().str()] = &fn;
|
||||
|
||||
// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost();
|
||||
// auto DL = JTMB.getDefaultDataLayoutForTarget();
|
||||
// auto CIRC = std::unique_ptr<llvm::orc::ConcurrentIRCompiler>(new llvm::orc::ConcurrentIRCompiler(JTMB));
|
||||
// hst_->ES = new llvm::orc::ExecutionSession();
|
||||
// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr<llvm::SectionMemoryManager>(new llvm::SectionMemoryManager()); });
|
||||
// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC);
|
||||
// hst_->DL = new llvm::DataLayout(std::move(*DL));
|
||||
// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL);
|
||||
// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr<llvm::LLVMContext>(new llvm::LLVMContext()));
|
||||
// hst_->MainJD = &hst_->ES->createJITDylib("<main>");
|
||||
// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
|
||||
// hst_->DL->getGlobalPrefix())));
|
||||
// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx)));
|
||||
// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress());
|
||||
|
||||
|
||||
|
||||
llvm::EngineBuilder builder(std::move(src));
|
||||
builder.setErrorStr(&hst_->error);
|
||||
builder.setMCJITMemoryManager(std::make_unique<llvm::SectionMemoryManager>());
|
||||
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
|
||||
builder.setEngineKind(llvm::EngineKind::JIT);
|
||||
hst_->engine = builder.create();
|
||||
hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main"));
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
|
||||
size_t start_replace = str.find(begin);
|
||||
size_t end_replace = str.find(end, start_replace);
|
||||
if(start_replace == std::string::npos)
|
||||
return false;
|
||||
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::map<int, int> vptx = {
|
||||
{10000, 63},
|
||||
{10010, 64},
|
||||
{10020, 65},
|
||||
{11000, 70},
|
||||
{11010, 71},
|
||||
{11020, 72}
|
||||
};
|
||||
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
// LLVM version in use may not officially support target hardware
|
||||
int max_nvvm_cc = 75;
|
||||
int max_nvvm_ptx = 64;
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
int cc = ((driver::cu_device*)device)->compute_capability();
|
||||
std::string sm = "sm_" + std::to_string(cc);
|
||||
// driver version
|
||||
int version;
|
||||
dispatch::cuDriverGetVersion(&version);
|
||||
int major = version / 1000;
|
||||
int minor = (version - major*1000) / 10;
|
||||
if(major < 10)
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
// PTX version
|
||||
int ptx = vptx.at(version);
|
||||
int ptx_major = ptx / 10;
|
||||
int ptx_minor = ptx % 10;
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "nvptx64-nvidia-cuda";
|
||||
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
|
||||
std::string layout = "";
|
||||
std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
|
||||
init_llvm();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
pm.add(llvm::createVerifierPass());
|
||||
pm.run(*module);
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
@@ -117,176 +267,126 @@ void module::compile_llvm_module(std::unique_ptr<llvm::Module> module, const std
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
// convert triton file type to llvm file type
|
||||
auto ll_file_type = [&](module::file_type_t type){
|
||||
if(type == Object)
|
||||
return llvm::TargetMachine::CGFT_ObjectFile;
|
||||
return llvm::TargetMachine::CGFT_AssemblyFile;
|
||||
};
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, ll_file_type(ft));
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
pass.run(*module);
|
||||
|
||||
// post-process
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
|
||||
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
|
||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// Host //
|
||||
/* ------------------------ */
|
||||
|
||||
host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, host_module_t(), true) {
|
||||
init_llvm();
|
||||
// host info
|
||||
// std::string triple = llvm::sys::getDefaultTargetTriple();
|
||||
// std::string cpu = llvm::sys::getHostCPUName();
|
||||
// llvm::SmallVector<char, 0> buffer;
|
||||
// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly);
|
||||
|
||||
// create kernel wrapper
|
||||
llvm::LLVMContext &ctx = src->getContext();
|
||||
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
|
||||
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
|
||||
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
|
||||
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
|
||||
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
|
||||
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
|
||||
llvm::Function* fn = src->getFunction("matmul");
|
||||
llvm::FunctionType *fn_ty = fn->getFunctionType();
|
||||
std::vector<llvm::Value*> fn_args(fn_ty->getNumParams());
|
||||
std::vector<llvm::Value*> ptrs(fn_args.size() - 3);
|
||||
llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main);
|
||||
llvm::IRBuilder<> ir_builder(ctx);
|
||||
ir_builder.SetInsertPoint(entry);
|
||||
for(unsigned i = 0; i < ptrs.size(); i++)
|
||||
ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i));
|
||||
for(unsigned i = 0; i < ptrs.size(); i++){
|
||||
llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo());
|
||||
fn_args[i] = ir_builder.CreateLoad(addr);
|
||||
}
|
||||
fn_args[fn_args.size() - 3] = main->arg_begin() + 1;
|
||||
fn_args[fn_args.size() - 2] = main->arg_begin() + 2;
|
||||
fn_args[fn_args.size() - 1] = main->arg_begin() + 3;
|
||||
ir_builder.CreateCall(fn, fn_args);
|
||||
ir_builder.CreateRetVoid();
|
||||
|
||||
|
||||
// create execution engine
|
||||
for(llvm::Function& fn: src->functions())
|
||||
hst_->functions[fn.getName()] = &fn;
|
||||
llvm::EngineBuilder builder(std::move(src));
|
||||
builder.setErrorStr(&hst_->error);
|
||||
builder.setMCJITMemoryManager(llvm::make_unique<llvm::SectionMemoryManager>());
|
||||
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
|
||||
builder.setEngineKind(llvm::EngineKind::JIT);
|
||||
builder.setUseOrcMCJITReplacement(true);
|
||||
hst_->engine = builder.create();
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> host_module::symbol(const char *name) const {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
ocl_module::ocl_module(driver::context * context, std::unique_ptr<llvm::Module> src): module(context, cl_program(), true) {
|
||||
throw std::runtime_error("not supported");
|
||||
// init_llvm();
|
||||
// llvm::SmallVector<char, 0> buffer;
|
||||
// module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer, "code-object-v3", Object);
|
||||
// std::ofstream output("/tmp/tmp.o", std::ios::binary);
|
||||
// std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator<char>(output));
|
||||
// system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o");
|
||||
// std::ifstream input("/tmp/tmp.o", std::ios::in | std::ios::binary );
|
||||
// std::vector<unsigned char> in_buffer(std::istreambuf_iterator<char>(input), {});
|
||||
// size_t sizes[] = {in_buffer.size()};
|
||||
// const unsigned char* data[] = {(unsigned char*)in_buffer.data()};
|
||||
// cl_int status;
|
||||
// cl_int err;
|
||||
// *cl_ = dispatch::clCreateProgramWithBinary(*context->cl(), 1, &*context->device()->cl(), sizes, data, &status, &err);
|
||||
// check(status);
|
||||
// check(err);
|
||||
// try{
|
||||
// dispatch::clBuildProgram(*cl_, 1, &*context->device()->cl(), NULL, NULL, NULL);
|
||||
// }
|
||||
// catch(...){
|
||||
// char log[2048];
|
||||
// dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL);
|
||||
// throw;
|
||||
// }
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> ocl_module::symbol(const char *name) const {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
|
||||
size_t start_replace = str.find(begin);
|
||||
size_t end_replace = str.find(end, start_replace);
|
||||
if(start_replace == std::string::npos)
|
||||
return false;
|
||||
str.replace(start_replace, end_replace + 1 - start_replace, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device) {
|
||||
// options
|
||||
auto options = llvm::cl::getRegisteredOptions();
|
||||
// for(auto& opt: options)
|
||||
// std::cout << opt.getKey().str() << std::endl;
|
||||
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
|
||||
assert(short_ptr);
|
||||
short_ptr->setValue(true);
|
||||
// compute capability
|
||||
auto cc = ((driver::cu_device*)device)->compute_capability();
|
||||
std::string sm = "sm_" + std::to_string(cc.first) + std::to_string(cc.second);
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly);
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
int version;
|
||||
dispatch::cuDriverGetVersion(&version);
|
||||
int major = version / 1000;
|
||||
// int minor = (version - major*1000) / 10;
|
||||
if(major < 10)
|
||||
throw std::runtime_error("Triton requires CUDA 10+");
|
||||
find_and_replace(result, ".version", "\n", ".version 6.4\n");
|
||||
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
|
||||
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// std::cout << source << std::endl;
|
||||
void cu_module::init_from_ptx(const std::string& ptx) {
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
unsigned int errbufsize = 8096;
|
||||
std::string errbuf(errbufsize, 0);
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
|
||||
// std::cout << ptx << std::endl;
|
||||
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::base const &){
|
||||
#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
#endif
|
||||
// // compile ptx with ptxas
|
||||
// char _fsrc[] = "/tmp/triton_k_XXXXXX";
|
||||
// char _flog[] = "/tmp/triton_l_XXXXXX";
|
||||
// int fdsrc = mkstemp(_fsrc);
|
||||
// int fdlog = mkstemp(_flog);
|
||||
// std::string fsrc = _fsrc;
|
||||
// std::string flog = _flog;
|
||||
// std::ofstream ofs(fsrc);
|
||||
// ofs << ptx;
|
||||
// ofs.close();
|
||||
// std::string cmd;
|
||||
// int err;
|
||||
// driver::cu_device* cu_device = (driver::cu_device*)device;
|
||||
// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
|
||||
// err = system(cmd.c_str());
|
||||
// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str());
|
||||
// std::ifstream file(flog);
|
||||
// std::string log;
|
||||
// if(file)
|
||||
// while (!file.eof()) log.push_back(file.get());
|
||||
// unlink(_fsrc);
|
||||
// unlink(_flog);
|
||||
|
||||
// std::smatch match;
|
||||
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
// spilled_ = 0;
|
||||
// while (std::regex_search (log,match,expr)){
|
||||
// spilled_ += std::stoi(match[1]);
|
||||
// log = match.suffix();
|
||||
// }
|
||||
// std::cout << log << std::endl;
|
||||
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
|
||||
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER,
|
||||
CU_JIT_LOG_VERBOSE};
|
||||
unsigned int errbufsize = 8192;
|
||||
unsigned int logbufsize = 8192;
|
||||
char _err[errbufsize];
|
||||
char _log[logbufsize];
|
||||
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval);
|
||||
std::string err(_err);
|
||||
std::string log(_log);
|
||||
// std::smatch match;
|
||||
// std::regex expr ("\\b([0-9]+) bytes spill");
|
||||
// spilled_ = 0;
|
||||
// while (std::regex_search(log,match,expr)){
|
||||
// spilled_ += std::stoi(match[1]);
|
||||
// log = match.suffix();
|
||||
// }
|
||||
}
|
||||
catch(exception::cuda::invalid_ptx const &){
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cout << ptx << std::endl;
|
||||
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||
// exit(1);
|
||||
//#endif
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
cu_module::cu_module(driver::device* device, std::unique_ptr<llvm::Module> ll_module): module(CUmodule(), true) {
|
||||
llvm::raw_string_ostream oss(llir_);
|
||||
oss << *ll_module;
|
||||
oss.flush();
|
||||
std::string cache_path = tools::getenv("TRITON_DEBUG_CACHE_PATH");
|
||||
if(cache_path.empty())
|
||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||
else{
|
||||
tools::mkdir(cache_path);
|
||||
// update cache path to PTX file
|
||||
unsigned char hash[20];
|
||||
sha1::calc((void*)llir_.data(), llir_.size(), hash);
|
||||
char _hex[40];
|
||||
sha1::toHexString(hash, _hex);
|
||||
std::string hex(_hex, _hex + 40);
|
||||
cache_path += "/" + hex;
|
||||
// read
|
||||
std::ifstream ifs(cache_path);
|
||||
std::ostringstream _ptx;
|
||||
if(ifs)
|
||||
_ptx << ifs.rdbuf();
|
||||
ptx_ = _ptx.str();
|
||||
// compile and write-back if read empty
|
||||
if(ptx_.empty()){
|
||||
ptx_ = compile_llvm_module(std::move(ll_module), device);
|
||||
std::ofstream ofs(cache_path);
|
||||
ofs << ptx_;
|
||||
}
|
||||
}
|
||||
init_from_ptx(ptx_);
|
||||
}
|
||||
|
||||
cu_module::cu_module(driver::device*, std::string const & source) : module(CUmodule(), true), ptx_(source){
|
||||
init_from_ptx(ptx_);
|
||||
}
|
||||
|
||||
std::unique_ptr<buffer> cu_module::symbol(const char *name) const{
|
||||
CUdeviceptr handle;
|
||||
size_t size;
|
||||
dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name);
|
||||
std::unique_ptr<buffer> res(new cu_buffer(ctx_, size, handle, false));
|
||||
std::unique_ptr<buffer> res(new cu_buffer(size, handle, false));
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
|
@@ -51,27 +51,6 @@ void cu_platform::devices(std::vector<device *> &devices) const{
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
std::string cl_platform::version() const {
|
||||
size_t size;
|
||||
check(dispatch::clGetPlatformInfo(*cl_, CL_PLATFORM_VERSION, 0, nullptr, &size));
|
||||
std::string result(size, 0);
|
||||
check(dispatch::clGetPlatformInfo(*cl_, CL_PLATFORM_VERSION, size, (void*)&*result.begin(), nullptr));
|
||||
return result;
|
||||
}
|
||||
|
||||
void cl_platform::devices(std::vector<device*> &devices) const{
|
||||
cl_uint num_devices;
|
||||
check(dispatch::clGetDeviceIDs(*cl_, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices));
|
||||
std::vector<cl_device_id> ids(num_devices);
|
||||
check(dispatch::clGetDeviceIDs(*cl_, CL_DEVICE_TYPE_GPU, num_devices, ids.data(), nullptr));
|
||||
for(cl_device_id id: ids)
|
||||
devices.push_back(new driver::ocl_device(id));
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// Host //
|
||||
/* ------------------------ */
|
||||
|
@@ -21,12 +21,12 @@
|
||||
*/
|
||||
|
||||
#include <cassert>
|
||||
#include <unistd.h>
|
||||
#include <array>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/event.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
@@ -42,51 +42,49 @@ namespace driver
|
||||
// Base //
|
||||
/* ------------------------ */
|
||||
|
||||
stream::stream(driver::context *ctx, CUstream cu, bool has_ownership)
|
||||
: polymorphic_resource(cu, has_ownership), ctx_(ctx) {
|
||||
stream::stream(CUstream cu, bool has_ownership)
|
||||
: polymorphic_resource(cu, has_ownership) {
|
||||
}
|
||||
|
||||
stream::stream(driver::context *ctx, cl_command_queue cl, bool has_ownership)
|
||||
: polymorphic_resource(cl, has_ownership), ctx_(ctx) {
|
||||
|
||||
stream::stream(host_stream_t cl, bool has_ownership)
|
||||
: polymorphic_resource(cl, has_ownership) {
|
||||
}
|
||||
|
||||
stream::stream(driver::context *ctx, host_stream_t cl, bool has_ownership)
|
||||
: polymorphic_resource(cl, has_ownership), ctx_(ctx) {
|
||||
}
|
||||
|
||||
driver::stream* stream::create(driver::context* ctx) {
|
||||
switch(ctx->backend()){
|
||||
case CUDA: return new cu_stream(ctx);
|
||||
case OpenCL: return new cl_stream(ctx);
|
||||
case Host: return new host_stream(ctx);
|
||||
driver::stream* stream::create(backend_t backend) {
|
||||
switch(backend){
|
||||
case CUDA: return new cu_stream();
|
||||
case Host: return new host_stream();
|
||||
default: throw std::runtime_error("unknown backend");
|
||||
}
|
||||
}
|
||||
|
||||
driver::context* stream::context() const {
|
||||
return ctx_;
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// Host //
|
||||
/* ------------------------ */
|
||||
|
||||
host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) {
|
||||
|
||||
host_stream::host_stream(): stream(host_stream_t(), true) {
|
||||
hst_->pool.reset(new ThreadPool(1));
|
||||
hst_->futures.reset(new std::vector<std::future<void>>());
|
||||
}
|
||||
|
||||
void host_stream::synchronize() {
|
||||
|
||||
for(auto& x: *hst_->futures)
|
||||
x.wait();
|
||||
hst_->futures->clear();
|
||||
hst_->args.clear();
|
||||
}
|
||||
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
driver::host_kernel* hst_kernel = (host_kernel*)kernel;
|
||||
llvm::ExecutionEngine* engine = kernel->module()->hst()->engine;
|
||||
void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main");
|
||||
void host_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t) {
|
||||
auto hst = kernel->module()->hst();
|
||||
hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
|
||||
char* params = new char[args_size];
|
||||
std::memcpy((void*)params, (void*)args, args_size);
|
||||
for(size_t i = 0; i < grid[0]; i++)
|
||||
for(size_t j = 0; j < grid[1]; j++)
|
||||
for(size_t k = 0; k < grid[2]; k++)
|
||||
fn((char**)hst_kernel->params().data(), int32_t(i), int32_t(j), int32_t(k));
|
||||
hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k)));
|
||||
}
|
||||
|
||||
void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
@@ -98,68 +96,33 @@ void host_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------ */
|
||||
// OpenCL //
|
||||
/* ------------------------ */
|
||||
|
||||
cl_stream::cl_stream(driver::context *ctx): stream(ctx, cl_command_queue(), true) {
|
||||
cl_int err;
|
||||
*cl_ = dispatch::clCreateCommandQueue(*ctx->cl(), *ctx->device()->cl(), 0, &err);
|
||||
check(err);
|
||||
}
|
||||
|
||||
void cl_stream::synchronize() {
|
||||
check(dispatch::clFinish(*cl_));
|
||||
}
|
||||
|
||||
void cl_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void **extra) {
|
||||
std::array<size_t, 3> global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]};
|
||||
check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL));
|
||||
}
|
||||
|
||||
void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
check(dispatch::clEnqueueWriteBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL));
|
||||
}
|
||||
|
||||
void cl_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
|
||||
check(dispatch::clEnqueueReadBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL));
|
||||
}
|
||||
|
||||
/* ------------------------ */
|
||||
// CUDA //
|
||||
/* ------------------------ */
|
||||
|
||||
inline CUcontext get_context() {
|
||||
CUcontext result;
|
||||
dispatch::cuCtxGetCurrent(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
cu_stream::cu_stream(CUstream str, bool take_ownership):
|
||||
stream(backend::contexts::import(get_context()), str, take_ownership) {
|
||||
stream(str, take_ownership) {
|
||||
}
|
||||
|
||||
cu_stream::cu_stream(driver::context *context): stream((driver::cu_context*)context, CUstream(), true) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
cu_stream::cu_stream(): stream(CUstream(), true) {
|
||||
dispatch::cuStreamCreate(&*cu_, 0);
|
||||
}
|
||||
|
||||
void cu_stream::synchronize() {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
dispatch::cuStreamSynchronize(*cu_);
|
||||
}
|
||||
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<event> const *, event* event, void** extra) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->first, *cu_);
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra);
|
||||
if(event)
|
||||
dispatch::cuEventRecord(event->cu()->second, *cu_);
|
||||
void cu_stream::enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem) {
|
||||
void *config[] = {
|
||||
CU_LAUNCH_PARAM_BUFFER_POINTER, args,
|
||||
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config);
|
||||
}
|
||||
|
||||
void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
if(blocking)
|
||||
dispatch::cuMemcpyHtoD(*buffer->cu() + offset, ptr, size);
|
||||
else
|
||||
@@ -167,7 +130,6 @@ void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset,
|
||||
}
|
||||
|
||||
void cu_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) {
|
||||
cu_context::context_switcher ctx_switch(*ctx_);
|
||||
if(blocking)
|
||||
dispatch::cuMemcpyDtoH(ptr, *buffer->cu() + offset, size);
|
||||
else
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/constant.h"
|
||||
@@ -10,7 +11,7 @@ namespace triton{
|
||||
namespace ir{
|
||||
|
||||
builder::builder(context &ctx):
|
||||
ctx_(ctx), block_(nullptr), insert_point_(nullptr) {}
|
||||
ctx_(ctx), block_(nullptr) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// utilities
|
||||
@@ -44,6 +45,9 @@ void builder::set_insert_point(basic_block *block){
|
||||
// convenience functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::get_int1(bool val)
|
||||
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
|
||||
|
||||
value *builder::get_int32(int32_t val)
|
||||
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
|
||||
|
||||
@@ -253,6 +257,15 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
|
||||
value *builder::create_load(value *ptr, const std::string &name){
|
||||
return insert(unmasked_load_inst::create(ptr, name));
|
||||
// type *ty = ptr->get_type()->get_pointer_element_ty();
|
||||
// value *mask = constant_int::get(get_int1_ty(), 1);
|
||||
// value *undef = undef_value::get(ty);
|
||||
// if(ptr->get_type()->is_tile_ty()){
|
||||
// auto shapes = ptr->get_type()->get_tile_shapes();
|
||||
// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name));
|
||||
// }
|
||||
// return insert(masked_load_inst::create(ptr, mask, undef, name));
|
||||
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||
@@ -263,6 +276,7 @@ value *builder::create_masked_load(value *ptr, value *mask, value *false_value,
|
||||
return insert(masked_load_inst::create(ptr, mask, false_value, name));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){
|
||||
return insert(masked_store_inst::create(ptr, val, mask, name));
|
||||
}
|
||||
@@ -307,14 +321,18 @@ value *builder::create_atomic_exch(value *ptr, value *val, const std::string &na
|
||||
return insert(atomic_exch_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, name));
|
||||
value *builder::create_atomic_add(value *ptr, value *val, value *msk, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, msk, name));
|
||||
}
|
||||
|
||||
value *builder::create_exp(value *arg, const std::string &name){
|
||||
return insert(exp_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_log(value *arg, const std::string &name){
|
||||
return insert(log_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
||||
return insert(dot_inst::create_nn(A, B, C, name));
|
||||
}
|
||||
@@ -344,13 +362,22 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_to_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
|
||||
value *builder::create_copy_from_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_from_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) {
|
||||
return insert(masked_load_async_inst::create(ptr, mask, false_value, name));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
return insert(barrier_inst::create(ctx_, name));
|
||||
}
|
||||
|
||||
value *builder::create_async_wait(int N) {
|
||||
return insert(async_wait_inst::create(ctx_, N));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -43,6 +43,8 @@ constant_int::constant_int(type *ty, uint64_t value)
|
||||
: constant(ty, 0), value_(value){ }
|
||||
|
||||
constant_int *constant_int::get(type *ty, uint64_t value) {
|
||||
if (!ty->is_integer_ty())
|
||||
throw std::runtime_error("Cannot create constant_int with non integer ty");
|
||||
context_impl *impl = ty->get_context().p_impl.get();
|
||||
constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)];
|
||||
if(cst == nullptr)
|
||||
|
@@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins
|
||||
blocks_.reserve(num_reserved);
|
||||
}
|
||||
|
||||
value* phi_node::get_value_for_block(basic_block * block) {
|
||||
auto it = std::find(blocks_.begin(), blocks_.end(), block);
|
||||
size_t n = std::distance(blocks_.begin(), it);
|
||||
return get_incoming_value(n);
|
||||
}
|
||||
|
||||
// Set incoming value
|
||||
void phi_node::set_incoming_value(unsigned i, value *v){
|
||||
assert(v && "PHI node got a null value!");
|
||||
@@ -463,6 +469,34 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
|
||||
return new masked_load_inst(ptr, mask, false_value, name, next);
|
||||
}
|
||||
|
||||
// masked load async
|
||||
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next) {
|
||||
return new masked_load_async_inst(ptr, mask, false_value, name, next);
|
||||
}
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next)
|
||||
: io_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, msk);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, value *msk, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, msk, name, next);
|
||||
}
|
||||
|
||||
// store
|
||||
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
@@ -734,17 +768,6 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
|
||||
return new atomic_exch_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_add_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
// exp
|
||||
|
||||
@@ -757,6 +780,17 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction *
|
||||
return new exp_inst(val, name, next);
|
||||
}
|
||||
|
||||
// log
|
||||
|
||||
log_inst::log_inst(value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(val->get_type(), INST_LOG, 1, name, next) {
|
||||
set_operand(0, val);
|
||||
}
|
||||
|
||||
instruction* log_inst::create(value *val, const std::string& name, instruction *next) {
|
||||
return new log_inst(val, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
@@ -790,6 +824,13 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
||||
return new barrier_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { }
|
||||
|
||||
async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) {
|
||||
return new async_wait_inst(ctx, N, name, next);
|
||||
}
|
||||
|
||||
|
||||
// nv_dynamic_program_idx
|
||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
|
@@ -10,8 +10,8 @@ namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* Module */
|
||||
module::module(const std::string &name, context &ctx)
|
||||
: name_(name), context_(ctx), builder_(ctx) {
|
||||
module::module(const std::string &name)
|
||||
: name_(name), builder_(context_) {
|
||||
sealed_blocks_.insert(nullptr);
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
|
||||
if(it != metadatas_.end()){
|
||||
x->set_metadata(it->second.first, it->second.second);
|
||||
}
|
||||
value->set_name(name);
|
||||
// value->set_name(name);
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::value *value){
|
||||
|
@@ -65,7 +65,12 @@ void print(module &mod, std::ostream& os) {
|
||||
os << get_name(ops[i], cnt++);
|
||||
os << (i < num_ops - 1?", ":"");
|
||||
}
|
||||
os << ";" << std::endl;
|
||||
os << ";";
|
||||
// os << " (";
|
||||
// for(ir::user* usr: inst->get_users())
|
||||
// os << get_name(usr, cnt++) << ", " ;
|
||||
// os << " )";
|
||||
os << std::endl;
|
||||
}
|
||||
}
|
||||
os << "}" << std::endl;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user