Compare commits
341 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
0f5c6e619c | ||
|
c20215dad1 | ||
|
733301ff31 | ||
|
ff399fbc20 | ||
|
4023149ee3 | ||
|
2193bee94e | ||
|
411bacb2a8 | ||
|
bc73bbb12c | ||
|
8460ea3df1 | ||
|
678b9f53a2 | ||
|
0e8590f1c9 | ||
|
194ba103b1 | ||
|
1d3029faf8 | ||
|
2ba74d2729 | ||
|
fd2da4aff6 | ||
|
925d3d7f98 | ||
|
b5aafb0dab | ||
|
20100a7254 | ||
|
8650b4d1cb | ||
|
44f577984d | ||
|
0e4691e6dd | ||
|
0d7e753227 | ||
|
77bc5187b5 | ||
|
f16138d447 | ||
|
578ada7740 | ||
|
6311d70406 | ||
|
584086f08c | ||
|
3ca667dfa8 | ||
|
5ca1ed0101 | ||
|
db3aa1d1fb | ||
|
ddae106c0e | ||
|
bc98aead33 | ||
|
71b46acc42 | ||
|
33e6f0df7f | ||
|
af76c989eb | ||
|
09cc2d454b | ||
|
5d4b26d380 | ||
|
9a11a567ce | ||
|
11345e9b74 | ||
|
bdfdb9a1d2 | ||
|
77c752dc78 | ||
|
d3c925db8a | ||
|
2b0f877fad | ||
|
4a2d3b7d79 | ||
|
f55960e773 | ||
|
b244db06da | ||
|
7b61303ea1 | ||
|
ae59f51c2d | ||
|
f45e31ba7c | ||
|
dad97528b2 | ||
|
998fd5f9af | ||
|
3ac929b48b | ||
|
579c03615d | ||
|
25e1b36785 | ||
|
61d104ab3a | ||
|
8c3d4d5749 | ||
|
df67068bb0 | ||
|
677ddae618 | ||
|
6abe813d1c | ||
|
e318185eb4 | ||
|
7dc2a70edb | ||
|
48f30550f1 | ||
|
93b1adc53b | ||
|
82956e5d6b | ||
|
2baf333d44 | ||
|
49f6bc3f2b | ||
|
00f4ef6958 | ||
|
e647402fd3 | ||
|
4a77dfb042 | ||
|
889d9e34a1 | ||
|
c668d6596e | ||
|
4580a04710 | ||
|
cfbbc7b43a | ||
|
59a8e25f43 | ||
|
437ced38c2 | ||
|
210a296699 | ||
|
fe0c29b9ec | ||
|
7394d732ad | ||
|
3e2953f357 | ||
|
cc79376222 | ||
|
7b91c7befd | ||
|
968f59027e | ||
|
923d468187 | ||
|
027321cdcf | ||
|
e02e56dc63 | ||
|
ab56d310dd | ||
|
f28caddbf8 | ||
|
af85f5fa46 | ||
|
9b2bc88d11 | ||
|
86cab58d89 | ||
|
5b04331dd2 | ||
|
0a3f3d5f25 | ||
|
4912916c11 | ||
|
971f5782b4 | ||
|
d5eb9bc230 | ||
|
c9a2b9c7d4 | ||
|
4a399a7e40 | ||
|
22105bc33b | ||
|
4bf509889b | ||
|
a74cce375f | ||
|
f733327ba4 | ||
|
1bbb2430d9 | ||
|
1895ceaa2d | ||
|
feb7a2a0dc | ||
|
5b4c8f221e | ||
|
87413bc925 | ||
|
d345ddf837 | ||
|
b02bac41ba | ||
|
a428cf0bb2 | ||
|
b5e728cb14 | ||
|
6b9756532f | ||
|
8ce2c12e33 | ||
|
93209c07e0 | ||
|
58c8889235 | ||
|
7094657aa9 | ||
|
38573d1261 | ||
|
2cdc6d35c4 | ||
|
f13cbaab9f | ||
|
751e325d2e | ||
|
801c8a4c92 | ||
|
8876e53206 | ||
|
a60374a597 | ||
|
efa04cac1f | ||
|
3e7500dfe6 | ||
|
37037bb3be | ||
|
c82a206684 | ||
|
0e2883020a | ||
|
43fec2adca | ||
|
011bc83c1b | ||
|
96bff90471 | ||
|
d5eaa8dfa0 | ||
|
80f6a2698b | ||
|
205a493b10 | ||
|
abea3dc2c6 | ||
|
d35617bea1 | ||
|
d1a22a94e6 | ||
|
d954a05989 | ||
|
0835a4fb05 | ||
|
c736ba7c3e | ||
|
cd30a99aa2 | ||
|
d87435e536 | ||
|
7c9bc5a47b | ||
|
95feb10ec9 | ||
|
11a908655d | ||
|
cd78ce4888 | ||
|
ae2a1ab225 | ||
|
7d544799a0 | ||
|
3ca792043f | ||
|
bda209002e | ||
|
0cc3b1129b | ||
|
7d6c504e8d | ||
|
073be1d2ee | ||
|
5c7122004c | ||
|
dc4d40faec | ||
|
25f6689508 | ||
|
76bfac9f15 | ||
|
14b0fd4cfb | ||
|
6424771f55 | ||
|
9f08ecd684 | ||
|
2bed6fc850 | ||
|
e85c7a7fc7 | ||
|
bace26143d | ||
|
e0cc488055 | ||
|
76a9ee50a8 | ||
|
ea6d1f1b85 | ||
|
a4f68165cd | ||
|
539961072c | ||
|
0dd2ec2e3a | ||
|
d4d8eaf6c0 | ||
|
21f8a0646d | ||
|
a50a47a85b | ||
|
bb5765df5c | ||
|
d9dd97492f | ||
|
98ed7db8c1 | ||
|
a9dfdcaaa9 | ||
|
9b100302d3 | ||
|
40093a9878 | ||
|
4941bc7001 | ||
|
2fdf0a4fe8 | ||
|
077d6c8ff0 | ||
|
822ddcd14b | ||
|
7b48340ffd | ||
|
5a8a544d10 | ||
|
69ff52ea1f | ||
|
137bb67fad | ||
|
3b20170fa3 | ||
|
b0d6e2f322 | ||
|
2922dc141c | ||
|
807d8a1945 | ||
|
bef76b142a | ||
|
bd52e530a0 | ||
|
e68d6a7776 | ||
|
59d371c6eb | ||
|
3a23c1dd33 | ||
|
ccf9abe0ba | ||
|
4c97d1ecd7 | ||
|
e0c5709cc8 | ||
|
2a944ded53 | ||
|
4c94359199 | ||
|
bbc78f6516 | ||
|
bf32205edc | ||
|
94a2e10fe5 | ||
|
efdabe6073 | ||
|
a70acfec77 | ||
|
9801aa7b56 | ||
|
8bf551ae7a | ||
|
6f7acad48f | ||
|
120cda015e | ||
|
001fb757fe | ||
|
0ab9d67bad | ||
|
d8db0308cb | ||
|
03f1256f60 | ||
|
3edc2633e9 | ||
|
985798f101 | ||
|
d8fce83e7a | ||
|
a425f24d54 | ||
|
2509124dd0 | ||
|
39d4bfed83 | ||
|
5cdb948c05 | ||
|
4a8953efa3 | ||
|
fa62b4a8f6 | ||
|
4e93b41c52 | ||
|
e062812969 | ||
|
eb077fc993 | ||
|
e0b92c1380 | ||
|
558555630f | ||
|
e575ae3443 | ||
|
9def2424ab | ||
|
e31b9b4e66 | ||
|
73b04d71b2 | ||
|
0ff1a26b70 | ||
|
f23bf55f15 | ||
|
8ec9f037bb | ||
|
c86ad9c9ab | ||
|
1296eb877b | ||
|
5693b582ea | ||
|
edd4b0c8b7 | ||
|
5b7ba3eb96 | ||
|
791b953b21 | ||
|
b908095872 | ||
|
01cc3d4503 | ||
|
e66bf76354 | ||
|
f7ab96cfd7 | ||
|
9a02dddf29 | ||
|
5d54352164 | ||
|
2acaa4d0dd | ||
|
b7f0e87dc2 | ||
|
770ea96cca | ||
|
969d6de8a2 | ||
|
2d6df9b518 | ||
|
1b842f8e5e | ||
|
d3e584d4ba | ||
|
d35014ba47 | ||
|
5ce1b726dc | ||
|
858dec8372 | ||
|
90ded16c32 | ||
|
abbc554838 | ||
|
9b32075062 | ||
|
c2e6b90ff1 | ||
|
bfacc191b3 | ||
|
f5ad168686 | ||
|
c3c0ff0552 | ||
|
9e9d781912 | ||
|
d5f20dbce0 | ||
|
d4baad426d | ||
|
5123db0b7d | ||
|
12b6158c5c | ||
|
b352b16567 | ||
|
d132b7442b | ||
|
44442db96e | ||
|
bfcfad7abe | ||
|
2c287544cb | ||
|
c3756d1c33 | ||
|
83da3febf2 | ||
|
0735061fce | ||
|
2066ccd87e | ||
|
e22d92c63c | ||
|
87f8d9f163 | ||
|
ec2e7b8f48 | ||
|
d253eb8719 | ||
|
5211f23a63 | ||
|
2849e7a773 | ||
|
41dbaf3b3f | ||
|
c151e0f6aa | ||
|
e96edc16ff | ||
|
b53f5f3803 | ||
|
a12827848d | ||
|
6e5b0b4301 | ||
|
bd855ac13d | ||
|
313d6488f6 | ||
|
da5063d898 | ||
|
8fdd7e7ed6 | ||
|
3e395bc84e | ||
|
cecca90bea | ||
|
4163d32c49 | ||
|
34369906b4 | ||
|
ac10551d55 | ||
|
43723ccb95 | ||
|
585e5cd0ec | ||
|
94c83d30ce | ||
|
8bedcce9be | ||
|
c069ef907e | ||
|
8a882b215f | ||
|
768e0ded28 | ||
|
c0daffc625 | ||
|
274d613488 | ||
|
4ff3714d61 | ||
|
85426dbaf7 | ||
|
5b29da719d | ||
|
6aa5720d75 | ||
|
f26a48a3b4 | ||
|
226fde6ea1 | ||
|
64b8e7222d | ||
|
a714b6b856 | ||
|
bb1eebb4b4 | ||
|
6e7593b446 | ||
|
c45c2e9684 | ||
|
c7a272cb91 | ||
|
b120d70a0a | ||
|
70e28ff380 | ||
|
398d4b4aeb | ||
|
83da7065da | ||
|
298da78058 | ||
|
6cd1ec3955 | ||
|
68f7eeba92 | ||
|
4e6f667c2f | ||
|
23c71538fc | ||
|
3cb77aa126 | ||
|
9967e9d4b4 | ||
|
e8031fe61f | ||
|
b7cdf670c3 | ||
|
c7060eadb2 | ||
|
c0bb895d9d | ||
|
a34c57402f | ||
|
2293afece7 | ||
|
cb5c280691 | ||
|
2322d6df2a | ||
|
2f0f51be50 | ||
|
41ecd96300 | ||
|
d3851d8989 | ||
|
4b9df06568 |
1
.clang-format
Normal file
1
.clang-format
Normal file
@@ -0,0 +1 @@
|
||||
BasedOnStyle: LLVM
|
57
.github/CODEOWNERS
vendored
Normal file
57
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
# @global-owner1 and @global-owner2 will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @ptillet
|
||||
|
||||
# --------
|
||||
# Analyses
|
||||
# --------
|
||||
# Alias analysis
|
||||
include/triton/Analysis/Alias.h @Jokeren
|
||||
lib/Analysis/Alias.cpp @Jokeren
|
||||
# Allocation analysis
|
||||
include/triton/Analysis/Allocation.h @Jokeren
|
||||
lib/Analysis/Allocation.cpp @Jokeren
|
||||
# Membar analysis
|
||||
include/triton/Analysis/Membar.h @Jokeren
|
||||
lib/Analysis/Membar.cpp @Jokeren
|
||||
# AxisInfo analysis
|
||||
include/triton/Analysis/AxisInfo.h @ptillet
|
||||
lib/Analysis/AxisInfo.cpp @ptillet
|
||||
# Utilities
|
||||
include/triton/Analysis/Utility.h @Jokeren
|
||||
lib/Analysis/Utility.cpp @Jokeren
|
||||
|
||||
# ----------
|
||||
# Dialects
|
||||
# ----------
|
||||
# Pipeline pass
|
||||
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada
|
||||
# Prefetch pass
|
||||
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @daadaada
|
||||
# Coalesce pass
|
||||
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet
|
||||
# Layout simplification pass
|
||||
lib/Dialect/TritonGPU/Transforms/Combine.cpp @ptillet
|
||||
|
||||
# -----------
|
||||
# Conversions
|
||||
# -----------
|
||||
# TritonGPUToLLVM
|
||||
include/triton/Conversion/TritonGPUToLLVM/ @goostavz @Superjomn
|
||||
lib/Conversions/TritonGPUToLLVM @goostavz @Superjomn
|
||||
# TritonToTritonGPU
|
||||
include/triton/Conversion/TritonToTritonGPU/ @daadaada
|
||||
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @daadaada
|
||||
|
||||
|
||||
# -------
|
||||
# Targets
|
||||
# -------
|
||||
# LLVMIR
|
||||
include/triton/Target/LLVMIR/ @goostavz @Superjomn
|
||||
lib/Target/LLVMIR @goostavz @Superjomn
|
||||
# PTX
|
||||
include/triton/Target/PTX/ @goostavz @Superjomn
|
||||
lib/Target/PTX @goostavz @Superjomn
|
40
.github/workflows/documentation.yml
vendored
40
.github/workflows/documentation.yml
vendored
@@ -1,40 +0,0 @@
|
||||
name: Documentation
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
jobs:
|
||||
|
||||
Build-Documentation:
|
||||
|
||||
runs-on: self-hosted
|
||||
|
||||
steps:
|
||||
|
||||
|
||||
- name: Checkout gh-pages
|
||||
uses: actions/checkout@v1
|
||||
with:
|
||||
ref: 'gh-pages'
|
||||
|
||||
- name: Checkout branch
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Install Triton
|
||||
run: |
|
||||
alias python='python3'
|
||||
cd python
|
||||
pip3 install -e .
|
||||
|
||||
- name: Build docs
|
||||
run: |
|
||||
cd docs
|
||||
make html
|
||||
|
||||
- name: Publish docs
|
||||
run: |
|
||||
git checkout gh-pages
|
||||
sh ./update-website.sh
|
||||
git remote set-url origin git@github.com:ptillet/triton.git
|
||||
git push
|
78
.github/workflows/integration-tests.yml
vendored
78
.github/workflows/integration-tests.yml
vendored
@@ -5,30 +5,88 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
- triton-mlir
|
||||
|
||||
jobs:
|
||||
Runner-Preparation:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: Prepare runner matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||
fi
|
||||
|
||||
Integration-Tests:
|
||||
needs: Runner-Preparation
|
||||
|
||||
runs-on: self-hosted
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
|
||||
|
||||
steps:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -rf ~/.triton/cache/
|
||||
|
||||
- name: Check imports
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install isort
|
||||
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check python style
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install autopep8
|
||||
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check cpp style
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install clang-format
|
||||
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
|
||||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
|
||||
|
||||
- name: Flake8
|
||||
if: ${{ matrix.runner != 'macos-10.15' }}
|
||||
run: |
|
||||
pip install flake8
|
||||
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
|
||||
|
||||
- name: Install Triton
|
||||
run: |
|
||||
alias python='python3'
|
||||
cd python
|
||||
pip3 install -e .
|
||||
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
|
||||
|
||||
- name: Run benchmarks
|
||||
- name: Run lit tests
|
||||
run: |
|
||||
cd python/bench
|
||||
python3 -m run
|
||||
cd python
|
||||
LIT_TEST_DIR="build/$(ls build)/test"
|
||||
if [ ! -d "$LIT_TEST_DIR" ]; then
|
||||
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
|
||||
fi
|
||||
lit -v "$LIT_TEST_DIR"
|
||||
|
||||
- name: Run unit tests
|
||||
- name: Run python tests
|
||||
if: ${{matrix.runner[0] == 'self-hosted'}}
|
||||
run: |
|
||||
pytest .
|
||||
cd python/test/unit/
|
||||
pytest
|
||||
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
cd python/
|
||||
cd "build/$(ls build)"
|
||||
ctest
|
||||
|
4
.github/workflows/wheels.yml
vendored
4
.github/workflows/wheels.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
|
||||
Build-Wheels:
|
||||
|
||||
runs-on: self-hosted
|
||||
runs-on: [self-hosted, V100]
|
||||
|
||||
steps:
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
- name: Patch setup.py
|
||||
run: |
|
||||
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
|
||||
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
|
||||
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
|
||||
echo "" >> python/setup.cfg
|
||||
echo "[build_ext]" >> python/setup.cfg
|
||||
|
20
.gitignore
vendored
Normal file
20
.gitignore
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
# Triton builds
|
||||
build/
|
||||
|
||||
# Triton Python module builds
|
||||
python/build/
|
||||
python/triton.egg-info/
|
||||
python/triton/_C/libtriton.pyd
|
||||
python/triton/_C/libtriton.so
|
||||
|
||||
# Python caches
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
|
||||
# VS Code project files
|
||||
.vscode
|
||||
.vs
|
||||
|
||||
# JetBrains project files
|
||||
.idea
|
||||
cmake-build-*
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "deps/dlfcn-win32"]
|
||||
path = deps/dlfcn-win32
|
||||
url = https://github.com/dlfcn-win32/dlfcn-win32.git
|
4
.isort.cfg
Normal file
4
.isort.cfg
Normal file
@@ -0,0 +1,4 @@
|
||||
[settings]
|
||||
known_local_folder=triton
|
||||
line_length=88
|
||||
py_version=36
|
232
CMakeLists.txt
232
CMakeLists.txt
@@ -1,18 +1,27 @@
|
||||
cmake_minimum_required(VERSION 3.6)
|
||||
include(ExternalProject)
|
||||
|
||||
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||
endif()
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
set(CMAKE_INCLUDE_CURRENT_DIR ON)
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
if(NOT WIN32)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
endif()
|
||||
|
||||
# Options
|
||||
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
|
||||
# Ensure Python3 vars are set correctly
|
||||
# used conditionally in this file and by lit tests
|
||||
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
|
||||
|
||||
# Customized release build type with assertions: TritonRelBuildWithAsserts
|
||||
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
|
||||
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
|
||||
|
||||
# Default build type
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
@@ -20,60 +29,207 @@ if(NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release")
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
find_library(TERMINFO_LIBRARY tinfo)
|
||||
endif()
|
||||
|
||||
# Compiler flags
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
|
||||
# Third-party
|
||||
include_directories(${PYBIND11_INCLUDE_DIR})
|
||||
|
||||
if(WIN32)
|
||||
SET(BUILD_SHARED_LIBS OFF)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
|
||||
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
|
||||
if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
##########
|
||||
# LLVM
|
||||
##########
|
||||
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
||||
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx")
|
||||
if (NOT MLIR_DIR)
|
||||
if(NOT LLVM_LIBRARY_DIR)
|
||||
if(WIN32)
|
||||
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
|
||||
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
|
||||
add_definitions(${LLVM_DEFINITIONS_LIST})
|
||||
|
||||
llvm_map_components_to_libnames(LLVM_LIBRARIES support core
|
||||
NVPTXInfo nvptxcodegen
|
||||
AMDGPUInfo AMDGPUcodegen
|
||||
)
|
||||
else()
|
||||
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
|
||||
endif()
|
||||
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
|
||||
if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
|
||||
endif()
|
||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||
else()
|
||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||
else()
|
||||
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
|
||||
set(LLVM_LIBRARIES libLLVMNVPTXCodeGen.a libLLVMSelectionDAG.a libLLVMipo.a libLLVMInstrumentation.a
|
||||
libLLVMVectorize.a libLLVMLinker.a libLLVMIRReader.a libLLVMAsmParser.a libLLVMFrontendOpenMP.a
|
||||
libLLVMAsmPrinter.a libLLVMDebugInfoDWARF.a libLLVMCodeGen.a libLLVMTarget.a libLLVMScalarOpts.a
|
||||
libLLVMInstCombine.a libLLVMAggressiveInstCombine.a libLLVMTransformUtils.a libLLVMBitWriter.a
|
||||
libLLVMAnalysis.a libLLVMProfileData.a libLLVMObject.a libLLVMTextAPI.a libLLVMMCParser.a
|
||||
libLLVMBitReader.a libLLVMCore.a libLLVMRemarks.a libLLVMBitstreamReader.a libLLVMNVPTXDesc.a
|
||||
libLLVMMC.a libLLVMDebugInfoCodeView.a libLLVMDebugInfoMSF.a libLLVMBinaryFormat.a libLLVMNVPTXInfo.a
|
||||
libLLVMSupport.a libLLVMDemangle.a)
|
||||
set(LLVM_LIBRARIES
|
||||
libLLVMNVPTXCodeGen.a
|
||||
libLLVMNVPTXDesc.a
|
||||
libLLVMNVPTXInfo.a
|
||||
libLLVMAMDGPUDisassembler.a
|
||||
libLLVMMCDisassembler.a
|
||||
libLLVMAMDGPUCodeGen.a
|
||||
libLLVMMIRParser.a
|
||||
libLLVMGlobalISel.a
|
||||
libLLVMSelectionDAG.a
|
||||
libLLVMipo.a
|
||||
libLLVMInstrumentation.a
|
||||
libLLVMVectorize.a
|
||||
libLLVMLinker.a
|
||||
libLLVMIRReader.a
|
||||
libLLVMAsmParser.a
|
||||
libLLVMFrontendOpenMP.a
|
||||
libLLVMAsmPrinter.a
|
||||
libLLVMDebugInfoDWARF.a
|
||||
libLLVMCodeGen.a
|
||||
libLLVMTarget.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMInstCombine.a
|
||||
libLLVMAggressiveInstCombine.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMBitWriter.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMProfileData.a
|
||||
libLLVMObject.a
|
||||
libLLVMTextAPI.a
|
||||
libLLVMBitReader.a
|
||||
libLLVMAMDGPUAsmParser.a
|
||||
libLLVMMCParser.a
|
||||
libLLVMAMDGPUDesc.a
|
||||
libLLVMAMDGPUUtils.a
|
||||
libLLVMMC.a
|
||||
libLLVMDebugInfoCodeView.a
|
||||
libLLVMDebugInfoMSF.a
|
||||
libLLVMCore.a
|
||||
libLLVMRemarks.a
|
||||
libLLVMBitstreamReader.a
|
||||
libLLVMBinaryFormat.a
|
||||
libLLVMAMDGPUInfo.a
|
||||
libLLVMSupport.a
|
||||
libLLVMDemangle.a
|
||||
libLLVMPasses.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMipo.a
|
||||
libLLVMObjCARCOpts.a
|
||||
libLLVMCoroutines.a
|
||||
libLLVMAnalysis.a
|
||||
)
|
||||
endif()
|
||||
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||
endif()
|
||||
include_directories("${LLVM_INCLUDE_DIRS}")
|
||||
|
||||
# Python module
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# Build CUTLASS python wrapper if requested
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
|
||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||
set(CUTLASS_LIBRARIES "cutlass.a")
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
|
||||
include_directories("." ${PYTHON_SRC_PATH})
|
||||
if (PYTHON_INCLUDE_DIRS)
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
else()
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
link_directories(${Python3_LIBRARY_DIRS})
|
||||
link_libraries(${Python3_LIBRARIES})
|
||||
add_link_options(${Python3_LINK_OPTIONS})
|
||||
endif()
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# # Triton
|
||||
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
|
||||
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
# set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
# else()
|
||||
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# endif()
|
||||
|
||||
|
||||
# MLIR
|
||||
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
|
||||
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||
|
||||
include(TableGen) # required by AddMLIR
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
|
||||
# Disable warnings that show up in external code (gtest;pybind11)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
|
||||
|
||||
include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||
# link_directories(${LLVM_LIBRARY_DIR})
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(bin)
|
||||
|
||||
add_library(triton SHARED ${PYTHON_SRC})
|
||||
|
||||
# find_package(PythonLibs REQUIRED)
|
||||
|
||||
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
target_link_libraries(triton
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonLLVMIR
|
||||
TritonPTX
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRLLVMIR
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
MLIRMathToLLVM
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z tinfo)
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
|
||||
elseif(APPLE)
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z)
|
||||
else()
|
||||
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
|
||||
endif()
|
||||
|
||||
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
|
||||
# Check if the platform is MacOS
|
||||
if(APPLE)
|
||||
@@ -81,3 +237,7 @@ if(BUILD_PYTHON_MODULE)
|
||||
endif()
|
||||
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
|
||||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
|
||||
add_subdirectory(unittest)
|
||||
|
9
LICENSE
9
LICENSE
@@ -1,4 +1,6 @@
|
||||
/* Copyright 2018-2021 Philippe Tillet
|
||||
/*
|
||||
* Copyright 2018-2020 Philippe Tillet
|
||||
* Copyright 2020-2022 OpenAI
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
@@ -19,8 +21,3 @@
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
|
||||
// The compiler front-end is based on a modified version of WGTCC
|
||||
// https://github.com/wgtdkp/wgtcc
|
||||
// Copyright (c) 2016 wgtdkp
|
40
README.md
40
README.md
@@ -18,6 +18,46 @@ The foundations of this project are described in the following MAPL2019 publicat
|
||||
|
||||
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
|
||||
|
||||
# Quick Installation
|
||||
|
||||
You can install the latest stable release of Triton from pip:
|
||||
|
||||
```bash
|
||||
pip install triton
|
||||
```
|
||||
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
|
||||
|
||||
And the latest nightly release:
|
||||
|
||||
```bash
|
||||
pip install -U --pre triton
|
||||
```
|
||||
|
||||
# Install from source
|
||||
|
||||
```
|
||||
git clone https://github.com/openai/triton.git;
|
||||
cd triton/python;
|
||||
pip install cmake; # build time dependency
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
# Changelog
|
||||
|
||||
Version 1.1 is out! New features include:
|
||||
- Many, many bugfixes
|
||||
- More documentation
|
||||
- Automatic on-disk caching of compiled binary objects
|
||||
- Random Number Generation
|
||||
- Faster (up to 2x on A100), cleaner blocksparse ops
|
||||
|
||||
# Contributing
|
||||
|
||||
Community contributions are more than welcome, whether it be to fix bugs or to add new features. Feel free to open GitHub issues about your contribution ideas, and we will review them. A contributor's guide containing general guidelines is coming soon!
|
||||
|
||||
If you’re interested in joining our team and working on Triton & GPU kernels, [we’re hiring](https://openai.com/jobs/#acceleration)!
|
||||
|
||||
|
||||
# Compatibility
|
||||
|
||||
Supported Platforms:
|
||||
|
60
bin/CMakeLists.txt
Normal file
60
bin/CMakeLists.txt
Normal file
@@ -0,0 +1,60 @@
|
||||
add_subdirectory(FileCheck)
|
||||
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
|
||||
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
|
||||
|
||||
# TODO: what's this?
|
||||
llvm_update_compile_flags(triton-opt)
|
||||
target_link_libraries(triton-opt PRIVATE
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# tests
|
||||
TritonTestAnalysis
|
||||
# MLIR core
|
||||
MLIROptLib
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(triton-opt)
|
||||
|
||||
|
||||
# add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
|
||||
#llvm_update_compile_flags(triton-translate)
|
||||
# target_link_libraries(triton-translate PRIVATE
|
||||
# TritonAnalysis
|
||||
# TritonTransforms
|
||||
# TritonGPUTransforms
|
||||
# TritonLLVMIR
|
||||
# TritonDriver
|
||||
# ${dialect_libs}
|
||||
# ${conversion_libs}
|
||||
# # tests
|
||||
# TritonTestAnalysis
|
||||
|
||||
# LLVMCore
|
||||
# LLVMSupport
|
||||
# LLVMOption
|
||||
# LLVMCodeGen
|
||||
# LLVMAsmParser
|
||||
|
||||
# # MLIR core
|
||||
# MLIROptLib
|
||||
# MLIRIR
|
||||
# MLIRPass
|
||||
# MLIRSupport
|
||||
# MLIRTransforms
|
||||
# MLIRExecutionEngine
|
||||
# MLIRMathToLLVM
|
||||
# MLIRTransformUtils
|
||||
# MLIRLLVMToLLVMIRTranslation
|
||||
# MLIRNVVMToLLVMIRTranslation
|
||||
# )
|
||||
# mlir_check_all_link_libraries(triton-translate)
|
2
bin/FileCheck/CMakeLists.txt
Normal file
2
bin/FileCheck/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_llvm_executable(FileCheck FileCheck.cpp)
|
||||
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
|
882
bin/FileCheck/FileCheck.cpp
Normal file
882
bin/FileCheck/FileCheck.cpp
Normal file
@@ -0,0 +1,882 @@
|
||||
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// FileCheck does a line-by line check of a file that validates whether it
|
||||
// contains the expected content. This is useful for regression tests etc.
|
||||
//
|
||||
// This program exits with an exit status of 2 on error, exit status of 0 if
|
||||
// the file matched the expected contents, and exit status of 1 if it did not
|
||||
// contain the expected contents.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/FileCheck/FileCheck.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Process.h"
|
||||
#include "llvm/Support/WithColor.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
using namespace llvm;
|
||||
|
||||
static cl::extrahelp FileCheckOptsEnv(
|
||||
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
|
||||
"from the command line.\n");
|
||||
|
||||
static cl::opt<std::string>
|
||||
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
|
||||
|
||||
static cl::opt<std::string>
|
||||
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
|
||||
cl::init("-"), cl::value_desc("filename"));
|
||||
|
||||
static cl::list<std::string> CheckPrefixes(
|
||||
"check-prefix",
|
||||
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
|
||||
static cl::alias CheckPrefixesAlias(
|
||||
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
|
||||
cl::NotHidden,
|
||||
cl::desc(
|
||||
"Alias for -check-prefix permitting multiple comma separated values"));
|
||||
|
||||
static cl::list<std::string> CommentPrefixes(
|
||||
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
|
||||
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
|
||||
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
|
||||
"LLVM's LIT-based test suites, which should be easier to\n"
|
||||
"maintain if they all follow a consistent comment style. This\n"
|
||||
"feature is meant for non-LIT test suites using FileCheck."));
|
||||
|
||||
static cl::opt<bool> NoCanonicalizeWhiteSpace(
|
||||
"strict-whitespace",
|
||||
cl::desc("Do not treat all horizontal whitespace as equivalent"));
|
||||
|
||||
static cl::opt<bool> IgnoreCase("ignore-case",
|
||||
cl::desc("Use case-insensitive matching"));
|
||||
|
||||
static cl::list<std::string> ImplicitCheckNot(
|
||||
"implicit-check-not",
|
||||
cl::desc("Add an implicit negative check with this pattern to every\n"
|
||||
"positive check. This can be used to ensure that no instances of\n"
|
||||
"this pattern occur which are not matched by a positive pattern"),
|
||||
cl::value_desc("pattern"));
|
||||
|
||||
static cl::list<std::string>
|
||||
GlobalDefines("D", cl::AlwaysPrefix,
|
||||
cl::desc("Define a variable to be used in capture patterns."),
|
||||
cl::value_desc("VAR=VALUE"));
|
||||
|
||||
static cl::opt<bool> AllowEmptyInput(
|
||||
"allow-empty", cl::init(false),
|
||||
cl::desc("Allow the input file to be empty. This is useful when making\n"
|
||||
"checks that some error message does not occur, for example."));
|
||||
|
||||
static cl::opt<bool> AllowUnusedPrefixes(
|
||||
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Allow prefixes to be specified but not appear in the test."));
|
||||
|
||||
static cl::opt<bool> MatchFullLines(
|
||||
"match-full-lines", cl::init(false),
|
||||
cl::desc("Require all positive matches to cover an entire input line.\n"
|
||||
"Allows leading and trailing whitespace if --strict-whitespace\n"
|
||||
"is not also passed."));
|
||||
|
||||
static cl::opt<bool> EnableVarScope(
|
||||
"enable-var-scope", cl::init(false),
|
||||
cl::desc("Enables scope for regex variables. Variables with names that\n"
|
||||
"do not start with '$' will be reset at the beginning of\n"
|
||||
"each CHECK-LABEL block."));
|
||||
|
||||
static cl::opt<bool> AllowDeprecatedDagOverlap(
|
||||
"allow-deprecated-dag-overlap", cl::init(false),
|
||||
cl::desc("Enable overlapping among matches in a group of consecutive\n"
|
||||
"CHECK-DAG directives. This option is deprecated and is only\n"
|
||||
"provided for convenience as old tests are migrated to the new\n"
|
||||
"non-overlapping CHECK-DAG implementation.\n"));
|
||||
|
||||
static cl::opt<bool> Verbose(
|
||||
"v", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Print directive pattern matches, or add them to the input dump\n"
|
||||
"if enabled.\n"));
|
||||
|
||||
static cl::opt<bool> VerboseVerbose(
|
||||
"vv", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
|
||||
"issues, or add it to the input dump if enabled. Implies\n"
|
||||
"-v.\n"));
|
||||
|
||||
// The order of DumpInputValue members affects their precedence, as documented
|
||||
// for -dump-input below.
|
||||
enum DumpInputValue {
|
||||
DumpInputNever,
|
||||
DumpInputFail,
|
||||
DumpInputAlways,
|
||||
DumpInputHelp
|
||||
};
|
||||
|
||||
static cl::list<DumpInputValue> DumpInputs(
|
||||
"dump-input",
|
||||
cl::desc("Dump input to stderr, adding annotations representing\n"
|
||||
"currently enabled diagnostics. When there are multiple\n"
|
||||
"occurrences of this option, the <value> that appears earliest\n"
|
||||
"in the list below has precedence. The default is 'fail'.\n"),
|
||||
cl::value_desc("mode"),
|
||||
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
|
||||
clEnumValN(DumpInputAlways, "always", "Always dump input"),
|
||||
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
|
||||
clEnumValN(DumpInputNever, "never", "Never dump input")));
|
||||
|
||||
// The order of DumpInputFilterValue members affects their precedence, as
|
||||
// documented for -dump-input-filter below.
|
||||
enum DumpInputFilterValue {
|
||||
DumpInputFilterError,
|
||||
DumpInputFilterAnnotation,
|
||||
DumpInputFilterAnnotationFull,
|
||||
DumpInputFilterAll
|
||||
};
|
||||
|
||||
static cl::list<DumpInputFilterValue> DumpInputFilters(
|
||||
"dump-input-filter",
|
||||
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
|
||||
"kind <value> plus any context specified by -dump-input-context.\n"
|
||||
"When there are multiple occurrences of this option, the <value>\n"
|
||||
"that appears earliest in the list below has precedence. The\n"
|
||||
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
|
||||
"-dump-input=always.\n"),
|
||||
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
|
||||
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
|
||||
"Input lines with annotations"),
|
||||
clEnumValN(DumpInputFilterAnnotation, "annotation",
|
||||
"Input lines with starting points of annotations"),
|
||||
clEnumValN(DumpInputFilterError, "error",
|
||||
"Input lines with starting points of error "
|
||||
"annotations")));
|
||||
|
||||
static cl::list<unsigned> DumpInputContexts(
|
||||
"dump-input-context", cl::value_desc("N"),
|
||||
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
|
||||
"before and <N> input lines after any lines specified by\n"
|
||||
"-dump-input-filter. When there are multiple occurrences of\n"
|
||||
"this option, the largest specified <N> has precedence. The\n"
|
||||
"default is 5.\n"));
|
||||
|
||||
typedef cl::list<std::string>::const_iterator prefix_iterator;
|
||||
|
||||
static void DumpCommandLine(int argc, char **argv) {
|
||||
errs() << "FileCheck command line: ";
|
||||
for (int I = 0; I < argc; I++)
|
||||
errs() << " " << argv[I];
|
||||
errs() << "\n";
|
||||
}
|
||||
|
||||
struct MarkerStyle {
|
||||
/// The starting char (before tildes) for marking the line.
|
||||
char Lead;
|
||||
/// What color to use for this annotation.
|
||||
raw_ostream::Colors Color;
|
||||
/// A note to follow the marker, or empty string if none.
|
||||
std::string Note;
|
||||
/// Does this marker indicate inclusion by -dump-input-filter=error?
|
||||
bool FiltersAsError;
|
||||
MarkerStyle() {}
|
||||
MarkerStyle(char Lead, raw_ostream::Colors Color,
|
||||
const std::string &Note = "", bool FiltersAsError = false)
|
||||
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
|
||||
assert((!FiltersAsError || !Note.empty()) &&
|
||||
"expected error diagnostic to have note");
|
||||
}
|
||||
};
|
||||
|
||||
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
|
||||
switch (MatchTy) {
|
||||
case FileCheckDiag::MatchFoundAndExpected:
|
||||
return MarkerStyle('^', raw_ostream::GREEN);
|
||||
case FileCheckDiag::MatchFoundButExcluded:
|
||||
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFoundButWrongLine:
|
||||
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFoundButDiscarded:
|
||||
return MarkerStyle('!', raw_ostream::CYAN,
|
||||
"discard: overlaps earlier match");
|
||||
case FileCheckDiag::MatchFoundErrorNote:
|
||||
// Note should always be overridden within the FileCheckDiag.
|
||||
return MarkerStyle('!', raw_ostream::RED,
|
||||
"error: unknown error after match",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchNoneAndExcluded:
|
||||
return MarkerStyle('X', raw_ostream::GREEN);
|
||||
case FileCheckDiag::MatchNoneButExpected:
|
||||
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchNoneForInvalidPattern:
|
||||
return MarkerStyle('X', raw_ostream::RED,
|
||||
"error: match failed for invalid pattern",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFuzzy:
|
||||
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
|
||||
/*FiltersAsError=*/true);
|
||||
}
|
||||
llvm_unreachable_internal("unexpected match type");
|
||||
}
|
||||
|
||||
static void DumpInputAnnotationHelp(raw_ostream &OS) {
|
||||
OS << "The following description was requested by -dump-input=help to\n"
|
||||
<< "explain the input dump printed by FileCheck.\n"
|
||||
<< "\n"
|
||||
<< "Related command-line options:\n"
|
||||
<< "\n"
|
||||
<< " - -dump-input=<value> enables or disables the input dump\n"
|
||||
<< " - -dump-input-filter=<value> filters the input lines\n"
|
||||
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
|
||||
<< " - -v and -vv add more annotations\n"
|
||||
<< " - -color forces colors to be enabled both in the dump and below\n"
|
||||
<< " - -help documents the above options in more detail\n"
|
||||
<< "\n"
|
||||
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
|
||||
<< "maximum debugging output on failures:\n"
|
||||
<< "\n"
|
||||
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
|
||||
<< "\n"
|
||||
<< "Input dump annotation format:\n"
|
||||
<< "\n";
|
||||
|
||||
// Labels for input lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
|
||||
OS << " labels line number L of the input file\n"
|
||||
<< " An extra space is added after each input line to represent"
|
||||
<< " the\n"
|
||||
<< " newline character\n";
|
||||
|
||||
// Labels for annotation lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
|
||||
OS << " labels the only match result for either (1) a pattern of type T"
|
||||
<< " from\n"
|
||||
<< " line L of the check file if L is an integer or (2) the"
|
||||
<< " I-th implicit\n"
|
||||
<< " pattern if L is \"imp\" followed by an integer "
|
||||
<< "I (index origin one)\n";
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
|
||||
OS << " labels the Nth match result for such a pattern\n";
|
||||
|
||||
// Markers on annotation lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
|
||||
OS << " marks good match (reported if -v)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
|
||||
OS << " marks bad match, such as:\n"
|
||||
<< " - CHECK-NEXT on same line as previous match (error)\n"
|
||||
<< " - CHECK-NOT found (error)\n"
|
||||
<< " - CHECK-DAG overlapping match (discarded, reported if "
|
||||
<< "-vv)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
|
||||
OS << " marks search range when no match is found, such as:\n"
|
||||
<< " - CHECK-NEXT not found (error)\n"
|
||||
<< " - CHECK-NOT not found (success, reported if -vv)\n"
|
||||
<< " - CHECK-DAG not found after discarded matches (error)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
|
||||
OS << " marks fuzzy match when no match is found\n";
|
||||
|
||||
// Elided lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
|
||||
OS << " indicates elided input lines and annotations, as specified by\n"
|
||||
<< " -dump-input-filter and -dump-input-context\n";
|
||||
|
||||
// Colors.
|
||||
OS << " - colors ";
|
||||
WithColor(OS, raw_ostream::GREEN, true) << "success";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::RED, true) << "error";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
|
||||
OS << "\n";
|
||||
}
|
||||
|
||||
/// An annotation for a single input line.
|
||||
struct InputAnnotation {
|
||||
/// The index of the match result across all checks
|
||||
unsigned DiagIndex;
|
||||
/// The label for this annotation.
|
||||
std::string Label;
|
||||
/// Is this the initial fragment of a diagnostic that has been broken across
|
||||
/// multiple lines?
|
||||
bool IsFirstLine;
|
||||
/// What input line (one-origin indexing) this annotation marks. This might
|
||||
/// be different from the starting line of the original diagnostic if
|
||||
/// !IsFirstLine.
|
||||
unsigned InputLine;
|
||||
/// The column range (one-origin indexing, open end) in which to mark the
|
||||
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
|
||||
/// before the newline.
|
||||
unsigned InputStartCol, InputEndCol;
|
||||
/// The marker to use.
|
||||
MarkerStyle Marker;
|
||||
/// Whether this annotation represents a good match for an expected pattern.
|
||||
bool FoundAndExpectedMatch;
|
||||
};
|
||||
|
||||
/// Get an abbreviation for the check type.
|
||||
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
|
||||
switch (Ty) {
|
||||
case Check::CheckPlain:
|
||||
if (Ty.getCount() > 1)
|
||||
return "count";
|
||||
return "check";
|
||||
case Check::CheckNext:
|
||||
return "next";
|
||||
case Check::CheckSame:
|
||||
return "same";
|
||||
case Check::CheckNot:
|
||||
return "not";
|
||||
case Check::CheckDAG:
|
||||
return "dag";
|
||||
case Check::CheckLabel:
|
||||
return "label";
|
||||
case Check::CheckEmpty:
|
||||
return "empty";
|
||||
case Check::CheckComment:
|
||||
return "com";
|
||||
case Check::CheckEOF:
|
||||
return "eof";
|
||||
case Check::CheckBadNot:
|
||||
return "bad-not";
|
||||
case Check::CheckBadCount:
|
||||
return "bad-count";
|
||||
case Check::CheckNone:
|
||||
llvm_unreachable("invalid FileCheckType");
|
||||
}
|
||||
llvm_unreachable("unknown FileCheckType");
|
||||
}
|
||||
|
||||
static void
|
||||
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
|
||||
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
|
||||
const std::vector<FileCheckDiag> &Diags,
|
||||
std::vector<InputAnnotation> &Annotations,
|
||||
unsigned &LabelWidth) {
|
||||
struct CompareSMLoc {
|
||||
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
|
||||
return LHS.getPointer() < RHS.getPointer();
|
||||
}
|
||||
};
|
||||
// How many diagnostics does each pattern have?
|
||||
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
|
||||
for (auto Diag : Diags)
|
||||
++DiagCountPerPattern[Diag.CheckLoc];
|
||||
// How many diagnostics have we seen so far per pattern?
|
||||
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
|
||||
// How many total diagnostics have we seen so far?
|
||||
unsigned DiagIndex = 0;
|
||||
// What's the widest label?
|
||||
LabelWidth = 0;
|
||||
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
|
||||
++DiagItr) {
|
||||
InputAnnotation A;
|
||||
A.DiagIndex = DiagIndex++;
|
||||
|
||||
// Build label, which uniquely identifies this check result.
|
||||
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
|
||||
auto CheckLineAndCol =
|
||||
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
|
||||
llvm::raw_string_ostream Label(A.Label);
|
||||
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
|
||||
if (CheckBufferID == CheckFileBufferID)
|
||||
Label << CheckLineAndCol.first;
|
||||
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
|
||||
CheckBufferID < ImpPatBufferIDRange.second)
|
||||
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
|
||||
else
|
||||
llvm_unreachable("expected diagnostic's check location to be either in "
|
||||
"the check file or for an implicit pattern");
|
||||
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
|
||||
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
|
||||
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
|
||||
|
||||
A.Marker = GetMarker(DiagItr->MatchTy);
|
||||
if (!DiagItr->Note.empty()) {
|
||||
A.Marker.Note = DiagItr->Note;
|
||||
// It's less confusing if notes that don't actually have ranges don't have
|
||||
// markers. For example, a marker for 'with "VAR" equal to "5"' would
|
||||
// seem to indicate where "VAR" matches, but the location we actually have
|
||||
// for the marker simply points to the start of the match/search range for
|
||||
// the full pattern of which the substitution is potentially just one
|
||||
// component.
|
||||
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
|
||||
DiagItr->InputStartCol == DiagItr->InputEndCol)
|
||||
A.Marker.Lead = ' ';
|
||||
}
|
||||
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
|
||||
assert(!DiagItr->Note.empty() &&
|
||||
"expected custom note for MatchFoundErrorNote");
|
||||
A.Marker.Note = "error: " + A.Marker.Note;
|
||||
}
|
||||
A.FoundAndExpectedMatch =
|
||||
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
|
||||
|
||||
// Compute the mark location, and break annotation into multiple
|
||||
// annotations if it spans multiple lines.
|
||||
A.IsFirstLine = true;
|
||||
A.InputLine = DiagItr->InputStartLine;
|
||||
A.InputStartCol = DiagItr->InputStartCol;
|
||||
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
|
||||
// Sometimes ranges are empty in order to indicate a specific point, but
|
||||
// that would mean nothing would be marked, so adjust the range to
|
||||
// include the following character.
|
||||
A.InputEndCol =
|
||||
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
|
||||
Annotations.push_back(A);
|
||||
} else {
|
||||
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
|
||||
"expected input range not to be inverted");
|
||||
A.InputEndCol = UINT_MAX;
|
||||
Annotations.push_back(A);
|
||||
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
|
||||
L <= E; ++L) {
|
||||
// If a range ends before the first column on a line, then it has no
|
||||
// characters on that line, so there's nothing to render.
|
||||
if (DiagItr->InputEndCol == 1 && L == E)
|
||||
break;
|
||||
InputAnnotation B;
|
||||
B.DiagIndex = A.DiagIndex;
|
||||
B.Label = A.Label;
|
||||
B.IsFirstLine = false;
|
||||
B.InputLine = L;
|
||||
B.Marker = A.Marker;
|
||||
B.Marker.Lead = '~';
|
||||
B.Marker.Note = "";
|
||||
B.InputStartCol = 1;
|
||||
if (L != E)
|
||||
B.InputEndCol = UINT_MAX;
|
||||
else
|
||||
B.InputEndCol = DiagItr->InputEndCol;
|
||||
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
|
||||
Annotations.push_back(B);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned FindInputLineInFilter(
|
||||
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
|
||||
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
|
||||
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
|
||||
if (DumpInputFilter == DumpInputFilterAll)
|
||||
return CurInputLine;
|
||||
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
|
||||
++AnnotationItr) {
|
||||
switch (DumpInputFilter) {
|
||||
case DumpInputFilterAll:
|
||||
llvm_unreachable("unexpected DumpInputFilterAll");
|
||||
break;
|
||||
case DumpInputFilterAnnotationFull:
|
||||
return AnnotationItr->InputLine;
|
||||
case DumpInputFilterAnnotation:
|
||||
if (AnnotationItr->IsFirstLine)
|
||||
return AnnotationItr->InputLine;
|
||||
break;
|
||||
case DumpInputFilterError:
|
||||
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
|
||||
return AnnotationItr->InputLine;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return UINT_MAX;
|
||||
}
|
||||
|
||||
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
|
||||
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
|
||||
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
|
||||
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
|
||||
unsigned LabelWidth) {
|
||||
if (ElidedLines.empty())
|
||||
return;
|
||||
unsigned EllipsisLines = 3;
|
||||
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
|
||||
for (unsigned i = 0; i < EllipsisLines; ++i) {
|
||||
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
|
||||
<< right_justify(".", LabelWidth);
|
||||
OS << '\n';
|
||||
}
|
||||
} else
|
||||
OS << ElidedLines;
|
||||
ElidedLines.clear();
|
||||
}
|
||||
|
||||
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
|
||||
DumpInputFilterValue DumpInputFilter,
|
||||
unsigned DumpInputContext,
|
||||
StringRef InputFileText,
|
||||
std::vector<InputAnnotation> &Annotations,
|
||||
unsigned LabelWidth) {
|
||||
OS << "Input was:\n<<<<<<\n";
|
||||
|
||||
// Sort annotations.
|
||||
llvm::sort(Annotations,
|
||||
[](const InputAnnotation &A, const InputAnnotation &B) {
|
||||
// 1. Sort annotations in the order of the input lines.
|
||||
//
|
||||
// This makes it easier to find relevant annotations while
|
||||
// iterating input lines in the implementation below. FileCheck
|
||||
// does not always produce diagnostics in the order of input
|
||||
// lines due to, for example, CHECK-DAG and CHECK-NOT.
|
||||
if (A.InputLine != B.InputLine)
|
||||
return A.InputLine < B.InputLine;
|
||||
// 2. Sort annotations in the temporal order FileCheck produced
|
||||
// their associated diagnostics.
|
||||
//
|
||||
// This sort offers several benefits:
|
||||
//
|
||||
// A. On a single input line, the order of annotations reflects
|
||||
// the FileCheck logic for processing directives/patterns.
|
||||
// This can be helpful in understanding cases in which the
|
||||
// order of the associated directives/patterns in the check
|
||||
// file or on the command line either (i) does not match the
|
||||
// temporal order in which FileCheck looks for matches for the
|
||||
// directives/patterns (due to, for example, CHECK-LABEL,
|
||||
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
|
||||
// that order but does not match the order of those
|
||||
// diagnostics along an input line (due to, for example,
|
||||
// CHECK-DAG).
|
||||
//
|
||||
// On the other hand, because our presentation format presents
|
||||
// input lines in order, there's no clear way to offer the
|
||||
// same benefit across input lines. For consistency, it might
|
||||
// then seem worthwhile to have annotations on a single line
|
||||
// also sorted in input order (that is, by input column).
|
||||
// However, in practice, this appears to be more confusing
|
||||
// than helpful. Perhaps it's intuitive to expect annotations
|
||||
// to be listed in the temporal order in which they were
|
||||
// produced except in cases the presentation format obviously
|
||||
// and inherently cannot support it (that is, across input
|
||||
// lines).
|
||||
//
|
||||
// B. When diagnostics' annotations are split among multiple
|
||||
// input lines, the user must track them from one input line
|
||||
// to the next. One property of the sort chosen here is that
|
||||
// it facilitates the user in this regard by ensuring the
|
||||
// following: when comparing any two input lines, a
|
||||
// diagnostic's annotations are sorted in the same position
|
||||
// relative to all other diagnostics' annotations.
|
||||
return A.DiagIndex < B.DiagIndex;
|
||||
});
|
||||
|
||||
// Compute the width of the label column.
|
||||
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
|
||||
*InputFileEnd = InputFileText.bytes_end();
|
||||
unsigned LineCount = InputFileText.count('\n');
|
||||
if (InputFileEnd[-1] != '\n')
|
||||
++LineCount;
|
||||
unsigned LineNoWidth = std::log10(LineCount) + 1;
|
||||
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
|
||||
// on input lines and (2) to the right of the (left-aligned) labels on
|
||||
// annotation lines so that input lines and annotation lines are more
|
||||
// visually distinct. For example, the spaces on the annotation lines ensure
|
||||
// that input line numbers and check directive line numbers never align
|
||||
// horizontally. Those line numbers might not even be for the same file.
|
||||
// One space would be enough to achieve that, but more makes it even easier
|
||||
// to see.
|
||||
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
|
||||
|
||||
// Print annotated input lines.
|
||||
unsigned PrevLineInFilter = 0; // 0 means none so far
|
||||
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
|
||||
std::string ElidedLines;
|
||||
raw_string_ostream ElidedLinesOS(ElidedLines);
|
||||
ColorMode TheColorMode =
|
||||
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
|
||||
if (TheColorMode == ColorMode::Enable)
|
||||
ElidedLinesOS.enable_colors(true);
|
||||
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
|
||||
for (unsigned Line = 1;
|
||||
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
|
||||
const unsigned char *InputFileLine = InputFilePtr;
|
||||
|
||||
// Compute the previous and next line included by the filter.
|
||||
if (NextLineInFilter < Line)
|
||||
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
|
||||
AnnotationItr, AnnotationEnd);
|
||||
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
|
||||
if (NextLineInFilter == Line)
|
||||
PrevLineInFilter = Line;
|
||||
|
||||
// Elide this input line and its annotations if it's not within the
|
||||
// context specified by -dump-input-context of an input line included by
|
||||
// -dump-input-filter. However, in case the resulting ellipsis would occupy
|
||||
// more lines than the input lines and annotations it elides, buffer the
|
||||
// elided lines and annotations so we can print them instead.
|
||||
raw_ostream *LineOS = &OS;
|
||||
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
|
||||
(NextLineInFilter == UINT_MAX ||
|
||||
Line + DumpInputContext < NextLineInFilter))
|
||||
LineOS = &ElidedLinesOS;
|
||||
else {
|
||||
LineOS = &OS;
|
||||
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
|
||||
}
|
||||
|
||||
// Print right-aligned line number.
|
||||
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
|
||||
TheColorMode)
|
||||
<< format_decimal(Line, LabelWidth) << ": ";
|
||||
|
||||
// For the case where -v and colors are enabled, find the annotations for
|
||||
// good matches for expected patterns in order to highlight everything
|
||||
// else in the line. There are no such annotations if -v is disabled.
|
||||
std::vector<InputAnnotation> FoundAndExpectedMatches;
|
||||
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
|
||||
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
|
||||
++I) {
|
||||
if (I->FoundAndExpectedMatch)
|
||||
FoundAndExpectedMatches.push_back(*I);
|
||||
}
|
||||
}
|
||||
|
||||
// Print numbered line with highlighting where there are no matches for
|
||||
// expected patterns.
|
||||
bool Newline = false;
|
||||
{
|
||||
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
|
||||
/*BG=*/false, TheColorMode);
|
||||
bool InMatch = false;
|
||||
if (Req.Verbose)
|
||||
COS.changeColor(raw_ostream::CYAN, true, true);
|
||||
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
|
||||
bool WasInMatch = InMatch;
|
||||
InMatch = false;
|
||||
for (auto M : FoundAndExpectedMatches) {
|
||||
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
|
||||
InMatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!WasInMatch && InMatch)
|
||||
COS.resetColor();
|
||||
else if (WasInMatch && !InMatch)
|
||||
COS.changeColor(raw_ostream::CYAN, true, true);
|
||||
if (*InputFilePtr == '\n') {
|
||||
Newline = true;
|
||||
COS << ' ';
|
||||
} else
|
||||
COS << *InputFilePtr;
|
||||
++InputFilePtr;
|
||||
}
|
||||
}
|
||||
*LineOS << '\n';
|
||||
unsigned InputLineWidth = InputFilePtr - InputFileLine;
|
||||
|
||||
// Print any annotations.
|
||||
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
|
||||
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
|
||||
/*BG=*/false, TheColorMode);
|
||||
// The two spaces below are where the ": " appears on input lines.
|
||||
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
|
||||
unsigned Col;
|
||||
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
|
||||
COS << ' ';
|
||||
COS << AnnotationItr->Marker.Lead;
|
||||
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
|
||||
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
|
||||
++Col)
|
||||
COS << '~';
|
||||
const std::string &Note = AnnotationItr->Marker.Note;
|
||||
if (!Note.empty()) {
|
||||
// Put the note at the end of the input line. If we were to instead
|
||||
// put the note right after the marker, subsequent annotations for the
|
||||
// same input line might appear to mark this note instead of the input
|
||||
// line.
|
||||
for (; Col <= InputLineWidth; ++Col)
|
||||
COS << ' ';
|
||||
COS << ' ' << Note;
|
||||
}
|
||||
COS << '\n';
|
||||
++AnnotationItr;
|
||||
}
|
||||
}
|
||||
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
|
||||
|
||||
OS << ">>>>>>\n";
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Enable use of ANSI color codes because FileCheck is using them to
|
||||
// highlight text.
|
||||
llvm::sys::Process::UseANSIEscapeCodes(true);
|
||||
|
||||
InitLLVM X(argc, argv);
|
||||
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
|
||||
"FILECHECK_OPTS");
|
||||
|
||||
// Select -dump-input* values. The -help documentation specifies the default
|
||||
// value and which value to choose if an option is specified multiple times.
|
||||
// In the latter case, the general rule of thumb is to choose the value that
|
||||
// provides the most information.
|
||||
DumpInputValue DumpInput =
|
||||
DumpInputs.empty()
|
||||
? DumpInputFail
|
||||
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
|
||||
DumpInputFilterValue DumpInputFilter;
|
||||
if (DumpInputFilters.empty())
|
||||
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
|
||||
: DumpInputFilterError;
|
||||
else
|
||||
DumpInputFilter =
|
||||
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
|
||||
unsigned DumpInputContext = DumpInputContexts.empty()
|
||||
? 5
|
||||
: *std::max_element(DumpInputContexts.begin(),
|
||||
DumpInputContexts.end());
|
||||
|
||||
if (DumpInput == DumpInputHelp) {
|
||||
DumpInputAnnotationHelp(outs());
|
||||
return 0;
|
||||
}
|
||||
if (CheckFilename.empty()) {
|
||||
errs() << "<check-file> not specified\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
FileCheckRequest Req;
|
||||
append_range(Req.CheckPrefixes, CheckPrefixes);
|
||||
|
||||
append_range(Req.CommentPrefixes, CommentPrefixes);
|
||||
|
||||
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
|
||||
|
||||
bool GlobalDefineError = false;
|
||||
for (StringRef G : GlobalDefines) {
|
||||
size_t EqIdx = G.find('=');
|
||||
if (EqIdx == std::string::npos) {
|
||||
errs() << "Missing equal sign in command-line definition '-D" << G
|
||||
<< "'\n";
|
||||
GlobalDefineError = true;
|
||||
continue;
|
||||
}
|
||||
if (EqIdx == 0) {
|
||||
errs() << "Missing variable name in command-line definition '-D" << G
|
||||
<< "'\n";
|
||||
GlobalDefineError = true;
|
||||
continue;
|
||||
}
|
||||
Req.GlobalDefines.push_back(G);
|
||||
}
|
||||
if (GlobalDefineError)
|
||||
return 2;
|
||||
|
||||
Req.AllowEmptyInput = AllowEmptyInput;
|
||||
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
|
||||
Req.EnableVarScope = EnableVarScope;
|
||||
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
|
||||
Req.Verbose = Verbose;
|
||||
Req.VerboseVerbose = VerboseVerbose;
|
||||
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
|
||||
Req.MatchFullLines = MatchFullLines;
|
||||
Req.IgnoreCase = IgnoreCase;
|
||||
|
||||
if (VerboseVerbose)
|
||||
Req.Verbose = true;
|
||||
|
||||
FileCheck FC(Req);
|
||||
if (!FC.ValidateCheckPrefixes())
|
||||
return 2;
|
||||
|
||||
Regex PrefixRE = FC.buildCheckPrefixRegex();
|
||||
std::string REError;
|
||||
if (!PrefixRE.isValid(REError)) {
|
||||
errs() << "Unable to combine check-prefix strings into a prefix regular "
|
||||
"expression! This is likely a bug in FileCheck's verification of "
|
||||
"the check-prefix strings. Regular expression parsing failed "
|
||||
"with the following error: "
|
||||
<< REError << "\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
SourceMgr SM;
|
||||
|
||||
// Read the expected strings from the check file.
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
|
||||
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
|
||||
if (std::error_code EC = CheckFileOrErr.getError()) {
|
||||
errs() << "Could not open check file '" << CheckFilename
|
||||
<< "': " << EC.message() << '\n';
|
||||
return 2;
|
||||
}
|
||||
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
|
||||
|
||||
SmallString<4096> CheckFileBuffer;
|
||||
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
|
||||
|
||||
unsigned CheckFileBufferID =
|
||||
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
|
||||
CheckFileText, CheckFile.getBufferIdentifier()),
|
||||
SMLoc());
|
||||
|
||||
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
|
||||
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
|
||||
return 2;
|
||||
|
||||
// Open the file to check and add it to SourceMgr.
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
|
||||
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
|
||||
if (InputFilename == "-")
|
||||
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
|
||||
if (std::error_code EC = InputFileOrErr.getError()) {
|
||||
errs() << "Could not open input file '" << InputFilename
|
||||
<< "': " << EC.message() << '\n';
|
||||
return 2;
|
||||
}
|
||||
MemoryBuffer &InputFile = *InputFileOrErr.get();
|
||||
|
||||
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
|
||||
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
|
||||
DumpCommandLine(argc, argv);
|
||||
return 2;
|
||||
}
|
||||
|
||||
SmallString<4096> InputFileBuffer;
|
||||
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
|
||||
|
||||
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
|
||||
InputFileText, InputFile.getBufferIdentifier()),
|
||||
SMLoc());
|
||||
|
||||
std::vector<FileCheckDiag> Diags;
|
||||
int ExitCode = FC.checkInput(SM, InputFileText,
|
||||
DumpInput == DumpInputNever ? nullptr : &Diags)
|
||||
? EXIT_SUCCESS
|
||||
: 1;
|
||||
if (DumpInput == DumpInputAlways ||
|
||||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
|
||||
errs() << "\n"
|
||||
<< "Input file: " << InputFilename << "\n"
|
||||
<< "Check file: " << CheckFilename << "\n"
|
||||
<< "\n"
|
||||
<< "-dump-input=help explains the following input dump.\n"
|
||||
<< "\n";
|
||||
std::vector<InputAnnotation> Annotations;
|
||||
unsigned LabelWidth;
|
||||
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
|
||||
Annotations, LabelWidth);
|
||||
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
|
||||
InputFileText, Annotations, LabelWidth);
|
||||
}
|
||||
|
||||
return ExitCode;
|
||||
}
|
42
bin/triton-opt.cpp
Normal file
42
bin/triton-opt.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Conversion/Passes.h"
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAliasPass();
|
||||
void registerTestAlignmentPass();
|
||||
void registerTestAllocationPass();
|
||||
void registerTestMembarPass();
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::registerTritonPasses();
|
||||
mlir::registerTritonGPUPasses();
|
||||
mlir::test::registerTestAliasPass();
|
||||
mlir::test::registerTestAlignmentPass();
|
||||
mlir::test::registerTestAllocationPass();
|
||||
mlir::test::registerTestMembarPass();
|
||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
||||
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
||||
|
||||
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
||||
argc, argv, "Triton (GPU) optimizer driver\n", registry));
|
||||
}
|
131
bin/triton-translate.cpp
Normal file
131
bin/triton-translate.cpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||
MLIRContext &context) {
|
||||
std::string errorMessage;
|
||||
auto input = openInputFile(inputFilename, &errorMessage);
|
||||
if (!input) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, arith::ArithmeticDialect,
|
||||
StandardOpsDialect, scf::SCFDialect>();
|
||||
|
||||
context.appendDialectRegistry(registry);
|
||||
|
||||
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
|
||||
-> OwningOpRef<ModuleOp> {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
||||
|
||||
context.loadAllAvailableDialects();
|
||||
context.allowUnregisteredDialects();
|
||||
|
||||
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
||||
if (!module) {
|
||||
llvm::errs() << "Parse MLIR file failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
};
|
||||
|
||||
auto module = processBuffer(std::move(input));
|
||||
if (!module) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::StringRef toolName) {
|
||||
static llvm::cl::opt<std::string> inputFilename(
|
||||
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
static llvm::cl::opt<std::string> outputFilename(
|
||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
static llvm::cl::opt<std::string> targetKind(
|
||||
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
|
||||
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
|
||||
|
||||
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
|
||||
llvm::cl::init(80));
|
||||
|
||||
static llvm::cl::opt<int> ptxVersion(
|
||||
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
registerAsmPrinterCLOptions();
|
||||
registerMLIRContextCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadMLIRModule(inputFilename, context);
|
||||
if (!module) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::string errorMessage;
|
||||
auto output = openOutputFile(outputFilename, &errorMessage);
|
||||
if (!output) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmir =
|
||||
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
}
|
||||
|
||||
if (targetKind == "llvmir")
|
||||
llvm::outs() << *llvmir << '\n';
|
||||
else if (targetKind == "ptx")
|
||||
llvm::outs() << ::triton::driver::llir_to_ptx(
|
||||
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
return failed(mlir::triton::tritonTranslateMain(
|
||||
argc, argv, "Triton Translate Testing Tool."));
|
||||
}
|
@@ -25,7 +25,7 @@
|
||||
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
|
||||
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
|
||||
#
|
||||
# Note: The variable names were chosen in conformance with the offical CMake
|
||||
# Note: The variable names were chosen in conformance with the official CMake
|
||||
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
|
||||
|
||||
# Try suffixed versions to pick up the newest LLVM install available on Debian
|
||||
|
27
docs/_templates/versions.html
vendored
Normal file
27
docs/_templates/versions.html
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
{%- if current_version %}
|
||||
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
|
||||
<span class="rst-current-version" data-toggle="rst-current-version">
|
||||
<span class="fa fa-book"> Other Versions</span>
|
||||
v: {{ current_version.name }}
|
||||
<span class="fa fa-caret-down"></span>
|
||||
</span>
|
||||
<div class="rst-other-versions">
|
||||
{%- if versions.tags %}
|
||||
<dl>
|
||||
<dt>Tags</dt>
|
||||
{%- for item in versions.tags %}
|
||||
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
|
||||
{%- endfor %}
|
||||
</dl>
|
||||
{%- endif %}
|
||||
{%- if versions.branches %}
|
||||
<dl>
|
||||
<dt>Branches</dt>
|
||||
{%- for item in versions.branches %}
|
||||
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
|
||||
{%- endfor %}
|
||||
</dl>
|
||||
{%- endif %}
|
||||
</div>
|
||||
</div>
|
||||
{%- endif %}
|
34
docs/conf.py
34
docs/conf.py
@@ -24,24 +24,38 @@
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
def process_sig(app, what, name, obj, options, signature, return_annotation):
|
||||
if signature and '_builder' in signature:
|
||||
signature = signature.split('_builder')[0] + ")"
|
||||
return (signature, return_annotation)
|
||||
|
||||
def setup(app):
|
||||
"""Customize function args retrieving to get args under decorator."""
|
||||
import sphinx
|
||||
import triton
|
||||
import os
|
||||
|
||||
app.connect("autodoc-process-signature", process_sig)
|
||||
os.system("pip install -e ../python")
|
||||
|
||||
|
||||
def forward_jit_fn(func):
|
||||
old = func
|
||||
|
||||
def wrapped(obj, **kwargs):
|
||||
import triton
|
||||
if isinstance(obj, triton.code_gen.JITFunction):
|
||||
obj = obj.fn
|
||||
return old(obj)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
old_documenter = sphinx.ext.autosummary.get_documenter
|
||||
|
||||
def documenter(app, obj, parent):
|
||||
import triton
|
||||
if isinstance(obj, triton.code_gen.JITFunction):
|
||||
obj = obj.fn
|
||||
return old_documenter(app, obj, parent)
|
||||
@@ -56,9 +70,17 @@ def setup(app):
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.abspath('../python/'))
|
||||
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
|
||||
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion']
|
||||
autosummary_generate = True
|
||||
|
||||
# versioning config
|
||||
smv_tag_whitelist = r'^(v1.1.2)$'
|
||||
smv_branch_whitelist = r'^master$'
|
||||
smv_remote_whitelist = None
|
||||
smv_released_pattern = r'^tags/.*$'
|
||||
smv_outputdir_format = '{ref.name}'
|
||||
smv_prefer_remote_refs = False
|
||||
|
||||
# Sphinx gallery
|
||||
extensions += ['sphinx_gallery.gen_gallery']
|
||||
from sphinx_gallery.sorting import FileNameSortKey
|
||||
@@ -68,10 +90,18 @@ sphinx_gallery_conf = {
|
||||
'filename_pattern': '',
|
||||
'ignore_pattern': r'__init__\.py',
|
||||
'within_subsection_order': FileNameSortKey,
|
||||
'reference_url': {
|
||||
'sphinx_gallery': None,
|
||||
}
|
||||
}
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
html_sidebars = {
|
||||
'**': [
|
||||
'_templates/versions.html',
|
||||
],
|
||||
}
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
|
@@ -8,6 +8,8 @@ Binary Distributions
|
||||
|
||||
You can install the latest stable release of Triton from pip:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install triton
|
||||
|
||||
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
|
||||
@@ -31,18 +33,19 @@ You can install the Python package from source by running the following commands
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/ptillet/triton.git;
|
||||
git clone https://github.com/openai/triton.git;
|
||||
cd triton/python;
|
||||
pip install cmake; # build time dependency
|
||||
pip install -e .
|
||||
|
||||
Note that, if llvm-11 is not present on your system, the setup.py script will download LLVM static libraries on the web and link against that.
|
||||
Note that, if llvm-11 is not present on your system, the setup.py script will download the official LLVM11 static libraries link against that.
|
||||
|
||||
You can then test your installation by running the unit tests:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pytest -vs .
|
||||
pip install -e '.[tests]'
|
||||
pytest -vs test/unit/
|
||||
|
||||
and the benchmarks
|
||||
|
||||
|
BIN
docs/getting-started/tutorials/grouped_vs_row_major_ordering.png
Normal file
BIN
docs/getting-started/tutorials/grouped_vs_row_major_ordering.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 465 KiB |
BIN
docs/getting-started/tutorials/random_bits.png
Normal file
BIN
docs/getting-started/tutorials/random_bits.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
@@ -1,7 +1,7 @@
|
||||
Welcome to Triton's documentation!
|
||||
==================================
|
||||
|
||||
Triton is an language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
Triton is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
@@ -2,7 +2,7 @@
|
||||
Related Work
|
||||
==============
|
||||
|
||||
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlights its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
|
||||
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlight its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
|
||||
|
||||
-----------------------
|
||||
Polyhedral Compilation
|
||||
@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
|
||||
Program Representation
|
||||
+++++++++++++++++++++++
|
||||
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
@@ -121,7 +121,7 @@ Limitations
|
||||
|
||||
Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks.
|
||||
|
||||
First, the set of possible program transformations $\Omega = \{ \Theta_S ~|~ S \in \text{program} \}$ is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_.
|
||||
First, the set of possible program transformations :math:`\Omega = \{ \Theta_S ~|~ S \in \text{program} \}` is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_.
|
||||
|
||||
Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [GIRBAL2006]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years.
|
||||
|
||||
@@ -131,7 +131,7 @@ On the other hand, blocked program representations advocated by this dissertatio
|
||||
Scheduling Languages
|
||||
-----------------------
|
||||
|
||||
Separation of concerns \cite{dijkstra82} is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
|
||||
Separation of concerns [DIJKSTRA82]_ is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
|
||||
|
||||
.. code-block:: python
|
||||
:linenos:
|
||||
@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse com-putations, whose iteration spaces may be irregular.
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
@@ -206,4 +206,5 @@ References
|
||||
.. [GROSSER2012] T. Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012
|
||||
.. [SATO2019] Y. Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019
|
||||
.. [GIRBAL2006] S. Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006
|
||||
.. [DIJKSTRA82] E. W. Dijkstra et al., "On the role of scientific thought", Selected writings on computing: a personal perspective 1982
|
||||
.. [MULLAPUDI2016] R. Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016
|
@@ -80,6 +80,9 @@ Math Ops
|
||||
|
||||
exp
|
||||
log
|
||||
cos
|
||||
sin
|
||||
sqrt
|
||||
sigmoid
|
||||
softmax
|
||||
|
||||
@@ -95,6 +98,18 @@ Reduction Ops
|
||||
min
|
||||
sum
|
||||
|
||||
Atomic Ops
|
||||
---------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
atomic_cas
|
||||
atomic_add
|
||||
atomic_max
|
||||
atomic_min
|
||||
|
||||
|
||||
Comparison ops
|
||||
---------------
|
||||
@@ -106,6 +121,19 @@ Comparison ops
|
||||
minimum
|
||||
maximum
|
||||
|
||||
.. _Random Number Generation:
|
||||
|
||||
Random Number Generation
|
||||
-------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
randint4x
|
||||
randint
|
||||
rand
|
||||
randn
|
||||
|
||||
Compiler Hint Ops
|
||||
-------------------
|
||||
|
@@ -8,3 +8,6 @@ triton
|
||||
:nosignatures:
|
||||
|
||||
jit
|
||||
autotune
|
||||
heuristics
|
||||
Config
|
1
include/CMakeLists.txt
Normal file
1
include/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(triton)
|
80
include/triton/Analysis/Alias.h
Normal file
80
include/triton/Analysis/Alias.h
Normal file
@@ -0,0 +1,80 @@
|
||||
#ifndef TRITON_ANALYSIS_ALIAS_H
|
||||
#define TRITON_ANALYSIS_ALIAS_H
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AliasInfo {
|
||||
public:
|
||||
AliasInfo() = default;
|
||||
AliasInfo(Value value) { insert(value); }
|
||||
|
||||
void insert(Value value) { allocs.insert(value); }
|
||||
|
||||
const DenseSet<Value> &getAllocs() const { return allocs; }
|
||||
|
||||
bool operator==(const AliasInfo &other) const {
|
||||
return allocs == other.allocs;
|
||||
}
|
||||
|
||||
/// The pessimistic value state of a value without alias
|
||||
static AliasInfo getPessimisticValueState(MLIRContext *context) {
|
||||
return AliasInfo();
|
||||
}
|
||||
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
|
||||
|
||||
/// The union of both arguments
|
||||
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The set of allocated values that are aliased by this lattice.
|
||||
/// For now, we only consider aliased value produced by the following
|
||||
/// situations:
|
||||
/// 1. values returned by scf.yield
|
||||
/// 2. block arguments in scf.for
|
||||
/// Example:
|
||||
/// alloc v1 alloc v2
|
||||
/// | |
|
||||
/// |--------------| |------------|
|
||||
/// scf.for v3 scf.for v4 scf.for v5
|
||||
/// |
|
||||
/// scf.yield v6
|
||||
///
|
||||
/// v1's alloc [v1]
|
||||
/// v2's alloc [v2]
|
||||
/// v3's alloc [v1]
|
||||
/// v4's alloc [v1, v2]
|
||||
/// v5's alloc [v2]
|
||||
/// v6's alloc [v1]
|
||||
///
|
||||
/// Therefore, v1's liveness range is the union of v3, v4, and v6
|
||||
/// v2's liveness range is the union of v4 and v5.
|
||||
DenseSet<Value> allocs;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Alias Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
||||
/// Given two values, returns their aliasing behavior.
|
||||
AliasResult alias(Value lhs, Value rhs);
|
||||
|
||||
/// Returns the modify-reference behavior of `op` on `location`.
|
||||
ModRefResult getModRef(Operation *op, Value location);
|
||||
|
||||
/// Computes if the alloc set of the results are changed.
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_ALIAS_H
|
192
include/triton/Analysis/Allocation.h
Normal file
192
include/triton/Analysis/Allocation.h
Normal file
@@ -0,0 +1,192 @@
|
||||
#ifndef TRITON_ANALYSIS_ALLOCATION_H
|
||||
#define TRITON_ANALYSIS_ALLOCATION_H
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class AllocationAnalysis;
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
/// A class that represents an interval, specified using a start and an end
|
||||
/// values: [Start, End).
|
||||
template <typename T> class Interval {
|
||||
public:
|
||||
Interval() {}
|
||||
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
|
||||
T start() const { return Start; }
|
||||
T end() const { return End; }
|
||||
T size() const { return End - Start; }
|
||||
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
|
||||
bool intersects(const Interval &R) const {
|
||||
return Start < R.End && R.Start < End;
|
||||
}
|
||||
bool operator==(const Interval &R) const {
|
||||
return Start == R.Start && End == R.End;
|
||||
}
|
||||
bool operator!=(const Interval &R) const { return !(*this == R); }
|
||||
bool operator<(const Interval &R) const {
|
||||
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
|
||||
}
|
||||
|
||||
private:
|
||||
T Start = std::numeric_limits<T>::min();
|
||||
T End = std::numeric_limits<T>::max();
|
||||
};
|
||||
|
||||
class Allocation {
|
||||
public:
|
||||
/// A unique identifier for shared memory buffers
|
||||
using BufferId = size_t;
|
||||
using BufferIdSetT = DenseSet<BufferId>;
|
||||
|
||||
static constexpr BufferId InvalidBufferId =
|
||||
std::numeric_limits<BufferId>::max();
|
||||
|
||||
/// Creates a new Allocation analysis that computes the shared memory
|
||||
/// information for all associated shared memory values.
|
||||
Allocation(Operation *operation) : operation(operation) { run(); }
|
||||
|
||||
/// Returns the operation this analysis was constructed from.
|
||||
Operation *getOperation() const { return operation; }
|
||||
|
||||
/// Returns the offset of the given buffer in the shared memory.
|
||||
size_t getOffset(BufferId bufferId) const {
|
||||
return bufferSet.lookup(bufferId).offset;
|
||||
}
|
||||
|
||||
/// Returns the size of the given buffer in the shared memory.
|
||||
size_t getAllocatedSize(BufferId bufferId) const {
|
||||
return bufferSet.lookup(bufferId).size;
|
||||
}
|
||||
|
||||
/// Returns the buffer id of the given value.
|
||||
/// This interface only returns the allocated buffer id.
|
||||
/// If you want to get all the buffer ids that are associated with the given
|
||||
/// value, including alias buffers, use getBufferIds.
|
||||
BufferId getBufferId(Value value) const {
|
||||
if (valueBuffer.count(value)) {
|
||||
return valueBuffer.lookup(value)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all the buffer ids of the given value, including alias buffers.
|
||||
BufferIdSetT getBufferIds(Value value) const {
|
||||
BufferIdSetT bufferIds;
|
||||
auto allocBufferId = getBufferId(value);
|
||||
if (allocBufferId != InvalidBufferId)
|
||||
bufferIds.insert(allocBufferId);
|
||||
for (auto *buffer : aliasBuffer.lookup(value)) {
|
||||
if (buffer->id != InvalidBufferId)
|
||||
bufferIds.insert(buffer->id);
|
||||
}
|
||||
return bufferIds;
|
||||
}
|
||||
|
||||
/// Returns the scratch buffer id of the given value.
|
||||
BufferId getBufferId(Operation *operation) const {
|
||||
if (opScratch.count(operation)) {
|
||||
return opScratch.lookup(operation)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of total shared memory allocated
|
||||
size_t getSharedMemorySize() const { return sharedMemorySize; }
|
||||
|
||||
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
|
||||
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
|
||||
return false;
|
||||
auto lhsBuffer = bufferSet.lookup(lhsId);
|
||||
auto rhsBuffer = bufferSet.lookup(rhsId);
|
||||
return lhsBuffer.intersects(rhsBuffer);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A class that represents a shared memory buffer
|
||||
struct BufferT {
|
||||
enum class BufferKind { Explicit, Scratch };
|
||||
|
||||
/// MT: thread-safe
|
||||
inline static std::atomic<BufferId> nextId = 0;
|
||||
|
||||
BufferKind kind;
|
||||
BufferId id;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
|
||||
bool operator==(const BufferT &other) const { return id == other.id; }
|
||||
bool operator<(const BufferT &other) const { return id < other.id; }
|
||||
|
||||
BufferT() : BufferT(BufferKind::Explicit) {}
|
||||
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||
|
||||
bool intersects(const BufferT &other) const {
|
||||
return Interval<size_t>(offset, offset + size)
|
||||
.intersects(
|
||||
Interval<size_t>(other.offset, other.offset + other.size));
|
||||
}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
|
||||
/// Value -> Explicit Buffer
|
||||
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
|
||||
/// Value -> Alias Buffer
|
||||
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
|
||||
/// BufferId -> Buffer
|
||||
using BufferSetT = DenseMap<BufferId, BufferT>;
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
void run();
|
||||
|
||||
private:
|
||||
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
|
||||
void addBuffer(KeyType &key, Args &&...args) {
|
||||
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
|
||||
bufferSet[buffer.id] = std::move(buffer);
|
||||
if constexpr (Kind == BufferT::BufferKind::Explicit) {
|
||||
valueBuffer[key] = &bufferSet[buffer.id];
|
||||
} else {
|
||||
opScratch[key] = &bufferSet[buffer.id];
|
||||
}
|
||||
}
|
||||
|
||||
void addAlias(Value value, Value alloc) {
|
||||
aliasBuffer[value].insert(valueBuffer[alloc]);
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
OpScratchMapT opScratch;
|
||||
ValueBufferMapT valueBuffer;
|
||||
AliasBufferMapT aliasBuffer;
|
||||
BufferSetT bufferSet;
|
||||
size_t sharedMemorySize = 0;
|
||||
|
||||
friend class triton::AllocationAnalysis;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_ALLOCATION_H
|
144
include/triton/Analysis/AxisInfo.h
Normal file
144
include/triton/Analysis/AxisInfo.h
Normal file
@@ -0,0 +1,144 @@
|
||||
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
||||
#define TRITON_ANALYSIS_AXISINFO_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
/// Axis information is represented by a std::map<int, int>
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef SmallVector<int, 4> DimVectorT;
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
AxisInfo() : AxisInfo({}, {}, {}) {}
|
||||
// Construct contiguity info with known contiguity
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == (size_t)rank);
|
||||
assert(knownConstancy.size() == (size_t)rank);
|
||||
}
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const DimVectorT &getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
const DimVectorT &getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const DimVectorT &getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
// Comparison
|
||||
bool operator==(const AxisInfo &other) const {
|
||||
return (contiguity == other.contiguity) &&
|
||||
(divisibility == other.divisibility) &&
|
||||
(constancy == other.constancy);
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context) {
|
||||
return AxisInfo();
|
||||
}
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of contiguous integers along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
/// Would have contiguity [1, 4].
|
||||
/// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
/// [18, 22, 26, 30]
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
DimVectorT contiguity;
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
// would have divisibility [1, 2]
|
||||
// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
DimVectorT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of constant integer along it. This is
|
||||
/// particularly useful to infer the contiguity
|
||||
/// of operations (e.g., add) involving a constant
|
||||
/// For example
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
DimVectorT constancy;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
};
|
||||
|
||||
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
|
||||
private:
|
||||
static const int maxPow2Divisor = 65536;
|
||||
|
||||
int highestPowOf2Divisor(int n) {
|
||||
if (n == 0)
|
||||
return maxPow2Divisor;
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
AxisInfo visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
|
||||
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
unsigned getPtrVectorSize(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
unsigned getMaskAlignment(Value mask);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
119
include/triton/Analysis/Membar.h
Normal file
119
include/triton/Analysis/Membar.h
Normal file
@@ -0,0 +1,119 @@
|
||||
#ifndef TRITON_ANALYSIS_MEMBAR_H
|
||||
#define TRITON_ANALYSIS_MEMBAR_H
|
||||
|
||||
#include "Allocation.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OpBuilder;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Barrier Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class MembarAnalysis {
|
||||
public:
|
||||
/// Creates a new Membar analysis that generates the shared memory barrier
|
||||
/// in the following circumstances:
|
||||
/// - RAW: If a shared memory write is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// - WAR: If a shared memory read is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// The following circumstances do not require a barrier:
|
||||
/// - WAW: not possible because overlapped memory allocation is not allowed.
|
||||
/// - RAR: no write is performed.
|
||||
/// Temporary storage of operations such as Reduce are considered as both
|
||||
/// a shared memory read. If the temporary storage is written but not read,
|
||||
/// it is considered as the problem of the operation itself but not the membar
|
||||
/// analysis.
|
||||
/// The following circumstances are not considered yet:
|
||||
/// - Double buffers
|
||||
/// - N buffers
|
||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run();
|
||||
|
||||
private:
|
||||
struct RegionInfo {
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
|
||||
BufferIdSetT syncReadBuffers;
|
||||
BufferIdSetT syncWriteBuffers;
|
||||
|
||||
RegionInfo() = default;
|
||||
RegionInfo(const BufferIdSetT &syncReadBuffers,
|
||||
const BufferIdSetT &syncWriteBuffers)
|
||||
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
|
||||
}
|
||||
|
||||
/// Unions two RegionInfo objects.
|
||||
void join(const RegionInfo &other) {
|
||||
syncReadBuffers.insert(other.syncReadBuffers.begin(),
|
||||
other.syncReadBuffers.end());
|
||||
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
|
||||
other.syncWriteBuffers.end());
|
||||
}
|
||||
|
||||
/// Returns true if buffers in two RegionInfo objects are intersected.
|
||||
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
|
||||
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
||||
allocation) ||
|
||||
/*WAR*/
|
||||
isIntersected(syncReadBuffers, other.syncWriteBuffers,
|
||||
allocation) ||
|
||||
/*WAW*/
|
||||
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
|
||||
allocation);
|
||||
}
|
||||
|
||||
/// Clears the buffers because a barrier is inserted.
|
||||
void sync() {
|
||||
syncReadBuffers.clear();
|
||||
syncWriteBuffers.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns true if buffers in two sets are intersected.
|
||||
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
|
||||
Allocation *allocation) const {
|
||||
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
|
||||
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
|
||||
return allocation->isIntersected(lhsId, rhsId);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||
/// region has a single basic block only.
|
||||
/// Example:
|
||||
/// region1
|
||||
/// op1
|
||||
/// op2 (scf.if)
|
||||
/// region2
|
||||
/// op3
|
||||
/// op4
|
||||
/// region3
|
||||
/// op5
|
||||
/// op6
|
||||
/// op7
|
||||
/// region2 and region3 started with the information of region1.
|
||||
/// Each region is analyzed separately and keeps their own copy of the
|
||||
/// information. At op7, we union the information of the region2 and region3
|
||||
/// and update the information of region1.
|
||||
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
|
||||
/// Updates the RegionInfo operation based on the operation.
|
||||
void transfer(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
|
||||
private:
|
||||
Allocation *allocation;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_MEMBAR_H
|
82
include/triton/Analysis/Utility.h
Normal file
82
include/triton/Analysis/Utility.h
Normal file
@@ -0,0 +1,82 @@
|
||||
#ifndef TRITON_ANALYSIS_UTILITY_H
|
||||
#define TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
||||
srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
||||
|
||||
Attribute getSrcLayout() { return srcTy.getEncoding(); }
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
unsigned getInterWarpSize();
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
|
||||
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
|
||||
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
};
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
bool maybeAliasOp(Operation *op);
|
||||
|
||||
bool supportMMA(triton::DotOp op, int version);
|
||||
|
||||
bool supportMMA(Value value, int version);
|
||||
|
||||
Type getElementType(Value value);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
SmallVector<T_OUT> out;
|
||||
for (const T_IN &i : in)
|
||||
out.push_back(T_OUT(i));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
// output[i] = input[order[i]]
|
||||
template <typename T, typename RES_T = T>
|
||||
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<RES_T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[it.index()] = input[it.value()];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
2
include/triton/CMakeLists.txt
Normal file
2
include/triton/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
4
include/triton/Conversion/CMakeLists.txt
Normal file
4
include/triton/Conversion/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(TritonConversionPassIncGen)
|
40
include/triton/Conversion/MLIRTypes.h
Normal file
40
include/triton/Conversion/MLIRTypes.h
Normal file
@@ -0,0 +1,40 @@
|
||||
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
|
||||
#define TRITON_CONVERSION_MLIR_TYPES_H_
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
// This file redefines some common MLIR types for easy usage.
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace type {
|
||||
|
||||
// Integer types
|
||||
// TODO(Superjomn): may change `static` into better implementations
|
||||
static Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
|
||||
static Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); }
|
||||
static Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
|
||||
static Type u32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||
}
|
||||
static Type u1Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
||||
}
|
||||
|
||||
// Float types
|
||||
static Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
static Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||
static Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||
static Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
|
||||
|
||||
static bool isFloat(Type type) {
|
||||
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
||||
}
|
||||
|
||||
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
|
||||
|
||||
} // namespace type
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_MLIR_TYPES_H_
|
17
include/triton/Conversion/Passes.h
Normal file
17
include/triton/Conversion/Passes.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES_H
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
54
include/triton/Conversion/Passes.td
Normal file
54
include/triton/Conversion/Passes.td
Normal file
@@ -0,0 +1,54 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES
|
||||
#define TRITON_CONVERSION_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
|
||||
let summary = "Convert Triton to TritonGPU";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::math::MathDialect",
|
||||
"mlir::StandardOpsDialect",
|
||||
// TODO: Does this pass depend on SCF?
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Convert TritonGPU to LLVM";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::math::MathDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::LLVM::LLVMDialect",
|
||||
"mlir::tensor::TensorDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::NVVM::NVVMDialect",
|
||||
"mlir::StandardOpsDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
326
include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h
Normal file
326
include/triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h
Normal file
@@ -0,0 +1,326 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPatternRewriter;
|
||||
class Location;
|
||||
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
struct PTXInstr;
|
||||
struct PTXInstrCommon;
|
||||
struct PTXInstrExecution;
|
||||
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
//
|
||||
// A helper for building an ASM program, the objective of PTXBuilder is to give
|
||||
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// Currently, several factors are introduced to reduce the need for mixing
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
// Usage:
|
||||
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
|
||||
// "b"(p));
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& add = builder.create<>();
|
||||
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
|
||||
// // predicate here binds %0 to pVal, pVal is a mlir::Value
|
||||
//
|
||||
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
|
||||
//
|
||||
// To get the asm code:
|
||||
// builder.dump()
|
||||
//
|
||||
// To get all the mlir::Value used in the PTX code,
|
||||
//
|
||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||
//
|
||||
// To get the string containing all the constraints with "," separated,
|
||||
// builder.getConstraints() // get "=r,r,k"
|
||||
//
|
||||
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& mov = builder.create("mov");
|
||||
// auto& cp = builder.create("cp");
|
||||
// mov(...);
|
||||
// cp(...);
|
||||
// This will get a PTX code with two instructions.
|
||||
//
|
||||
// Similar to a C function, a declared PTXInstr instance can be launched
|
||||
// multiple times with different operands, e.g.
|
||||
//
|
||||
// auto& mov = builder.create("mov");
|
||||
// mov(... some operands ...);
|
||||
// mov(... some different operands ...);
|
||||
//
|
||||
// Finally, we will get a PTX code with two mov instructions.
|
||||
//
|
||||
// There are several derived instruction type for typical instructions, for
|
||||
// example, the PtxIOInstr for ld and st instructions.
|
||||
struct PTXBuilder {
|
||||
struct Operand {
|
||||
std::string constraint;
|
||||
Value value;
|
||||
int idx{-1};
|
||||
llvm::SmallVector<Operand *> list;
|
||||
std::function<std::string(int idx)> repr;
|
||||
|
||||
// for list
|
||||
Operand() = default;
|
||||
Operand(const Operation &) = delete;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: constraint(constraint), value(value) {}
|
||||
|
||||
bool isList() const { return !value && constraint.empty(); }
|
||||
|
||||
Operand *listAppend(Operand *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
Operand *listGet(size_t nth) const {
|
||||
assert(nth < list.size());
|
||||
return list[nth];
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
template <typename INSTR = PTXInstr, typename... Args>
|
||||
INSTR *create(Args &&...args) {
|
||||
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
|
||||
return static_cast<INSTR *>(instrs.back().get());
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newListOperand() { return newOperand(); }
|
||||
|
||||
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
|
||||
auto *list = newOperand();
|
||||
for (auto &item : items) {
|
||||
list->listAppend(newOperand(item.first, item.second));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||
const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(val, constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
// @formatter: extra format to represent this operand in ASM code, default is
|
||||
// "%{0}".format(operand.idx).
|
||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int idx)> formatter = nullptr);
|
||||
|
||||
// Create a new operand which is written to, that is, the constraint starts
|
||||
// with "=", e.g. "=r".
|
||||
Operand *newOperand(StringRef constraint);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int64_t v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
Operand *newConstantOperand(const std::string &v);
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||
|
||||
llvm::SmallVector<Operand *, 4> getAllArgs() const;
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||
|
||||
std::string getConstraints() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type resTy, bool hasSideEffect = true,
|
||||
bool isAlignStack = false,
|
||||
ArrayRef<Attribute> attrs = {}) const;
|
||||
|
||||
private:
|
||||
Operand *newOperand() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
// Make the operands in argArchive follow the provided \param order.
|
||||
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||
assert(order.size() == argArchive.size());
|
||||
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
|
||||
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
|
||||
// determined by PTX code snippet passed from external.
|
||||
sort(argArchive.begin(), argArchive.end(),
|
||||
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
|
||||
auto ida = std::find(order.begin(), order.end(), a.get());
|
||||
auto idb = std::find(order.begin(), order.end(), b.get());
|
||||
assert(ida != order.end());
|
||||
assert(idb != order.end());
|
||||
return ida < idb;
|
||||
});
|
||||
}
|
||||
|
||||
friend struct PTXInstr;
|
||||
friend struct PTXInstrCommon;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
|
||||
int oprCounter{};
|
||||
};
|
||||
|
||||
// PTX instruction common interface.
|
||||
// Put the generic logic for all the instructions here.
|
||||
struct PTXInstrCommon {
|
||||
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
|
||||
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
// clang-format off
|
||||
PTXInstrExecution& operator()() { return call({}); }
|
||||
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
protected:
|
||||
// "Call" the instruction with operands.
|
||||
// \param oprs The operands of this instruction.
|
||||
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
|
||||
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
|
||||
// code.
|
||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
PTXBuilder *builder{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
|
||||
friend struct PTXInstrExecution;
|
||||
};
|
||||
|
||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
|
||||
: PTXInstrCommon(builder) {
|
||||
o(name);
|
||||
}
|
||||
|
||||
// Append a suffix to the instruction.
|
||||
// e.g. PTXInstr("add").o("s32") get a add.s32.
|
||||
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
||||
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||
// get a `add.s32` if isS32 is true.
|
||||
ConcreteT &o(const std::string &suffix, bool predicate = true) {
|
||||
if (predicate)
|
||||
instrParts.push_back(suffix);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||
|
||||
// Append a ".global" to the instruction.
|
||||
PTXInstr &global();
|
||||
|
||||
// Append a ".shared" to the instruction.
|
||||
PTXInstr &shared();
|
||||
|
||||
// Append a ".v[0-9]+" to the instruction
|
||||
PTXInstr &v(int vecWidth, bool predicate = true);
|
||||
|
||||
// Append a".b[0-9]+" to the instruction
|
||||
PTXInstr &b(int width);
|
||||
};
|
||||
|
||||
// Record the operands and context for "launching" a PtxInstr.
|
||||
struct PTXInstrExecution {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
|
||||
PTXInstrExecution() = default;
|
||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs)
|
||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
|
||||
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Prefix a !predicate to the instruction.
|
||||
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
SmallVector<Operand *> getArgList() const;
|
||||
|
||||
PTXInstrCommon *instr{};
|
||||
Operand *pred{};
|
||||
bool onlyAttachMLIRArgs{};
|
||||
};
|
||||
|
||||
/// ====== Some instruction wrappers ======
|
||||
// We add the wrappers to make the usage more intuitive by avoiding mixing the
|
||||
// PTX code with some trivial C++ code.
|
||||
|
||||
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
|
||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||
triton::CacheModifier modifier)
|
||||
: PTXInstrBase(builder, "cp.async") {
|
||||
o(triton::stringifyCacheModifier(modifier).str());
|
||||
o("shared");
|
||||
o("global");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
@@ -0,0 +1,22 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
@@ -0,0 +1,25 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
|
||||
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||
|
||||
// Create the pass with numWarps passed from cl::opt.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
2
include/triton/Dialect/CMakeLists.txt
Normal file
2
include/triton/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(Triton)
|
||||
add_subdirectory(TritonGPU)
|
2
include/triton/Dialect/Triton/CMakeLists.txt
Normal file
2
include/triton/Dialect/Triton/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
19
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
19
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
|
||||
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
|
||||
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
|
||||
|
||||
add_public_tablegen_target(TritonTableGen)
|
48
include/triton/Dialect/Triton/IR/Dialect.h
Normal file
48
include/triton/Dialect/Triton/IR/Dialect.h
Normal file
@@ -0,0 +1,48 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
class DialectInferLayoutInterface
|
||||
: public DialectInterface::Base<DialectInferLayoutInterface> {
|
||||
public:
|
||||
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
virtual LogicalResult
|
||||
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
|
||||
virtual LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding,
|
||||
Optional<Location> location) const = 0;
|
||||
|
||||
// Note: this function only verify operand encoding but doesn't infer result
|
||||
// encoding
|
||||
virtual LogicalResult
|
||||
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
|
||||
Attribute retEncoding,
|
||||
Optional<Location> location) const = 0;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifndef TRITON_IR_INTERFACES_H_
|
||||
#define TRITON_IR_INTERFACES_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
60
include/triton/Dialect/Triton/IR/Traits.h
Normal file
60
include/triton/Dialect/Triton/IR/Traits.h
Normal file
@@ -0,0 +1,60 @@
|
||||
#ifndef TRITON_IR_TRAITS_H_
|
||||
#define TRITON_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
|
||||
LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||
// The rationale for this trait is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxTensorNumElements = 1048576;
|
||||
LogicalResult verifyTensorSize(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <class ConcreteType>
|
||||
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyTensorSize(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsAndResultEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
68
include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Normal file
68
include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Normal file
@@ -0,0 +1,68 @@
|
||||
#ifndef TRITON_ATTR_DEFS
|
||||
#define TRITON_ATTR_DEFS
|
||||
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
// Attrs for LoadOp
|
||||
def TT_CacheModifierAttr : I32EnumAttr<
|
||||
"CacheModifier", "",
|
||||
[
|
||||
I32EnumAttrCase<"NONE", 1, "none">,
|
||||
I32EnumAttrCase<"CA", 2, "ca">,
|
||||
I32EnumAttrCase<"CG", 3, "cg">,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
"EvictionPolicy", "",
|
||||
[
|
||||
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
|
||||
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
|
||||
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
/*case*/
|
||||
[
|
||||
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
||||
I32EnumAttrCase<"FADD", 2, "fadd">,
|
||||
I32EnumAttrCase<"MIN", 3, "min">,
|
||||
I32EnumAttrCase<"MAX", 4, "max">,
|
||||
I32EnumAttrCase<"UMIN", 5, "umin">,
|
||||
I32EnumAttrCase<"UMAX", 6, "umax">,
|
||||
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
|
||||
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
|
||||
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
|
||||
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
|
||||
I32EnumAttrCase<"FMIN", 11, "fmin">,
|
||||
I32EnumAttrCase<"FMAX", 12, "fmax">,
|
||||
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
|
||||
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
|
||||
I32EnumAttrCase<"XOR", 15, "xor">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// atomic
|
||||
def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
"RMWOp", "",
|
||||
[
|
||||
I32EnumAttrCase<"AND", 1, "and">,
|
||||
I32EnumAttrCase<"OR", 2, "or">,
|
||||
I32EnumAttrCase<"XOR", 3, "xor">,
|
||||
I32EnumAttrCase<"ADD", 4, "add">,
|
||||
I32EnumAttrCase<"FADD", 5, "fadd">,
|
||||
I32EnumAttrCase<"MAX", 6, "max">,
|
||||
I32EnumAttrCase<"MIN", 7, "min">,
|
||||
I32EnumAttrCase<"UMAX", 8, "umax">,
|
||||
I32EnumAttrCase<"UMIN", 9, "umin">,
|
||||
I32EnumAttrCase<"XCHG", 10, "exch">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
#endif
|
46
include/triton/Dialect/Triton/IR/TritonDialect.td
Normal file
46
include/triton/Dialect/Triton/IR/TritonDialect.td
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef TRITON_DIALECT
|
||||
#define TRITON_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Triton_Dialect : Dialect {
|
||||
let name = "tt";
|
||||
|
||||
let cppNamespace = "::mlir::triton";
|
||||
|
||||
let summary = "The Triton IR in MLIR";
|
||||
|
||||
let description = [{
|
||||
Triton Dialect.
|
||||
|
||||
Dependent Dialects:
|
||||
* Arithmetic:
|
||||
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
|
||||
* Math:
|
||||
* exp, sin, cos, log, ...
|
||||
* StructuredControlFlow:
|
||||
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"math::MathDialect",
|
||||
"StandardOpsDialect",
|
||||
"scf::SCFDialect",
|
||||
|
||||
// Since LLVM 15
|
||||
// "cf::ControlFlowDialect",
|
||||
// "func::FuncDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
|
||||
|
||||
#endif // TRITON_DIALECT
|
11
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
11
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
@@ -0,0 +1,11 @@
|
||||
#ifndef TRITON_INTERFACES
|
||||
#define TRITON_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
||||
|
||||
|
||||
#endif // TRITON_INTERFACES
|
423
include/triton/Dialect/Triton/IR/TritonOps.td
Normal file
423
include/triton/Dialect/Triton/IR/TritonOps.td
Normal file
@@ -0,0 +1,423 @@
|
||||
#ifndef TRITON_OPS
|
||||
#define TRITON_OPS
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
|
||||
}
|
||||
|
||||
//
|
||||
// CastOps
|
||||
//
|
||||
// Use cast ops in arith:
|
||||
// bitcast
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
let arguments = (ins TT_I64Like:$from);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$from);
|
||||
|
||||
let results = (outs TT_I64Like:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
// arith.bitcast doesn't support pointers
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast between types of the same bitwidth";
|
||||
|
||||
let arguments = (ins TT_Type:$from);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
|
||||
// TODO: Add verifier
|
||||
}
|
||||
|
||||
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
Floating point casting for custom types (F8).
|
||||
|
||||
F8 <-> FP16, BF16, FP32, FP64
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatLike:$from);
|
||||
|
||||
let results = (outs TT_FloatLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
|
||||
//
|
||||
// Pointer Arith Ops
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr", "getPointerTypeSameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type or none",
|
||||
"result", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from result type or none",
|
||||
"result", "other", "$_self",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
SameOperandsEncoding,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
"getPointerTypeSameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"value", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic Op
|
||||
//
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"val", "ptr",
|
||||
"getPointerTypeSameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"val", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
let description = [{
|
||||
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
|
||||
|
||||
return old value at $ptr
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
|
||||
TT_Type:$val, Optional<TT_BoolLike>:$mask);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
compare $cmp with data $old at location $ptr,
|
||||
|
||||
if $old == $cmp, store $val to $ptr,
|
||||
|
||||
else store $old to $ptr,
|
||||
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "transpose a tensor";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Dot Op
|
||||
//
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
//
|
||||
// Reduce Op
|
||||
//
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// This member function is marked static because we need to call it before the ReduceOp
|
||||
// is constructed, see the implementation of create_reduce in triton.cc.
|
||||
static bool withIndex(mlir::triton::RedOp redOp);
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
// External elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
let summary = "ext_elemwise";
|
||||
|
||||
let description = [{
|
||||
call an external function $symbol implemented in $libpath/$libname with $args
|
||||
|
||||
return $libpath/$libname:$symbol($args...)
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Make Range Op
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
let summary = "make range";
|
||||
|
||||
let description = [{
|
||||
Returns an 1D int32 tensor.
|
||||
|
||||
Values span from $start to $end (exclusive), with step = 1
|
||||
}];
|
||||
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Make PrintfOp
|
||||
//
|
||||
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix,
|
||||
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
let summary = "Device-side printf, as in CUDA for debugging";
|
||||
let description = [{
|
||||
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||
format are generated automatically from the arguments.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$prefix attr-dict ($args^ `:` type($args))?
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
71
include/triton/Dialect/Triton/IR/TritonTypes.td
Normal file
71
include/triton/Dialect/Triton/IR/TritonTypes.td
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef TRITON_TYPES
|
||||
#define TRITON_TYPES
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
class TritonTypeDef<string name, string _mnemonic>
|
||||
: TypeDef<Triton_Dialect, name> {
|
||||
// Used by printer/parser
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
// Boolean Type
|
||||
// TT_Bool -> I1
|
||||
def TT_BoolTensor : TensorOf<[I1]>;
|
||||
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||
|
||||
// Integer Type
|
||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||
def TT_IntTensor : TensorOf<[TT_Int]>;
|
||||
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
|
||||
|
||||
// I32 Type
|
||||
// TT_I32 -> I32
|
||||
// TT_I32Tensor -> I32Tensor
|
||||
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
|
||||
|
||||
// I64 Type
|
||||
// TT_I64 -> I64
|
||||
// TT_I64Tensor -> I64Tensor
|
||||
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
|
||||
|
||||
// Pointer Type
|
||||
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
let description = [{
|
||||
Triton PointerType
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins
|
||||
"Type":$pointeeType,
|
||||
"int":$addressSpace
|
||||
), [{
|
||||
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
|
||||
}]>
|
||||
];
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
|
||||
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
|
||||
|
||||
#endif
|
10
include/triton/Dialect/Triton/IR/Types.h
Normal file
10
include/triton/Dialect/Triton/IR/Types.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef TRITON_IR_TYPES_H_
|
||||
#define TRITON_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
|
||||
add_public_tablegen_target(TritonTransformsIncGen)
|
18
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
18
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef TRITON_PASSES
|
||||
#define TRITON_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "combine ops";
|
||||
let description = [{
|
||||
dot(a, b, 0) + c => dot(a, b, c)
|
||||
|
||||
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
|
||||
|
||||
select(cond, load(ptrs, broadcast(cond), ???), other) =>
|
||||
load(ptrs, broadcast(cond), other)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::createCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||
}
|
||||
|
||||
#endif
|
2
include/triton/Dialect/TritonGPU/CMakeLists.txt
Normal file
2
include/triton/Dialect/TritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
12
include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Normal file
12
include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(TritonGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(TritonGPUAttrDefsIncGen)
|
||||
|
48
include/triton/Dialect/TritonGPU/IR/Dialect.h
Normal file
48
include/triton/Dialect/TritonGPU/IR/Dialect.h
Normal file
@@ -0,0 +1,48 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
|
||||
unsigned getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||
|
||||
bool isaDistributedLayout(const Attribute &layout);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
31
include/triton/Dialect/TritonGPU/IR/Traits.h
Normal file
31
include/triton/Dialect/TritonGPU/IR/Traits.h
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef TRITON_GPU_IR_TRAITS_H_
|
||||
#define TRITON_GPU_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <typename ConcreteType>
|
||||
class ResultsAreSharedEncoding
|
||||
: public TraitBase<ConcreteType, ResultsAreSharedEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyResultsAreSharedEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
481
include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Normal file
481
include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Normal file
@@ -0,0 +1,481 @@
|
||||
#ifndef TRITONGPU_ATTRDEFS
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonGPU Attribute Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
|
||||
|
||||
let description = [{
|
||||
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
|
||||
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
|
||||
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
|
||||
to the indices of the CUDA threads allowed to access some data at index $i$.
|
||||
|
||||
For example, let us consider the layout function:
|
||||
\mathcal{L}(0, 0) = {0, 4}
|
||||
\mathcal{L}(0, 1) = {1, 5}
|
||||
\mathcal{L}(1, 0) = {2, 6}
|
||||
\mathcal{L}(1, 1) = {3, 7}
|
||||
|
||||
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
|
||||
- T[0,0] is owned by both cuda thread 0 and 4
|
||||
- T[0,1] is owned by both cuda thread 1 and 5
|
||||
- T[1,0] is owned by both cuda thread 2 and 6
|
||||
- T[1,1] is owned by both cuda thread 3 and 7
|
||||
|
||||
Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> {
|
||||
let mnemonic = "shared";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors whose elements may be simultaneously accessed by
|
||||
different cuda threads in the programs, via shared memory. In other words,
|
||||
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
||||
|
||||
In order to avoid shared memory bank conflicts, elements may be swizzled
|
||||
in memory. For example, a swizzled row-major layout could store its data
|
||||
as follows:
|
||||
|
||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
||||
groups of vec=2 elements
|
||||
are stored contiguously
|
||||
_ _ _ _ /\_ _ _ _
|
||||
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
||||
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
// swizzle info
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"Type":$eltTy), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
|
||||
// ---- begin Volta ----
|
||||
if (mmaEnc.isVolta()) {
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
// TODO[Superjomn]: Support the case when is_vec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
|
||||
is_vec4 = true;
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
int vec = 2 * rep;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// ---- begin Ampere ----
|
||||
if (mmaEnc.isAmpere()) {
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (eltTy.isInteger(8) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Distributed Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class DistributedEncoding<string name> : TritonGPU_Attr<name> {
|
||||
let description = [{
|
||||
Distributed encodings have a layout function that is entirely characterized
|
||||
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
|
||||
(or even the same rank) as the tensor it is encoding.
|
||||
|
||||
The layout function \mathcal{L} of this layout is then defined, for an
|
||||
index `i` \in R^D, as follows:
|
||||
|
||||
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
|
||||
|
||||
For example, for a tensor/layout pair
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
||||
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
|
||||
let mnemonic = "blocked";
|
||||
|
||||
let description = [{
|
||||
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
||||
used to promote memory coalescing in LoadInst and StoreInst.
|
||||
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
||||
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
||||
|
||||
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
for
|
||||
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
}>
|
||||
}];
|
||||
|
||||
|
||||
let builders = [
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// TODO: compiles on MacOS but not linux?
|
||||
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
|
||||
// "ArrayRef<unsigned>":$threadsPerWarp,
|
||||
// "ArrayRef<unsigned>":$warpsPerCTA,
|
||||
// "ArrayRef<unsigned>":$order), [{
|
||||
// int rank = threadsPerWarp.size();
|
||||
// SmallVector<unsigned, 4> sizePerWarp(rank);
|
||||
// SmallVector<unsigned, 4> sizePerCTA(rank);
|
||||
// for (unsigned i = 0; i < rank; i++) {
|
||||
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
|
||||
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
|
||||
// }
|
||||
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
|
||||
// }]>,
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// Default builder takes sizePerThread, order and numWarps, and tries to
|
||||
// pack numWarps*32 threads in the provided order for use in a type
|
||||
// of the given shape.
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps), [{
|
||||
int rank = sizePerThread.size();
|
||||
unsigned remainingLanes = 32;
|
||||
unsigned remainingThreads = numWarps*32;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
// fastest-changing axis first
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
// These attributes can be inferred from the rest
|
||||
// ArrayRefParameter<"unsigned">:$sizePerWarp,
|
||||
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMA Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: MMAv1 and MMAv2 should be two instances of the same class
|
||||
|
||||
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
|
||||
let mnemonic = "mma";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors that have been produced by tensor cores.
|
||||
It is characterized by two parameters:
|
||||
- A 'versionMajor' which specifies the generation the tensor cores
|
||||
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
||||
and 2 for second-gen tensor cores (Turing/Ampere).
|
||||
- A 'versionMinor' which indicates the specific layout of a tensor core
|
||||
generation, e.g. for Volta, there might be multiple kinds of layouts annotated
|
||||
by 0,1,2 and so on.
|
||||
- A `blockTileSize` to indicate how data should be
|
||||
partitioned between warps.
|
||||
|
||||
// -------------------------------- version = 1 --------------------------- //
|
||||
|
||||
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
||||
Note: the layout is different from the recommended in PTX ISA
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.884 section, FP32 accumulator).
|
||||
|
||||
For example, when versionMinor=1, the matrix L corresponding to
|
||||
blockTileSize=[32,16] is:
|
||||
|
||||
warp 0
|
||||
--------------------------------/\-------------------------------
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
|
||||
warp 1 = warp0 + 32
|
||||
--------------------------------/\-------------------------------
|
||||
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
|
||||
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
|
||||
[ ............................................................... ]
|
||||
|
||||
|
||||
// -------------------------------- version = 2 --------------------------- //
|
||||
|
||||
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.16816 section, FP32 accumulator).
|
||||
|
||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
warp 0 warp 1
|
||||
-----------------/\------------- ----------------/\-------------
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
|
||||
warp 3 warp 4
|
||||
----------------/\------------- ----------------/\-------------
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$versionMajor,
|
||||
"unsigned":$versionMinor,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// specific for MMAV1(Volta)
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"ArrayRef<unsigned>":$warpsPerCTA,
|
||||
"ArrayRef<int64_t>":$shapeA,
|
||||
"ArrayRef<int64_t>":$shapeB,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow), [{
|
||||
assert(versionMajor == 1 && "Only MMAv1 has multiple versionMinor.");
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
|
||||
int versionMinor = (isARow * (1<<0)) |\
|
||||
(isBRow * (1<<1)) |\
|
||||
(isAVec4 * (1<<2)) |\
|
||||
(isBVec4 * (1<<3));
|
||||
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
bool isVolta() const;
|
||||
bool isAmpere() const;
|
||||
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
|
||||
std::tuple<bool, bool, bool, bool> decodeVoltaLayoutStates() const;
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
let mnemonic = "slice";
|
||||
|
||||
let description = [{
|
||||
TODO: improve docs
|
||||
|
||||
A = [x x x x x x x x]
|
||||
|
||||
parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
dim = 0
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
|
||||
|
||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$dim,
|
||||
// TODO: constraint here to only take distributed encodings
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
template<class T>
|
||||
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
|
||||
}];
|
||||
}
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
let mnemonic = "dot_op";
|
||||
|
||||
let description = [{
|
||||
In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
|
||||
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
|
||||
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
|
||||
section 9.7.13.4.1 for more details.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent,
|
||||
"Attribute":$isMMAv1Row
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent), [{
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().isVolta()){
|
||||
isMMAv1Row = BoolAttr::get(context, true);
|
||||
}
|
||||
return $_get(context, opIdx, parent, isMMAv1Row);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
36
include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Normal file
36
include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef TRITONGPU_DIALECT
|
||||
#define TRITONGPU_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TritonGPU_Dialect : Dialect {
|
||||
let name = "triton_gpu";
|
||||
|
||||
let cppNamespace = "::mlir::triton::gpu";
|
||||
|
||||
let hasOperationAttrVerify = 1;
|
||||
|
||||
let description = [{
|
||||
Triton GPU Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif
|
198
include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Normal file
198
include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Normal file
@@ -0,0 +1,198 @@
|
||||
#ifndef TRITONGPU_OPS
|
||||
#define TRITONGPU_OPS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
|
||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[SameOperandsAndResultShape, NoSideEffect]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let summary = "async wait";
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static bool isSupported(int computeCapability) {
|
||||
return computeCapability >= 80;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||
TT_IntLike:$lhs,
|
||||
TT_IntLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
|
||||
TT_FloatLike:$lhs,
|
||||
TT_FloatLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "select operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins TT_BoolLike:$condition,
|
||||
TT_Tensor:$true_value,
|
||||
TT_Tensor:$false_value);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from src type",
|
||||
"src", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
|
||||
let summary = "insert slice async";
|
||||
|
||||
let description = [{
|
||||
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s
|
||||
`$index` argument and `$axis` attribute.
|
||||
|
||||
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
|
||||
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
|
||||
|
||||
When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||
|
||||
The insert_slice_async operation supports the following arguments:
|
||||
|
||||
* src: the tensor that is inserted.
|
||||
* dst: the tensor into which the `$src` tensor is inserted.
|
||||
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
|
||||
* mask: optional tensor-rank number of boolean masks which specify which
|
||||
elements of the `$src` tensor are inserted into the `$dst` tensor.
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
In the future, we may decompose this operation into a sequence of:
|
||||
|
||||
* `async` operation to specify a sequence of asynchronous operations
|
||||
* `load` operation to load a tensor from global memory
|
||||
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
|
||||
%2 = triton_gpu.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
|
||||
triiton_gpu.async_wait { num = 0 : i32 }
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"Value":$mask, "Value":$other,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
//let assemblyFormat = [{
|
||||
// $src `,` $dst ``
|
||||
// $index, $mask, $other
|
||||
// attr-dict `:` type($src) `->` type($dst)
|
||||
//}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
|
||||
DenseSet<unsigned> validLoadBytes;
|
||||
if (computeCapability >= 80) {
|
||||
validLoadBytes = {4, 8, 16};
|
||||
}
|
||||
return validLoadBytes;
|
||||
}
|
||||
}];
|
||||
|
||||
// The custom parser could be replaced with oilist in LLVM-16
|
||||
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
||||
|
||||
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
||||
}
|
||||
|
||||
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
|
||||
ResultsAreSharedEncoding]> {
|
||||
let summary = "allocate tensor";
|
||||
|
||||
let description = [{
|
||||
This operation defines a tensor of a particular shape.
|
||||
The contents of the tensor are supposed to be in shared memory.
|
||||
|
||||
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
#endif
|
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU)
|
||||
add_public_tablegen_target(TritonGPUTransformsIncGen)
|
25
include/triton/Dialect/TritonGPU/Transforms/Passes.h
Normal file
25
include/triton/Dialect/TritonGPU/Transforms/Passes.h
Normal file
@@ -0,0 +1,25 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
// TODO(Keren): prefetch pass not working yet
|
||||
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
#endif
|
87
include/triton/Dialect/TritonGPU/Transforms/Passes.td
Normal file
87
include/triton/Dialect/TritonGPU/Transforms/Passes.td
Normal file
@@ -0,0 +1,87 @@
|
||||
#ifndef TRITONGPU_PASSES
|
||||
#define TRITONGPU_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
Unroll loops to hide global memory -> shared memory latency.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithmeticDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"2",
|
||||
"number of pipeline stages">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
let description = [{
|
||||
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPrefetchPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithmeticDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCoalescePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
let summary = "combine triton gpu ops";
|
||||
|
||||
let description = [{
|
||||
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
|
||||
convert_layout(%src, #LAYOUT_1)
|
||||
|
||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
let description = [{
|
||||
This implements some optimizations that are missing in the standard scf.ForOp
|
||||
canonicalizer.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
#endif
|
@@ -0,0 +1,33 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines utilities to use while converting to the TritonGPU dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
|
||||
int getNumWarps() const { return numWarps; }
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numWarps;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||
TritonGPUTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
37
include/triton/Target/LLVMIR/LLVMIRTranslation.h
Normal file
37
include/triton/Target/LLVMIR/LLVMIRTranslation.h
Normal file
@@ -0,0 +1,37 @@
|
||||
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||
#define TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// add external dependent libs
|
||||
void addExternalLibs(mlir::ModuleOp &module,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths);
|
||||
|
||||
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module, int computeCapability);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_TARGET_LLVMIRTRANSLATION_H
|
17
include/triton/Target/PTX/PTXTranslation.h
Normal file
17
include/triton/Target/PTX/PTXTranslation.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifndef TRITON_TARGET_PTXTRANSLATION_H
|
||||
#define TRITON_TARGET_PTXTRANSLATION_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
} // namespace llvm
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Translate TritonGPU IR to PTX code.
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
37
include/triton/tools/sys/getenv.hpp → include/triton/Tools/Sys/GetEnv.hpp
Executable file → Normal file
37
include/triton/tools/sys/getenv.hpp → include/triton/Tools/Sys/GetEnv.hpp
Executable file → Normal file
@@ -22,35 +22,32 @@
|
||||
#ifndef TDL_TOOLS_SYS_GETENV_HPP
|
||||
#define TDL_TOOLS_SYS_GETENV_HPP
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace triton {
|
||||
|
||||
namespace tools
|
||||
{
|
||||
namespace tools {
|
||||
|
||||
inline std::string getenv(const char * name)
|
||||
{
|
||||
#ifdef _MSC_VER
|
||||
char* cache_path = 0;
|
||||
std::size_t sz = 0;
|
||||
_dupenv_s(&cache_path, &sz, name);
|
||||
#else
|
||||
const char * cstr = std::getenv(name);
|
||||
#endif
|
||||
if(!cstr)
|
||||
inline std::string getenv(const char *name) {
|
||||
const char *cstr = std::getenv(name);
|
||||
if (!cstr)
|
||||
return "";
|
||||
std::string result(cstr);
|
||||
#ifdef _MSC_VER
|
||||
free(cache_path);
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(const std::string &env) {
|
||||
const char *s = std::getenv(env.c_str());
|
||||
std::string str(s ? s : "");
|
||||
std::transform(str.begin(), str.end(), str.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
return (str == "on" || str == "true" || str == "1");
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -1,80 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class value;
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class align {
|
||||
private:
|
||||
struct cst_info {
|
||||
unsigned num_cst;
|
||||
unsigned value;
|
||||
};
|
||||
// helpers
|
||||
std::vector<unsigned> get_shapes(ir::value *v);
|
||||
// populate is_constant
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||
// populate max_contiguous
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
void populate(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value* v, unsigned ax) const;
|
||||
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
|
||||
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,47 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class tiles;
|
||||
|
||||
class liveness;
|
||||
class cts;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
allocation(liveness *live)
|
||||
: liveness_(live) { }
|
||||
// accessors
|
||||
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
|
||||
unsigned allocated_size() const { return allocated_size_; }
|
||||
// run
|
||||
void run(ir::module& mod);
|
||||
|
||||
private:
|
||||
std::map<const data_layout*, unsigned> offsets_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,51 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
|
||||
#include "triton/tools/graph.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
|
||||
private:
|
||||
// update graph
|
||||
void update_graph_store(ir::instruction *i);
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
// accessors
|
||||
int get(ir::value *value, unsigned dim);
|
||||
std::vector<int> get(ir::value *value);
|
||||
|
||||
private:
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<node_t, size_t> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,243 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class type;
|
||||
class module;
|
||||
class instruction;
|
||||
class phi_node;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class align;
|
||||
class layout_visitor;
|
||||
class data_layout;
|
||||
class mma_layout;
|
||||
class scanline_layout;
|
||||
class shared_layout;
|
||||
|
||||
|
||||
class layout_visitor {
|
||||
public:
|
||||
virtual void visit_layout(data_layout *);
|
||||
virtual void visit_layout_mma(mma_layout*) = 0;
|
||||
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||
};
|
||||
|
||||
class data_layout {
|
||||
protected:
|
||||
enum id_t {
|
||||
MMA,
|
||||
SCANLINE,
|
||||
SHARED
|
||||
};
|
||||
|
||||
typedef std::vector<int> axes_t;
|
||||
typedef std::vector<unsigned> shape_t;
|
||||
typedef std::vector<int> order_t;
|
||||
typedef std::vector<ir::value*> values_t;
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
T* downcast(id_t id) {
|
||||
if(id_ == id)
|
||||
return static_cast<T*>(this);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
public:
|
||||
data_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
// visitor
|
||||
virtual void accept(layout_visitor* vst) = 0;
|
||||
// downcast
|
||||
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
|
||||
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||
// accessors
|
||||
size_t get_rank() { return shape_.size(); }
|
||||
const shape_t& get_shape() const { return shape_; }
|
||||
const order_t& get_order() const { return order_; }
|
||||
const values_t& get_values() const { return values_;}
|
||||
int get_axis(size_t k) const { return axes_.at(k); }
|
||||
std::vector<int> get_axes() const { return axes_; }
|
||||
const int get_order(size_t k) const { return order_.at(k); }
|
||||
// find the position of given axis
|
||||
int find_axis(int to_find) const;
|
||||
|
||||
|
||||
private:
|
||||
id_t id_;
|
||||
axes_t axes_;
|
||||
values_t values_;
|
||||
|
||||
protected:
|
||||
order_t order_;
|
||||
shape_t shape_;
|
||||
};
|
||||
|
||||
class mma_layout: public data_layout {
|
||||
public:
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt,
|
||||
shared_layout* layout_a,
|
||||
shared_layout* layout_b);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||
// accessor
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int spt(size_t k) { return spt_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
|
||||
private:
|
||||
std::vector<int> fpw_;
|
||||
std::vector<int> spw_;
|
||||
std::vector<int> wpt_;
|
||||
std::vector<int> spt_;
|
||||
std::vector<int> rep_;
|
||||
};
|
||||
|
||||
struct scanline_layout: public data_layout {
|
||||
scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
|
||||
public:
|
||||
std::vector<int> mts_;
|
||||
std::vector<int> nts_;
|
||||
};
|
||||
|
||||
struct double_buffer_info_t {
|
||||
ir::value* first;
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
struct N_buffer_info_t {
|
||||
std::vector<ir::value*> firsts; // not necessarily ordered as input order
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
std::map<ir::value*, int> firsts_idx;
|
||||
};
|
||||
|
||||
// abstract for dot and coresponding smem values
|
||||
class shared_layout: public data_layout {
|
||||
private:
|
||||
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
|
||||
|
||||
public:
|
||||
shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
ir::type* get_type() { return ty_; }
|
||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
|
||||
int get_num_stages() const;
|
||||
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
|
||||
size_t get_per_stage_elements() const;
|
||||
size_t get_num_per_phase() { return num_per_phase_; }
|
||||
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
ir::type *ty_;
|
||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||
std::shared_ptr<N_buffer_info_t> N_buffer_;
|
||||
size_t num_per_phase_;
|
||||
ir::value* hmma_dot_a_;
|
||||
ir::value* hmma_dot_b_;
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class layouts {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
// graph creation
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
|
||||
void init_hmma_tile(data_layout& layouts);
|
||||
void init_scanline_tile(data_layout &layouts);
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
data_layout* get(ir::value *v) { return get(layout_of(v));}
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,67 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class phi_node;
|
||||
class function;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class tiles;
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
slot_index end;
|
||||
|
||||
bool contains(slot_index idx) const {
|
||||
return start <= idx && idx < end;
|
||||
}
|
||||
|
||||
bool intersect(const segment &Other){
|
||||
return contains(Other.start) || Other.contains(start);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(layouts *l): layouts_(l){ }
|
||||
// accessors
|
||||
const intervals_map_t& get() const { return intervals_; }
|
||||
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
// analysis
|
||||
layouts *layouts_;
|
||||
intervals_map_t intervals_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,43 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class target;
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
class swizzle {
|
||||
public:
|
||||
// constructor
|
||||
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
|
||||
// accessors
|
||||
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
|
||||
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
|
||||
int get_vec (data_layout* layout) { return vec_.at(layout); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
private:
|
||||
layouts* layouts_;
|
||||
target* tgt_;
|
||||
std::map<data_layout*, int> per_phase_;
|
||||
std::map<data_layout*, int> max_phase_;
|
||||
std::map<data_layout*, int> vec_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,31 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_PASS_H_
|
||||
#define _TRITON_CODEGEN_PASS_H_
|
||||
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
namespace driver{
|
||||
class device;
|
||||
class module;
|
||||
class kernel;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache,
|
||||
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,257 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_GENERATOR_H_
|
||||
#define _TRITON_SELECTION_GENERATOR_H_
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class PHINode;
|
||||
class BasicBlock;
|
||||
class Attribute;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class attribute;
|
||||
class load_inst;
|
||||
class store_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
// forward
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
class swizzle;
|
||||
}
|
||||
// typedef
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Attribute Attribute;
|
||||
typedef llvm::BasicBlock BasicBlock;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
class target;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class adder{
|
||||
public:
|
||||
adder(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class multiplier{
|
||||
public:
|
||||
multiplier(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class geper{
|
||||
public:
|
||||
geper(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
|
||||
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void init_idx(ir::value *x);
|
||||
Instruction* add_barrier();
|
||||
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
|
||||
void finalize_shared_layout(analysis::shared_layout*);
|
||||
void finalize_function(ir::function*);
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
private:
|
||||
Type *cvt(ir::type *ty);
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps,
|
||||
bool force_nc_cache = false);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
Value* bf16_to_fp32(Value *in0);
|
||||
Value* fp32_to_bf16(Value *in0);
|
||||
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
void visit_load_inst(ir::load_inst*);
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_recoalesce_inst(ir::recoalesce_inst*);
|
||||
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
void visit_function(ir::function*);
|
||||
void visit_basic_block(ir::basic_block*);
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
target *tgt_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *shmem_;
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
unsigned num_warps_;
|
||||
bool force_nc_cache_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||
|
||||
/// layout -> base ptr
|
||||
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||
/// offset for double-buffered layout
|
||||
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||
|
||||
/// Base shmem pointer of ir value
|
||||
std::map<ir::value*, Value*> shmems_;
|
||||
std::map<ir::value*, Value*> shoffs_;
|
||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||
/// idx for multi-stage pipeline
|
||||
std::map<analysis::data_layout*, Value*> read_smem_idx_;
|
||||
std::map<analysis::data_layout*, Value*> write_smem_idx_;
|
||||
|
||||
/// triton bb -> llvm bb
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
// helper for creating llvm values
|
||||
adder add;
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
|
||||
/// PHI nodes
|
||||
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,105 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||
virtual ~target() {}
|
||||
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
|
||||
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
|
||||
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
|
||||
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
nvidia_cu_target* as_nvidia();
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
bool is_gpu_;
|
||||
};
|
||||
|
||||
class amd_cl_target: public target {
|
||||
public:
|
||||
amd_cl_target(): target(true){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
public:
|
||||
nvidia_cu_target(int sm): target(true), sm_(sm){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
int sm() { return sm_; }
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
public:
|
||||
cpu_target(): target(false){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 1; }
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,47 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class io_inst;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class align;
|
||||
class layouts;
|
||||
class cts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class coalesce {
|
||||
private:
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,36 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class phi_node;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,24 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class dce {
|
||||
public:
|
||||
dce() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class disassociate {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,72 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class masked_load_async_inst;
|
||||
class value;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class allocation;
|
||||
class liveness;
|
||||
class layouts;
|
||||
class cts;
|
||||
class shared_layout;
|
||||
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class prefetch;
|
||||
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::set<ir::value*> val_set_t;
|
||||
typedef std::vector<ir::value*> val_vec_t;
|
||||
|
||||
private:
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
bool check_safe_war(ir::instruction* i);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
|
||||
transform::prefetch *prefetch, target* tgt):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::allocation *alloc_;
|
||||
transform::prefetch *prefetch_;
|
||||
|
||||
target* tgt_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,54 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class instruction;
|
||||
class trans_inst;
|
||||
class builder;
|
||||
class constant_int;
|
||||
class dot_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class peephole {
|
||||
private:
|
||||
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
private:
|
||||
|
||||
public:
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
target* tgt_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
|
||||
// forward declaration
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
namespace transform {
|
||||
|
||||
class pipeline {
|
||||
public:
|
||||
pipeline(bool has_copy_async, int num_stages)
|
||||
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
|
||||
void run(ir::module &module);
|
||||
|
||||
private:
|
||||
bool has_copy_async_;
|
||||
int num_stages_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -1,27 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
|
||||
#include <set>
|
||||
|
||||
// forward dclaration
|
||||
namespace triton::ir{
|
||||
class module;
|
||||
class value;
|
||||
}
|
||||
|
||||
namespace triton::codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
class prefetch {
|
||||
target* tgt_;
|
||||
std::set<ir::value*> prefetched_vals_;
|
||||
public:
|
||||
prefetch(target *tgt) : tgt_(tgt) {}
|
||||
void run(ir::module &module);
|
||||
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,26 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reorder {
|
||||
public:
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,137 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_BACKEND_H_
|
||||
#define _TRITON_DRIVER_BACKEND_H_
|
||||
|
||||
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "triton/driver/context.h"
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class buffer;
|
||||
class stream;
|
||||
class device;
|
||||
class context;
|
||||
class platform;
|
||||
class module;
|
||||
class kernel;
|
||||
|
||||
struct backend
|
||||
{
|
||||
|
||||
// platforms
|
||||
class platforms
|
||||
{
|
||||
friend class backend;
|
||||
private:
|
||||
static void init();
|
||||
|
||||
public:
|
||||
static void get(std::vector<driver::platform*> &results);
|
||||
|
||||
private:
|
||||
static std::vector<driver::platform*> cache_;
|
||||
};
|
||||
|
||||
// devices
|
||||
class devices
|
||||
{
|
||||
friend class backend;
|
||||
|
||||
private:
|
||||
static void init(const std::vector<platform *> &platforms);
|
||||
|
||||
public:
|
||||
static void get(std::vector<driver::device*>& devs);
|
||||
|
||||
private:
|
||||
static std::vector<driver::device*> cache_;
|
||||
};
|
||||
|
||||
// modules
|
||||
class modules
|
||||
{
|
||||
friend class backend;
|
||||
|
||||
public:
|
||||
static void release();
|
||||
|
||||
private:
|
||||
static std::map<std::tuple<driver::stream*, std::string>, driver::module*> cache_;
|
||||
};
|
||||
|
||||
// kernels
|
||||
class kernels
|
||||
{
|
||||
friend class backend;
|
||||
public:
|
||||
static void release();
|
||||
static driver::kernel* get(driver::module* mod, const std::string & name);
|
||||
private:
|
||||
static std::map<std::tuple<module*, std::string>, driver::kernel*> cache_;
|
||||
};
|
||||
|
||||
// contexts
|
||||
class contexts
|
||||
{
|
||||
friend class backend;
|
||||
private:
|
||||
static void init(const std::vector<device *> &);
|
||||
static void release();
|
||||
public:
|
||||
static driver::context* get_default();
|
||||
|
||||
static driver::context* import(CUcontext ctx)
|
||||
{
|
||||
for(driver::context* x: cache_){
|
||||
driver::cu_context* cu_x = (driver::cu_context*)x;
|
||||
if(*cu_x->cu()==ctx)
|
||||
return x;
|
||||
}
|
||||
cache_.emplace_back(new driver::cu_context(ctx, false));
|
||||
return cache_.back();
|
||||
}
|
||||
|
||||
static void get(std::list<driver::context*> &);
|
||||
|
||||
private:
|
||||
static std::list<driver::context*> cache_;
|
||||
};
|
||||
|
||||
// streams
|
||||
class streams
|
||||
{
|
||||
friend class backend;
|
||||
private:
|
||||
static void init(std::list<context*> const &);
|
||||
static void release();
|
||||
public:
|
||||
static void get(driver::context*, std::vector<driver::stream *> &streams);
|
||||
static driver::stream* get(driver::context*, unsigned int id = 0);
|
||||
static driver::stream* get_default();
|
||||
private:
|
||||
static std::map<driver::context*, std::vector<driver::stream*> > cache_;
|
||||
};
|
||||
|
||||
static void init();
|
||||
static void release();
|
||||
static void synchronize(triton::driver::context *);
|
||||
|
||||
static unsigned int default_device;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,48 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_BUFFER_H_
|
||||
#define _TRITON_DRIVER_BUFFER_H_
|
||||
|
||||
#include "triton/driver/handle.h"
|
||||
#include "triton/driver/context.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class stream;
|
||||
|
||||
// Base
|
||||
class buffer : public polymorphic_resource<CUdeviceptr, host_buffer_t> {
|
||||
public:
|
||||
buffer(size_t size, CUdeviceptr cl, bool take_ownership);
|
||||
buffer(size_t size, host_buffer_t hst, bool take_ownership);
|
||||
uintptr_t addr_as_uintptr_t();
|
||||
static buffer* create(driver::context* ctx, size_t size);
|
||||
size_t size();
|
||||
|
||||
protected:
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
// CPU
|
||||
class host_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
host_buffer(size_t size);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_buffer: public buffer
|
||||
{
|
||||
public:
|
||||
cu_buffer(size_t size);
|
||||
cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership);
|
||||
void set_zero(triton::driver::stream *queue, size_t size);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,50 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_CONTEXT_H_
|
||||
#define _TRITON_DRIVER_CONTEXT_H_
|
||||
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/handle.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class context: public polymorphic_resource<CUcontext, host_context_t>{
|
||||
protected:
|
||||
static std::string get_cache_path();
|
||||
|
||||
public:
|
||||
context(driver::device *dev, CUcontext cu, bool take_ownership);
|
||||
context(driver::device *dev, host_context_t hst, bool take_ownership);
|
||||
driver::device* device() const;
|
||||
std::string const & cache_path() const;
|
||||
// factory methods
|
||||
static context* create(driver::device *dev);
|
||||
|
||||
protected:
|
||||
driver::device* dev_;
|
||||
std::string cache_path_;
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_context: public context {
|
||||
public:
|
||||
host_context(driver::device* dev);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_context: public context {
|
||||
private:
|
||||
static CUdevice get_device_of(CUcontext);
|
||||
public:
|
||||
//Constructors
|
||||
cu_context(CUcontext cu, bool take_ownership = true);
|
||||
cu_context(driver::device* dev);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,81 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_DEVICE_H_
|
||||
#define _TRITON_DRIVER_DEVICE_H_
|
||||
|
||||
#include "triton/driver/platform.h"
|
||||
#include "triton/driver/handle.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace codegen
|
||||
{
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class context;
|
||||
|
||||
// Base device
|
||||
class device: public polymorphic_resource<CUdevice, host_device_t>{
|
||||
public:
|
||||
using polymorphic_resource::polymorphic_resource;
|
||||
virtual size_t max_threads_per_block() const = 0;
|
||||
virtual size_t max_shared_memory() const = 0;
|
||||
virtual std::unique_ptr<codegen::target> make_target() const = 0;
|
||||
};
|
||||
|
||||
// Host device
|
||||
class host_device: public device {
|
||||
public:
|
||||
host_device(): device(host_device_t(), true){ }
|
||||
size_t max_threads_per_block() const { return 1; }
|
||||
size_t max_shared_memory() const { return 0; }
|
||||
std::unique_ptr<codegen::target> make_target() const;
|
||||
};
|
||||
|
||||
// CUDA device
|
||||
class cu_device: public device {
|
||||
private:
|
||||
//Metaprogramming elper to get cuda info from attribute
|
||||
template<CUdevice_attribute attr>
|
||||
int cuGetInfo() const;
|
||||
|
||||
inline nvmlDevice_t nvml_device() const;
|
||||
|
||||
public:
|
||||
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
|
||||
// Informations
|
||||
std::string infos() const;
|
||||
size_t address_bits() const;
|
||||
std::vector<size_t> max_block_dim() const;
|
||||
size_t warp_size() const;
|
||||
// Compute Capability
|
||||
void interpret_as(int cc);
|
||||
int compute_capability() const;
|
||||
// Identifier
|
||||
std::string name() const;
|
||||
std::string pci_bus_id() const;
|
||||
// Clocks
|
||||
size_t current_sm_clock() const;
|
||||
size_t current_mem_clock() const;
|
||||
size_t max_threads_per_block() const;
|
||||
size_t max_shared_memory() const;
|
||||
size_t max_sm_clock() const;
|
||||
size_t max_mem_clock() const;
|
||||
void set_max_clock();
|
||||
// Target
|
||||
std::unique_ptr<codegen::target> make_target() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<int> interpreted_as_;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,197 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_DISPATCH_H_
|
||||
#define _TRITON_DRIVER_DISPATCH_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <dlfcn.h>
|
||||
|
||||
//CUDA Backend
|
||||
#include "triton/external/CUDA/cuda.h"
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
|
||||
//Exceptions
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace llvm {
|
||||
class PassRegistry;
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class cu_context;
|
||||
|
||||
template<class T> void check(T){}
|
||||
void check(CUresult err);
|
||||
|
||||
class dispatch
|
||||
{
|
||||
protected:
|
||||
template <class F>
|
||||
struct return_type;
|
||||
|
||||
template <class R, class... A>
|
||||
struct return_type<R (*)(A...)>
|
||||
{ typedef R type; };
|
||||
|
||||
typedef bool (*f_init_t)();
|
||||
|
||||
template<f_init_t initializer, typename FunPtrT, typename... Args>
|
||||
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
|
||||
{
|
||||
initializer();
|
||||
if(cache == nullptr){
|
||||
cache = dlsym(lib_h, name);
|
||||
if(cache == 0)
|
||||
throw std::runtime_error("dlsym unable to load function");
|
||||
}
|
||||
FunPtrT fptr;
|
||||
*reinterpret_cast<void **>(&fptr) = cache;
|
||||
typename return_type<FunPtrT>::type res = (*fptr)(args...);
|
||||
check(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
static bool nvmlinit();
|
||||
static bool cuinit();
|
||||
static bool spvllvminit();
|
||||
static void release();
|
||||
|
||||
// CUDA
|
||||
static CUresult cuCtxGetCurrent(CUcontext *pctx);
|
||||
static CUresult cuCtxSetCurrent(CUcontext ctx);
|
||||
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
||||
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
||||
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
||||
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
|
||||
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
|
||||
static CUresult cuMemFree_v2(CUdeviceptr dptr);
|
||||
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuDriverGetVersion(int *driverVersion);
|
||||
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
static CUresult cuModuleUnload(CUmodule hmod);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||
static CUresult cuLinkDestroy(CUlinkState state);
|
||||
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
|
||||
static CUresult cuInit(unsigned int Flags);
|
||||
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
|
||||
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
|
||||
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
|
||||
static CUresult cuStreamSynchronize(CUstream hStream);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
|
||||
static CUresult cuStreamDestroy_v2(CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
|
||||
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
|
||||
static CUresult cuFuncSetCacheConfig (CUfunction hfunc, CUfunc_cache config);
|
||||
// NVML
|
||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
||||
|
||||
|
||||
// SPIR-V libraries
|
||||
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
||||
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
|
||||
|
||||
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
static void* cuda_;
|
||||
static void* nvml_;
|
||||
static void* vulkan_;
|
||||
static void* spvllvm_;
|
||||
static void* spvcross_;
|
||||
static void* opengl_;
|
||||
|
||||
|
||||
// CUDA functions
|
||||
static void* cuCtxGetCurrent_;
|
||||
static void* cuCtxSetCurrent_;
|
||||
static void* cuCtxDestroy_v2_;
|
||||
static void* cuEventCreate_;
|
||||
static void* cuDeviceGet_;
|
||||
static void* cuMemcpyDtoH_v2_;
|
||||
static void* cuStreamCreate_;
|
||||
static void* cuEventElapsedTime_;
|
||||
static void* cuMemFree_v2_;
|
||||
static void* cuMemcpyDtoHAsync_v2_;
|
||||
static void* cuDriverGetVersion_;
|
||||
static void* cuDeviceGetName_;
|
||||
static void* cuDeviceGetPCIBusId_;
|
||||
static void* cuModuleGetGlobal_v2_;
|
||||
static void* cuMemcpyHtoDAsync_v2_;
|
||||
static void* cuModuleLoad_;
|
||||
static void* cuLaunchKernel_;
|
||||
static void* cuModuleUnload_;
|
||||
static void* cuModuleLoadDataEx_;
|
||||
static void* cuLinkAddData_v2_;
|
||||
static void* cuLinkCreate_v2_;
|
||||
static void* cuLinkDestroy_;
|
||||
static void* cuModuleLoadData_;
|
||||
static void* cuLinkComplete_;
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
static void* cuMemcpyHtoD_v2_;
|
||||
static void* cuInit_;
|
||||
static void* cuEventRecord_;
|
||||
static void* cuCtxCreate_v2_;
|
||||
static void* cuModuleGetFunction_;
|
||||
static void* cuStreamSynchronize_;
|
||||
static void* cuStreamDestroy_v2_;
|
||||
static void* cuStreamGetCtx_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
static void* cuCtxGetDevice_;
|
||||
static void* cuMemsetD8Async_;
|
||||
static void* cuCtxPushCurrent_v2_;
|
||||
static void* cuCtxPopCurrent_v2_;
|
||||
static void* cuFuncGetAttribute_;
|
||||
static void* cuFuncSetAttribute_;
|
||||
static void* cuFuncSetCacheConfig_;
|
||||
// NVML
|
||||
static void* nvmlInit_v2_;
|
||||
static void* nvmlDeviceGetHandleByPciBusId_v2_;
|
||||
static void* nvmlDeviceGetClockInfo_;
|
||||
static void* nvmlDeviceGetMaxClockInfo_;
|
||||
static void* nvmlDeviceSetApplicationsClocks_;
|
||||
|
||||
// LLVM to SPIR-V
|
||||
static void* initializeLLVMToSPIRVPass_;
|
||||
static void* writeSpirv_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,148 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_ERROR_H_
|
||||
#define _TRITON_DRIVER_ERROR_H_
|
||||
|
||||
#include <exception>
|
||||
#include "triton/driver/dispatch.h"
|
||||
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
namespace exception
|
||||
{
|
||||
|
||||
namespace nvrtc
|
||||
{
|
||||
|
||||
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
||||
|
||||
#undef TRITON_CREATE_NVRTC_EXCEPTION
|
||||
}
|
||||
|
||||
|
||||
namespace cuda
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
|
||||
|
||||
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||
}
|
||||
|
||||
namespace cublas
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUBLAS_EXCEPTION
|
||||
}
|
||||
|
||||
namespace cudnn
|
||||
{
|
||||
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,146 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_HANDLE_H_
|
||||
#define _TRITON_DRIVER_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include "triton/driver/dispatch.h"
|
||||
#include "llvm/ExecutionEngine/JITSymbol.h"
|
||||
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/Core.h"
|
||||
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
|
||||
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
|
||||
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
|
||||
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "triton/tools/thread_pool.h"
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
class ExecutionEngine;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
enum backend_t {
|
||||
CUDA,
|
||||
Host
|
||||
};
|
||||
|
||||
// Host handles
|
||||
struct host_platform_t{
|
||||
|
||||
};
|
||||
|
||||
struct host_device_t{
|
||||
|
||||
};
|
||||
|
||||
struct host_context_t{
|
||||
|
||||
};
|
||||
|
||||
struct host_stream_t{
|
||||
std::shared_ptr<ThreadPool> pool;
|
||||
std::shared_ptr<std::vector<std::future<void>>> futures;
|
||||
std::vector<std::shared_ptr<char*>> args;
|
||||
};
|
||||
|
||||
struct host_module_t{
|
||||
std::string error;
|
||||
llvm::ExecutionEngine* engine;
|
||||
std::map<std::string, llvm::Function*> functions;
|
||||
void(*fn)(char**, int32_t, int32_t, int32_t);
|
||||
llvm::orc::ExecutionSession* ES;
|
||||
llvm::orc::RTDyldObjectLinkingLayer* ObjectLayer;
|
||||
llvm::orc::IRCompileLayer* CompileLayer;
|
||||
llvm::DataLayout* DL;
|
||||
llvm::orc::MangleAndInterner* Mangle;
|
||||
llvm::orc::ThreadSafeContext* Ctx;
|
||||
llvm::orc::JITDylib *MainJD;
|
||||
};
|
||||
|
||||
struct host_function_t{
|
||||
llvm::Function* fn;
|
||||
};
|
||||
|
||||
struct host_buffer_t{
|
||||
char* data;
|
||||
};
|
||||
|
||||
|
||||
// Extra CUDA handles
|
||||
struct cu_event_t{
|
||||
operator bool() const { return first && second; }
|
||||
CUevent first;
|
||||
CUevent second;
|
||||
};
|
||||
|
||||
struct CUPlatform{
|
||||
CUPlatform() : status_(dispatch::cuInit(0)) { }
|
||||
operator bool() const { return status_; }
|
||||
private:
|
||||
CUresult status_;
|
||||
};
|
||||
|
||||
template<class T, class CUType>
|
||||
class handle_interface{
|
||||
public:
|
||||
//Accessors
|
||||
operator CUType() const { return *(((T*)this)->cu().h_); }
|
||||
//Comparison
|
||||
bool operator==(handle_interface const & y) { return (CUType)(*this) == (CUType)(y); }
|
||||
bool operator!=(handle_interface const & y) { return (CUType)(*this) != (CUType)(y); }
|
||||
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class handle{
|
||||
public:
|
||||
template<class, class> friend class handle_interface;
|
||||
public:
|
||||
//Constructors
|
||||
handle(T h, bool take_ownership = true);
|
||||
handle();
|
||||
~handle();
|
||||
T& operator*() { return *h_; }
|
||||
T const & operator*() const { return *h_; }
|
||||
T* operator->() const { return h_.get(); }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<T> h_;
|
||||
bool has_ownership_;
|
||||
};
|
||||
|
||||
template<class CUType, class HostType>
|
||||
class polymorphic_resource {
|
||||
public:
|
||||
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){}
|
||||
polymorphic_resource(HostType hst, bool take_ownership): hst_(hst, take_ownership), backend_(Host){}
|
||||
virtual ~polymorphic_resource() { }
|
||||
|
||||
handle<CUType> cu() { return cu_; }
|
||||
handle<HostType> hst() { return hst_; }
|
||||
const handle<CUType>& cu() const { return cu_; }
|
||||
const handle<HostType>& hst() const { return hst_; }
|
||||
backend_t backend() { return backend_; }
|
||||
|
||||
protected:
|
||||
handle<CUType> cu_;
|
||||
handle<HostType> hst_;
|
||||
backend_t backend_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,53 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_KERNEL_H_
|
||||
#define _TRITON_DRIVER_KERNEL_H_
|
||||
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/handle.h"
|
||||
#include <memory>
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
class GenericValue;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class cu_buffer;
|
||||
|
||||
// Base
|
||||
class kernel: public polymorphic_resource<CUfunction, host_function_t> {
|
||||
public:
|
||||
kernel(driver::module* program, CUfunction fn, bool has_ownership);
|
||||
kernel(driver::module* program, host_function_t fn, bool has_ownership);
|
||||
driver::module* module();
|
||||
static kernel* create(driver::module* program, const char* name);
|
||||
private:
|
||||
driver::module* program_;
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_kernel: public kernel {
|
||||
public:
|
||||
//Constructors
|
||||
host_kernel(driver::module* program, const char* name);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_kernel: public kernel {
|
||||
public:
|
||||
//Constructors
|
||||
cu_kernel(driver::module* program, const char * name);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -1,84 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_MODULE_H_
|
||||
#define _TRITON_DRIVER_MODULE_H_
|
||||
|
||||
#include <map>
|
||||
#include "triton/driver/handle.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
class Module;
|
||||
template<class T>
|
||||
class SmallVectorImpl;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class cu_context;
|
||||
class cu_device;
|
||||
|
||||
// Base
|
||||
class module: public polymorphic_resource<CUmodule, host_module_t> {
|
||||
protected:
|
||||
void init_llvm();
|
||||
|
||||
enum file_type_t{
|
||||
Object,
|
||||
Assembly
|
||||
};
|
||||
|
||||
public:
|
||||
module(CUmodule mod, bool has_ownership);
|
||||
module(host_module_t mod, bool has_ownership);
|
||||
static module* create(driver::device* device, std::unique_ptr<llvm::Module> src);
|
||||
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string &features,
|
||||
file_type_t file_type);
|
||||
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
|
||||
int spilled() const { return spilled_; }
|
||||
|
||||
protected:
|
||||
int spilled_;
|
||||
};
|
||||
|
||||
// CPU
|
||||
class host_module: public module{
|
||||
public:
|
||||
host_module(std::unique_ptr<llvm::Module> module);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_module: public module {
|
||||
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
|
||||
void init_from_ptx(const std::string& ptx, cu_device *device);
|
||||
|
||||
public:
|
||||
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
|
||||
cu_module(driver::device* device, const std::string& source);
|
||||
std::unique_ptr<buffer> symbol(const char * name) const;
|
||||
std::string llir() const { return llir_; }
|
||||
const std::string& ptx() const { return ptx_; }
|
||||
const std::string& cubin() const { return cubin_; }
|
||||
|
||||
private:
|
||||
std::string ptx_;
|
||||
std::string cubin_;
|
||||
std::string llir_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,58 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_PLATFORM_H_
|
||||
#define _TRITON_DRIVER_PLATFORM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "triton/driver/handle.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class device;
|
||||
|
||||
class platform
|
||||
{
|
||||
public:
|
||||
// Constructor
|
||||
platform(const std::string& name): name_(name){ }
|
||||
// Accessors
|
||||
std::string name() const { return name_; }
|
||||
// Virtual methods
|
||||
virtual std::string version() const = 0;
|
||||
virtual void devices(std::vector<driver::device *> &devices) const = 0;
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_platform: public platform
|
||||
{
|
||||
public:
|
||||
cu_platform(): platform("CUDA") { }
|
||||
std::string version() const;
|
||||
void devices(std::vector<driver::device*> &devices) const;
|
||||
|
||||
private:
|
||||
handle<CUPlatform> cu_;
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_platform: public platform
|
||||
{
|
||||
public:
|
||||
host_platform(): platform("CPU") { }
|
||||
std::string version() const;
|
||||
void devices(std::vector<driver::device*> &devices) const;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,68 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_STREAM_H_
|
||||
#define _TRITON_DRIVER_STREAM_H_
|
||||
|
||||
#include <map>
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/handle.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class kernel;
|
||||
class event;
|
||||
class Range;
|
||||
class cu_buffer;
|
||||
|
||||
// Base
|
||||
class stream: public polymorphic_resource<CUstream, host_stream_t> {
|
||||
public:
|
||||
stream(CUstream, bool has_ownership);
|
||||
stream(host_stream_t, bool has_ownership);
|
||||
// factory
|
||||
static driver::stream* create(backend_t backend);
|
||||
// methods
|
||||
virtual void synchronize() = 0;
|
||||
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
|
||||
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
|
||||
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
|
||||
// template helpers
|
||||
template<class T> void write(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T> const & x)
|
||||
{ write(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
template<class T> void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T>& x)
|
||||
{ read(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
|
||||
};
|
||||
|
||||
// Host
|
||||
class host_stream: public stream {
|
||||
public:
|
||||
host_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
||||
// CUDA
|
||||
class cu_stream: public stream {
|
||||
public:
|
||||
cu_stream(CUstream str, bool take_ownership);
|
||||
cu_stream();
|
||||
void synchronize();
|
||||
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
|
||||
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
|
||||
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
1468
include/triton/external/CL/cl.h
vendored
1468
include/triton/external/CL/cl.h
vendored
File diff suppressed because it is too large
Load Diff
12947
include/triton/external/CL/cl.hpp
vendored
12947
include/triton/external/CL/cl.hpp
vendored
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user