Compare commits
325 Commits
master
...
keren/inse
Author | SHA1 | Date | |
---|---|---|---|
|
43408fef5a | ||
|
e817fdf1b9 | ||
|
8dd099beef | ||
|
f20f48a255 | ||
|
3eff110fbc | ||
|
5f85b79718 | ||
|
bab7338965 | ||
|
74f3d7a80f | ||
|
115cd3ac47 | ||
|
532e10cf87 | ||
|
16e973edf2 | ||
|
b539e031e8 | ||
|
46fa29496c | ||
|
9490252261 | ||
|
e419781978 | ||
|
189491727a | ||
|
e0072d210a | ||
|
2fa17588f7 | ||
|
e057c65cf0 | ||
|
99c7e0e008 | ||
|
f2fcaeabf3 | ||
|
8edfe813a5 | ||
|
4d64589b22 | ||
|
521ff9ad74 | ||
|
c280ebda1b | ||
|
9def1bcebf | ||
|
7d90a07d0b | ||
|
6461254fb5 | ||
|
4e6a8209ed | ||
|
9bb54402b3 | ||
|
66c36c4378 | ||
|
661be523c0 | ||
|
c87fbf886e | ||
|
0c1d4d764e | ||
|
9d31998a9d | ||
|
04ec5deb41 | ||
|
630dc315ee | ||
|
35c9ec1103 | ||
|
f63be0e9b5 | ||
|
153aecb339 | ||
|
f98aed1258 | ||
|
ace7d28736 | ||
|
b688f7b7b8 | ||
|
8925c2cd11 | ||
|
2e33352419 | ||
|
037f9efa95 | ||
|
07786dc932 | ||
|
2afebcd79b | ||
|
136668bac3 | ||
|
04b852e031 | ||
|
85cccfb81f | ||
|
23f71daa27 | ||
|
4d64ffb5fe | ||
|
6c5f646f4e | ||
|
e8994209f4 | ||
|
8a5647782d | ||
|
afaf59b0c9 | ||
|
dab4855bdf | ||
|
9ea6135eb5 | ||
|
5eee738df7 | ||
|
37f5846280 | ||
|
a22ff39017 | ||
|
4c4159c6fa | ||
|
c28cfd821b | ||
|
1eedaf7bec | ||
|
516a241234 | ||
|
f40c63fb03 | ||
|
2aa538ec2e | ||
|
57fd1864a7 | ||
|
4946167241 | ||
|
8832e32683 | ||
|
4640023d9b | ||
|
0c87360657 | ||
|
de5b84c476 | ||
|
e517b58d59 | ||
|
2da71b2aaa | ||
|
080b4addf8 | ||
|
303790da88 | ||
|
137344946f | ||
|
976cf12af1 | ||
|
b6f15e214b | ||
|
84ad215268 | ||
|
fdd59900f7 | ||
|
a4ff0c362c | ||
|
b6dbe959f0 | ||
|
4218e68d74 | ||
|
61f2ff98df | ||
|
91a9773b38 | ||
|
847a318a03 | ||
|
5feb6e24f9 | ||
|
12d60cb4a3 | ||
|
c9d84237e8 | ||
|
cdc0ec5077 | ||
|
031c2ae77b | ||
|
cb1b87a688 | ||
|
e61dc75942 | ||
|
71428194a1 | ||
|
7dfab26a39 | ||
|
82834d34f9 | ||
|
f2106d0aa2 | ||
|
ac0f6793cc | ||
|
3685194456 | ||
|
3b80801dff | ||
|
42db3538e4 | ||
|
3e6cc6d66c | ||
|
bb7008651a | ||
|
4dc2396ca0 | ||
|
a2cbe7af91 | ||
|
fcb228d1d4 | ||
|
877844de4f | ||
|
3aa8296b06 | ||
|
1bf59d315c | ||
|
bb0f9235d1 | ||
|
c4726333bf | ||
|
dc0588a898 | ||
|
0d22d2bc03 | ||
|
4464646efb | ||
|
38a80664b5 | ||
|
e948a618b3 | ||
|
5898352f97 | ||
|
963d031247 | ||
|
1baa4e125f | ||
|
623c99609f | ||
|
b6e5a231e5 | ||
|
555f94f9b9 | ||
|
ccc5ab6ac9 | ||
|
89f6e1db5e | ||
|
863578a7fa | ||
|
448d14a598 | ||
|
1d772cd843 | ||
|
498c685b46 | ||
|
e843257295 | ||
|
289ff293cc | ||
|
f9d7f2f126 | ||
|
baba98ad69 | ||
|
9ddf0921fb | ||
|
df8d276089 | ||
|
3a84278530 | ||
|
61b61755e5 | ||
|
1e91ed30d0 | ||
|
8bb09f83ee | ||
|
22ec22c257 | ||
|
ecd1bc33df | ||
|
c56f0198dd | ||
|
922155f1d2 | ||
|
23f424c660 | ||
|
940ef3f0ac | ||
|
15bfd0cb79 | ||
|
13669b46a6 | ||
|
e9e1a4e682 | ||
|
80e3fb5270 | ||
|
43be75ad42 | ||
|
2e08450c80 | ||
|
297d27e1c8 | ||
|
c14dff2190 | ||
|
16aed94ff5 | ||
|
9bd5a3dcd2 | ||
|
2a852044d9 | ||
|
a9464f4993 | ||
|
35e346bcff | ||
|
a0bab9748e | ||
|
ea175f689e | ||
|
d0b4c67b05 | ||
|
3c635449e5 | ||
|
328b87aec6 | ||
|
d01353de07 | ||
|
02ebf24d35 | ||
|
83287d7193 | ||
|
bedbf221c0 | ||
|
84aa7d025a | ||
|
1b513c9866 | ||
|
0ebef11c77 | ||
|
de2dd04c8a | ||
|
92ef552a54 | ||
|
10ba51c3bb | ||
|
9aa00249a6 | ||
|
192be76b3c | ||
|
e0bedeb44c | ||
|
8776ad1a0e | ||
|
d69ce77b19 | ||
|
fc58250a06 | ||
|
b1673caaf6 | ||
|
95bbac41e7 | ||
|
993ba7035a | ||
|
e5ec8e16ea | ||
|
d5856435d7 | ||
|
2ba9a83465 | ||
|
3a48ca0d4d | ||
|
83ef74f248 | ||
|
920723cf3d | ||
|
490d34e0d5 | ||
|
78ebbe24c7 | ||
|
a7b49b3227 | ||
|
b988bae813 | ||
|
3236642e8f | ||
|
d1593e6ca8 | ||
|
e02c82c765 | ||
|
432c3df265 | ||
|
6d62d88d4f | ||
|
25357083e6 | ||
|
3265e0df5a | ||
|
96cc6fb563 | ||
|
27c9f3d8cb | ||
|
7eda373a12 | ||
|
a633d2b403 | ||
|
df940aaab0 | ||
|
63e6a85901 | ||
|
65237f6117 | ||
|
9d1b5e3f79 | ||
|
53cf93ce6a | ||
|
64d0b87ef0 | ||
|
9feb256b71 | ||
|
35736aa44e | ||
|
22c65a53d9 | ||
|
0ee6e486f8 | ||
|
117a402c1b | ||
|
49d1821149 | ||
|
26fcc12afd | ||
|
7b09b5f9e9 | ||
|
560e29229b | ||
|
0e11435448 | ||
|
366dddc3bc | ||
|
7807f64ef3 | ||
|
bbf75b492f | ||
|
a4a2c72173 | ||
|
d5eca56cf3 | ||
|
55cf9a0a97 | ||
|
830fe19d58 | ||
|
935390dc03 | ||
|
e36a54eb86 | ||
|
41d338d848 | ||
|
c529b462f5 | ||
|
71d1c10e19 | ||
|
9308e9c90c | ||
|
441fd7c3cc | ||
|
e6f89a5777 | ||
|
9b670cfb9f | ||
|
a2c9f919a8 | ||
|
36c45ec687 | ||
|
39b1235082 | ||
|
79298d61bc | ||
|
c3c4ac3733 | ||
|
e3916c3a46 | ||
|
0e68e6eb59 | ||
|
7027af9666 | ||
|
7e0e7ec365 | ||
|
978463ba39 | ||
|
d23d7b244c | ||
|
1a4fbed25b | ||
|
96876a46d1 | ||
|
0c5319eed9 | ||
|
26c59e4718 | ||
|
a96fe07e1c | ||
|
2d281cbc0a | ||
|
b9279d2e3b | ||
|
3ad7bee35e | ||
|
5f08e2fdae | ||
|
75d32e2442 | ||
|
1428185c9c | ||
|
4ece9fd1f3 | ||
|
d9017f8593 | ||
|
2c6a213131 | ||
|
2239ac1998 | ||
|
012e8c5b2b | ||
|
513bcaee50 | ||
|
29859605ee | ||
|
38d13ae618 | ||
|
edca91bf8f | ||
|
8dfe78f6cf | ||
|
c70f6b666e | ||
|
74585fb970 | ||
|
81001d318c | ||
|
62a64ff29b | ||
|
9e304cf79d | ||
|
1c52bd587d | ||
|
44d75cf9bb | ||
|
9be2d655a3 | ||
|
f51e0b1be4 | ||
|
7e0fd97965 | ||
|
4eb062f313 | ||
|
fcbbb3c10e | ||
|
19f81b7dea | ||
|
9c7b3d5173 | ||
|
aa6e086881 | ||
|
f1cc67bbc3 | ||
|
28e96bbfd1 | ||
|
a3d0812d27 | ||
|
62f7609612 | ||
|
13aead4808 | ||
|
6002340456 | ||
|
0864b253bb | ||
|
62f772123c | ||
|
040a2b6c75 | ||
|
6b4da6f016 | ||
|
16d44e5c4c | ||
|
9cf4107990 | ||
|
39fad2b18a | ||
|
d7fbddc7d4 | ||
|
c7ad928e60 | ||
|
76d9249724 | ||
|
0f96da336a | ||
|
9df899b291 | ||
|
c71c50cd0c | ||
|
61413b8a97 | ||
|
9dafa0e2e3 | ||
|
bde103fab0 | ||
|
4ad432f1fc | ||
|
2041b67fbf | ||
|
e381dc72c5 | ||
|
e95d98a886 | ||
|
38e67b4293 | ||
|
0d139ec460 | ||
|
c53f3486e4 | ||
|
ba16116f96 | ||
|
fed9925bbd | ||
|
a17fba86b1 | ||
|
5e117966d0 | ||
|
d5612333c0 | ||
|
07881b4d41 | ||
|
cf7fc8d642 | ||
|
78c3480c85 | ||
|
14a71dcb6f | ||
|
f2ab318614 | ||
|
419bbe0f6e | ||
|
a2c31ff434 |
1
.clang-format
Normal file
1
.clang-format
Normal file
@@ -0,0 +1 @@
|
||||
BasedOnStyle: LLVM
|
57
.github/CODEOWNERS
vendored
Normal file
57
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
# @global-owner1 and @global-owner2 will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @ptillet
|
||||
|
||||
# --------
|
||||
# Analyses
|
||||
# --------
|
||||
# Alias analysis
|
||||
include/triton/Analysis/Alias.h @Jokeren
|
||||
lib/Analysis/Alias.cpp @Jokeren
|
||||
# Allocation analysis
|
||||
include/triton/Analysis/Allocation.h @Jokeren
|
||||
lib/Analysis/Allocation.cpp @Jokeren
|
||||
# Membar analysis
|
||||
include/triton/Analysis/Membar.h @Jokeren
|
||||
lib/Analysis/Membar.cpp @Jokeren
|
||||
# AxisInfo analysis
|
||||
include/triton/Analysis/AxisInfo.h @ptillet
|
||||
lib/Analysis/AxisInfo.cpp @ptillet
|
||||
# Utilities
|
||||
include/triton/Analysis/Utility.h @Jokeren
|
||||
lib/Analysis/Utility.cpp @Jokeren
|
||||
|
||||
# ----------
|
||||
# Dialects
|
||||
# ----------
|
||||
# Pipeline pass
|
||||
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada
|
||||
# Prefetch pass
|
||||
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @daadaada
|
||||
# Coalesce pass
|
||||
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet
|
||||
# Layout simplification pass
|
||||
lib/Dialect/TritonGPU/Transforms/Combine.cpp @ptillet
|
||||
|
||||
# -----------
|
||||
# Conversions
|
||||
# -----------
|
||||
# TritonGPUToLLVM
|
||||
include/triton/Conversion/TritonGPUToLLVM/ @goostavz @Superjomn
|
||||
lib/Conversions/TritonGPUToLLVM @goostavz @Superjomn
|
||||
# TritonToTritonGPU
|
||||
include/triton/Conversion/TritonToTritonGPU/ @daadaada
|
||||
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @daadaada
|
||||
|
||||
|
||||
# -------
|
||||
# Targets
|
||||
# -------
|
||||
# LLVMIR
|
||||
include/triton/Target/LLVMIR/ @goostavz @Superjomn
|
||||
lib/Target/LLVMIR @goostavz @Superjomn
|
||||
# PTX
|
||||
include/triton/Target/PTX/ @goostavz @Superjomn
|
||||
lib/Target/PTX @goostavz @Superjomn
|
51
.github/workflows/documentation.yml
vendored
51
.github/workflows/documentation.yml
vendored
@@ -1,51 +0,0 @@
|
||||
name: Documentation
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
jobs:
|
||||
|
||||
Build-Documentation:
|
||||
|
||||
runs-on: self-hosted
|
||||
|
||||
steps:
|
||||
|
||||
|
||||
- name: Checkout gh-pages
|
||||
uses: actions/checkout@v1
|
||||
with:
|
||||
ref: 'gh-pages'
|
||||
|
||||
- name: Checkout branch
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Build docs
|
||||
run: |
|
||||
git fetch origin master:master
|
||||
cd docs
|
||||
sphinx-multiversion . _build/html/
|
||||
|
||||
- name: Publish docs
|
||||
run: |
|
||||
git branch
|
||||
# update docs
|
||||
rm -r /tmp/triton-docs;
|
||||
mkdir /tmp/triton-docs;
|
||||
mv docs/_build/html/* /tmp/triton-docs/
|
||||
git checkout gh-pages
|
||||
cp -r CNAME /tmp/triton-docs/
|
||||
cp -r index.html /tmp/triton-docs/
|
||||
cp -r .nojekyll /tmp/triton-docs/
|
||||
rm -r *
|
||||
cp -r /tmp/triton-docs/* .
|
||||
# ln -s master/index.html .
|
||||
# mv master docs
|
||||
git add .
|
||||
git commit -am "[GH-PAGES] Updated website"
|
||||
# publish docs
|
||||
eval `ssh-agent -s`
|
||||
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
|
||||
git remote set-url origin git@github.com:openai/triton.git
|
||||
git push
|
103
.github/workflows/integration-tests.yml
vendored
103
.github/workflows/integration-tests.yml
vendored
@@ -4,52 +4,95 @@ on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
- v2.0
|
||||
|
||||
- main
|
||||
- triton-mlir
|
||||
|
||||
jobs:
|
||||
Runner-Preparation:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: Prepare runner matrix
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||
fi
|
||||
|
||||
Integration-Tests:
|
||||
|
||||
runs-on: self-hosted
|
||||
needs: Runner-Preparation
|
||||
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
|
||||
|
||||
steps:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -r /tmp/triton/
|
||||
continue-on-error: true
|
||||
rm -rf ~/.triton/cache/
|
||||
|
||||
- name: Check imports
|
||||
if: startsWith(matrix.runner, 'ubuntu')
|
||||
run: |
|
||||
pip install isort
|
||||
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check python style
|
||||
if: startsWith(matrix.runner, 'ubuntu')
|
||||
run: |
|
||||
pip install autopep8
|
||||
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check cpp style
|
||||
if: startsWith(matrix.runner, 'ubuntu')
|
||||
run: |
|
||||
pip install clang-format
|
||||
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
|
||||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
|
||||
|
||||
- name: Flake8
|
||||
if: startsWith(matrix.runner, 'ubuntu')
|
||||
run: |
|
||||
pip install flake8
|
||||
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
|
||||
|
||||
- name: Install Triton
|
||||
run: |
|
||||
alias python='python3'
|
||||
cd python
|
||||
pip3 install -e '.[tests]'
|
||||
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
|
||||
|
||||
- name: Check imports
|
||||
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
|
||||
|
||||
- name: Check style
|
||||
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
|
||||
|
||||
- name: Flake8
|
||||
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"
|
||||
|
||||
- name: Unit tests
|
||||
- name: Run lit tests
|
||||
run: |
|
||||
cd python/test/unit
|
||||
pytest -vs .
|
||||
cd python
|
||||
LIT_TEST_DIR="build/$(ls build)/test"
|
||||
if [ ! -d "$LIT_TEST_DIR" ]; then
|
||||
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
|
||||
fi
|
||||
lit -v "$LIT_TEST_DIR"
|
||||
|
||||
- name: Regression tests
|
||||
- name: Run python tests
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}}
|
||||
run: |
|
||||
cd python/test/regression
|
||||
sudo nvidia-smi -i 0 -pm 1
|
||||
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
|
||||
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
|
||||
pytest -vs .
|
||||
sudo nvidia-smi -i 0 -rgc
|
||||
sudo nvidia-smi -i 0 -rmc
|
||||
cd python/tests
|
||||
pytest
|
||||
|
||||
# TODO[Superjomn] Enable all the tests on V100 if available
|
||||
- name: Run python tests on V100
|
||||
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
|
||||
run: |
|
||||
cd python/tests
|
||||
pytest test_gemm.py::test_gemm_no_scf_for_mmav1
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
cd python/
|
||||
cd "build/$(ls build)"
|
||||
ctest
|
||||
|
4
.github/workflows/wheels.yml
vendored
4
.github/workflows/wheels.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
|
||||
Build-Wheels:
|
||||
|
||||
runs-on: self-hosted
|
||||
runs-on: [self-hosted, V100]
|
||||
|
||||
steps:
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
- name: Patch setup.py
|
||||
run: |
|
||||
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
|
||||
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
|
||||
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
|
||||
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
|
||||
echo "" >> python/setup.cfg
|
||||
echo "[build_ext]" >> python/setup.cfg
|
||||
|
17
.gitignore
vendored
17
.gitignore
vendored
@@ -1,9 +1,20 @@
|
||||
# Triton builds
|
||||
build/
|
||||
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
|
||||
# Triton Python module builds
|
||||
python/build/
|
||||
python/triton.egg-info/
|
||||
python/triton/_C/libtriton.pyd
|
||||
python/triton/_C/libtriton.so
|
||||
|
||||
# Python caches
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
|
||||
# VS Code project files
|
||||
.vscode
|
||||
.vs
|
||||
|
||||
# JetBrains project files
|
||||
.idea
|
||||
cmake-build-*
|
||||
|
242
CMakeLists.txt
242
CMakeLists.txt
@@ -3,10 +3,7 @@ include(ExternalProject)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
if(NOT TRITON_LLVM_BUILD_DIR)
|
||||
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
|
||||
endif()
|
||||
|
||||
set(CMAKE_INCLUDE_CURRENT_DIR ON)
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
@@ -15,8 +12,12 @@ if(NOT WIN32)
|
||||
endif()
|
||||
|
||||
# Options
|
||||
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
|
||||
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
|
||||
# Ensure Python3 vars are set correctly
|
||||
# used conditionally in this file and by lit tests
|
||||
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
|
||||
|
||||
# Default build type
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
@@ -31,19 +32,27 @@ endif()
|
||||
# Compiler flags
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
|
||||
# Third-party
|
||||
include_directories(${PYBIND11_INCLUDE_DIR})
|
||||
|
||||
if(WIN32)
|
||||
SET(BUILD_SHARED_LIBS OFF)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
|
||||
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
|
||||
if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
##########
|
||||
# LLVM
|
||||
##########
|
||||
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
||||
if (NOT MLIR_DIR)
|
||||
if(NOT LLVM_LIBRARY_DIR)
|
||||
if(WIN32)
|
||||
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
|
||||
|
||||
@@ -62,95 +71,148 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
|
||||
if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
|
||||
endif()
|
||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||
else()
|
||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||
else()
|
||||
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
|
||||
set(LLVM_LIBRARIES
|
||||
libLLVMNVPTXCodeGen.a
|
||||
libLLVMNVPTXDesc.a
|
||||
libLLVMNVPTXInfo.a
|
||||
libLLVMAMDGPUDisassembler.a
|
||||
libLLVMMCDisassembler.a
|
||||
libLLVMAMDGPUCodeGen.a
|
||||
libLLVMMIRParser.a
|
||||
libLLVMGlobalISel.a
|
||||
libLLVMSelectionDAG.a
|
||||
libLLVMipo.a
|
||||
libLLVMInstrumentation.a
|
||||
libLLVMVectorize.a
|
||||
libLLVMLinker.a
|
||||
libLLVMIRReader.a
|
||||
libLLVMAsmParser.a
|
||||
libLLVMFrontendOpenMP.a
|
||||
libLLVMAsmPrinter.a
|
||||
libLLVMDebugInfoDWARF.a
|
||||
libLLVMCodeGen.a
|
||||
libLLVMTarget.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMInstCombine.a
|
||||
libLLVMAggressiveInstCombine.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMBitWriter.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMProfileData.a
|
||||
libLLVMObject.a
|
||||
libLLVMTextAPI.a
|
||||
libLLVMBitReader.a
|
||||
libLLVMAMDGPUAsmParser.a
|
||||
libLLVMMCParser.a
|
||||
libLLVMAMDGPUDesc.a
|
||||
libLLVMAMDGPUUtils.a
|
||||
libLLVMMC.a
|
||||
libLLVMDebugInfoCodeView.a
|
||||
libLLVMDebugInfoMSF.a
|
||||
libLLVMCore.a
|
||||
libLLVMRemarks.a
|
||||
libLLVMBitstreamReader.a
|
||||
libLLVMBinaryFormat.a
|
||||
libLLVMAMDGPUInfo.a
|
||||
libLLVMSupport.a
|
||||
libLLVMDemangle.a
|
||||
libLLVMPasses.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMipo.a
|
||||
libLLVMObjCARCOpts.a
|
||||
libLLVMCoroutines.a
|
||||
libLLVMAnalysis.a
|
||||
)
|
||||
set(LLVM_LIBRARIES
|
||||
libLLVMNVPTXCodeGen.a
|
||||
libLLVMNVPTXDesc.a
|
||||
libLLVMNVPTXInfo.a
|
||||
libLLVMAMDGPUDisassembler.a
|
||||
libLLVMMCDisassembler.a
|
||||
libLLVMAMDGPUCodeGen.a
|
||||
libLLVMMIRParser.a
|
||||
libLLVMGlobalISel.a
|
||||
libLLVMSelectionDAG.a
|
||||
libLLVMipo.a
|
||||
libLLVMInstrumentation.a
|
||||
libLLVMVectorize.a
|
||||
libLLVMLinker.a
|
||||
libLLVMIRReader.a
|
||||
libLLVMAsmParser.a
|
||||
libLLVMFrontendOpenMP.a
|
||||
libLLVMAsmPrinter.a
|
||||
libLLVMDebugInfoDWARF.a
|
||||
libLLVMCodeGen.a
|
||||
libLLVMTarget.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMInstCombine.a
|
||||
libLLVMAggressiveInstCombine.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMBitWriter.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMProfileData.a
|
||||
libLLVMObject.a
|
||||
libLLVMTextAPI.a
|
||||
libLLVMBitReader.a
|
||||
libLLVMAMDGPUAsmParser.a
|
||||
libLLVMMCParser.a
|
||||
libLLVMAMDGPUDesc.a
|
||||
libLLVMAMDGPUUtils.a
|
||||
libLLVMMC.a
|
||||
libLLVMDebugInfoCodeView.a
|
||||
libLLVMDebugInfoMSF.a
|
||||
libLLVMCore.a
|
||||
libLLVMRemarks.a
|
||||
libLLVMBitstreamReader.a
|
||||
libLLVMBinaryFormat.a
|
||||
libLLVMAMDGPUInfo.a
|
||||
libLLVMSupport.a
|
||||
libLLVMDemangle.a
|
||||
libLLVMPasses.a
|
||||
libLLVMAnalysis.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMScalarOpts.a
|
||||
libLLVMTransformUtils.a
|
||||
libLLVMipo.a
|
||||
libLLVMObjCARCOpts.a
|
||||
libLLVMCoroutines.a
|
||||
libLLVMAnalysis.a
|
||||
)
|
||||
endif()
|
||||
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||
endif()
|
||||
include_directories("${LLVM_INCLUDE_DIRS}")
|
||||
|
||||
# Python module
|
||||
if(BUILD_PYTHON_MODULE)
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
# Build CUTLASS python wrapper if requested
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
|
||||
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
|
||||
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
|
||||
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
|
||||
add_definitions(-DWITH_CUTLASS_BINDINGS)
|
||||
set(CUTLASS_LIBRARIES "cutlass.a")
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
|
||||
include_directories("." ${PYTHON_SRC_PATH})
|
||||
if (PYTHON_INCLUDE_DIRS)
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
else()
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
link_directories(${Python3_LIBRARY_DIRS})
|
||||
link_libraries(${Python3_LIBRARIES})
|
||||
add_link_options(${Python3_LINK_OPTIONS})
|
||||
endif()
|
||||
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
|
||||
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
if (WIN32 AND BUILD_PYTHON_MODULE)
|
||||
find_package(Python3 REQUIRED COMPONENTS Development)
|
||||
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
else()
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
endif()
|
||||
# # Triton
|
||||
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
|
||||
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
# set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
# else()
|
||||
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# endif()
|
||||
|
||||
|
||||
# MLIR
|
||||
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
|
||||
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||
|
||||
include(TableGen) # required by AddMLIR
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
|
||||
# Disable warnings that show up in external code (gtest;pybind11)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
|
||||
|
||||
include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||
# link_directories(${LLVM_LIBRARY_DIR})
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(bin)
|
||||
|
||||
add_library(triton SHARED ${PYTHON_SRC})
|
||||
|
||||
# find_package(PythonLibs REQUIRED)
|
||||
|
||||
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
target_link_libraries(triton
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
TritonLLVMIR
|
||||
TritonPTX
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
MLIRLLVMIR
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRExecutionEngine
|
||||
MLIRMathToLLVM
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
|
||||
@@ -161,7 +223,7 @@ else()
|
||||
endif()
|
||||
|
||||
|
||||
if(BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
|
||||
# Check if the platform is MacOS
|
||||
if(APPLE)
|
||||
@@ -169,3 +231,7 @@ if(BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
endif()
|
||||
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
|
||||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
|
||||
add_subdirectory(unittest)
|
||||
|
4
LICENSE
4
LICENSE
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* Copyright 2018-2020 Philippe Tillet
|
||||
* Copyright 2020-2021 OpenAI
|
||||
* Copyright 2020-2022 OpenAI
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
@@ -20,4 +20,4 @@
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
*/
|
||||
|
60
bin/CMakeLists.txt
Normal file
60
bin/CMakeLists.txt
Normal file
@@ -0,0 +1,60 @@
|
||||
add_subdirectory(FileCheck)
|
||||
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
|
||||
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
|
||||
|
||||
# TODO: what's this?
|
||||
llvm_update_compile_flags(triton-opt)
|
||||
target_link_libraries(triton-opt PRIVATE
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# tests
|
||||
TritonTestAnalysis
|
||||
# MLIR core
|
||||
MLIROptLib
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(triton-opt)
|
||||
|
||||
|
||||
# add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
|
||||
#llvm_update_compile_flags(triton-translate)
|
||||
# target_link_libraries(triton-translate PRIVATE
|
||||
# TritonAnalysis
|
||||
# TritonTransforms
|
||||
# TritonGPUTransforms
|
||||
# TritonLLVMIR
|
||||
# TritonDriver
|
||||
# ${dialect_libs}
|
||||
# ${conversion_libs}
|
||||
# # tests
|
||||
# TritonTestAnalysis
|
||||
|
||||
# LLVMCore
|
||||
# LLVMSupport
|
||||
# LLVMOption
|
||||
# LLVMCodeGen
|
||||
# LLVMAsmParser
|
||||
|
||||
# # MLIR core
|
||||
# MLIROptLib
|
||||
# MLIRIR
|
||||
# MLIRPass
|
||||
# MLIRSupport
|
||||
# MLIRTransforms
|
||||
# MLIRExecutionEngine
|
||||
# MLIRMathToLLVM
|
||||
# MLIRTransformUtils
|
||||
# MLIRLLVMToLLVMIRTranslation
|
||||
# MLIRNVVMToLLVMIRTranslation
|
||||
# )
|
||||
# mlir_check_all_link_libraries(triton-translate)
|
2
bin/FileCheck/CMakeLists.txt
Normal file
2
bin/FileCheck/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_llvm_executable(FileCheck FileCheck.cpp)
|
||||
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
|
882
bin/FileCheck/FileCheck.cpp
Normal file
882
bin/FileCheck/FileCheck.cpp
Normal file
@@ -0,0 +1,882 @@
|
||||
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// FileCheck does a line-by line check of a file that validates whether it
|
||||
// contains the expected content. This is useful for regression tests etc.
|
||||
//
|
||||
// This program exits with an exit status of 2 on error, exit status of 0 if
|
||||
// the file matched the expected contents, and exit status of 1 if it did not
|
||||
// contain the expected contents.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/FileCheck/FileCheck.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Process.h"
|
||||
#include "llvm/Support/WithColor.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
using namespace llvm;
|
||||
|
||||
static cl::extrahelp FileCheckOptsEnv(
|
||||
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
|
||||
"from the command line.\n");
|
||||
|
||||
static cl::opt<std::string>
|
||||
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
|
||||
|
||||
static cl::opt<std::string>
|
||||
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
|
||||
cl::init("-"), cl::value_desc("filename"));
|
||||
|
||||
static cl::list<std::string> CheckPrefixes(
|
||||
"check-prefix",
|
||||
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
|
||||
static cl::alias CheckPrefixesAlias(
|
||||
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
|
||||
cl::NotHidden,
|
||||
cl::desc(
|
||||
"Alias for -check-prefix permitting multiple comma separated values"));
|
||||
|
||||
static cl::list<std::string> CommentPrefixes(
|
||||
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
|
||||
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
|
||||
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
|
||||
"LLVM's LIT-based test suites, which should be easier to\n"
|
||||
"maintain if they all follow a consistent comment style. This\n"
|
||||
"feature is meant for non-LIT test suites using FileCheck."));
|
||||
|
||||
static cl::opt<bool> NoCanonicalizeWhiteSpace(
|
||||
"strict-whitespace",
|
||||
cl::desc("Do not treat all horizontal whitespace as equivalent"));
|
||||
|
||||
static cl::opt<bool> IgnoreCase("ignore-case",
|
||||
cl::desc("Use case-insensitive matching"));
|
||||
|
||||
static cl::list<std::string> ImplicitCheckNot(
|
||||
"implicit-check-not",
|
||||
cl::desc("Add an implicit negative check with this pattern to every\n"
|
||||
"positive check. This can be used to ensure that no instances of\n"
|
||||
"this pattern occur which are not matched by a positive pattern"),
|
||||
cl::value_desc("pattern"));
|
||||
|
||||
static cl::list<std::string>
|
||||
GlobalDefines("D", cl::AlwaysPrefix,
|
||||
cl::desc("Define a variable to be used in capture patterns."),
|
||||
cl::value_desc("VAR=VALUE"));
|
||||
|
||||
static cl::opt<bool> AllowEmptyInput(
|
||||
"allow-empty", cl::init(false),
|
||||
cl::desc("Allow the input file to be empty. This is useful when making\n"
|
||||
"checks that some error message does not occur, for example."));
|
||||
|
||||
static cl::opt<bool> AllowUnusedPrefixes(
|
||||
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Allow prefixes to be specified but not appear in the test."));
|
||||
|
||||
static cl::opt<bool> MatchFullLines(
|
||||
"match-full-lines", cl::init(false),
|
||||
cl::desc("Require all positive matches to cover an entire input line.\n"
|
||||
"Allows leading and trailing whitespace if --strict-whitespace\n"
|
||||
"is not also passed."));
|
||||
|
||||
static cl::opt<bool> EnableVarScope(
|
||||
"enable-var-scope", cl::init(false),
|
||||
cl::desc("Enables scope for regex variables. Variables with names that\n"
|
||||
"do not start with '$' will be reset at the beginning of\n"
|
||||
"each CHECK-LABEL block."));
|
||||
|
||||
static cl::opt<bool> AllowDeprecatedDagOverlap(
|
||||
"allow-deprecated-dag-overlap", cl::init(false),
|
||||
cl::desc("Enable overlapping among matches in a group of consecutive\n"
|
||||
"CHECK-DAG directives. This option is deprecated and is only\n"
|
||||
"provided for convenience as old tests are migrated to the new\n"
|
||||
"non-overlapping CHECK-DAG implementation.\n"));
|
||||
|
||||
static cl::opt<bool> Verbose(
|
||||
"v", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Print directive pattern matches, or add them to the input dump\n"
|
||||
"if enabled.\n"));
|
||||
|
||||
static cl::opt<bool> VerboseVerbose(
|
||||
"vv", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
|
||||
"issues, or add it to the input dump if enabled. Implies\n"
|
||||
"-v.\n"));
|
||||
|
||||
// The order of DumpInputValue members affects their precedence, as documented
|
||||
// for -dump-input below.
|
||||
enum DumpInputValue {
|
||||
DumpInputNever,
|
||||
DumpInputFail,
|
||||
DumpInputAlways,
|
||||
DumpInputHelp
|
||||
};
|
||||
|
||||
static cl::list<DumpInputValue> DumpInputs(
|
||||
"dump-input",
|
||||
cl::desc("Dump input to stderr, adding annotations representing\n"
|
||||
"currently enabled diagnostics. When there are multiple\n"
|
||||
"occurrences of this option, the <value> that appears earliest\n"
|
||||
"in the list below has precedence. The default is 'fail'.\n"),
|
||||
cl::value_desc("mode"),
|
||||
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
|
||||
clEnumValN(DumpInputAlways, "always", "Always dump input"),
|
||||
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
|
||||
clEnumValN(DumpInputNever, "never", "Never dump input")));
|
||||
|
||||
// The order of DumpInputFilterValue members affects their precedence, as
|
||||
// documented for -dump-input-filter below.
|
||||
enum DumpInputFilterValue {
|
||||
DumpInputFilterError,
|
||||
DumpInputFilterAnnotation,
|
||||
DumpInputFilterAnnotationFull,
|
||||
DumpInputFilterAll
|
||||
};
|
||||
|
||||
static cl::list<DumpInputFilterValue> DumpInputFilters(
|
||||
"dump-input-filter",
|
||||
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
|
||||
"kind <value> plus any context specified by -dump-input-context.\n"
|
||||
"When there are multiple occurrences of this option, the <value>\n"
|
||||
"that appears earliest in the list below has precedence. The\n"
|
||||
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
|
||||
"-dump-input=always.\n"),
|
||||
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
|
||||
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
|
||||
"Input lines with annotations"),
|
||||
clEnumValN(DumpInputFilterAnnotation, "annotation",
|
||||
"Input lines with starting points of annotations"),
|
||||
clEnumValN(DumpInputFilterError, "error",
|
||||
"Input lines with starting points of error "
|
||||
"annotations")));
|
||||
|
||||
static cl::list<unsigned> DumpInputContexts(
|
||||
"dump-input-context", cl::value_desc("N"),
|
||||
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
|
||||
"before and <N> input lines after any lines specified by\n"
|
||||
"-dump-input-filter. When there are multiple occurrences of\n"
|
||||
"this option, the largest specified <N> has precedence. The\n"
|
||||
"default is 5.\n"));
|
||||
|
||||
typedef cl::list<std::string>::const_iterator prefix_iterator;
|
||||
|
||||
static void DumpCommandLine(int argc, char **argv) {
|
||||
errs() << "FileCheck command line: ";
|
||||
for (int I = 0; I < argc; I++)
|
||||
errs() << " " << argv[I];
|
||||
errs() << "\n";
|
||||
}
|
||||
|
||||
struct MarkerStyle {
|
||||
/// The starting char (before tildes) for marking the line.
|
||||
char Lead;
|
||||
/// What color to use for this annotation.
|
||||
raw_ostream::Colors Color;
|
||||
/// A note to follow the marker, or empty string if none.
|
||||
std::string Note;
|
||||
/// Does this marker indicate inclusion by -dump-input-filter=error?
|
||||
bool FiltersAsError;
|
||||
MarkerStyle() {}
|
||||
MarkerStyle(char Lead, raw_ostream::Colors Color,
|
||||
const std::string &Note = "", bool FiltersAsError = false)
|
||||
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
|
||||
assert((!FiltersAsError || !Note.empty()) &&
|
||||
"expected error diagnostic to have note");
|
||||
}
|
||||
};
|
||||
|
||||
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
|
||||
switch (MatchTy) {
|
||||
case FileCheckDiag::MatchFoundAndExpected:
|
||||
return MarkerStyle('^', raw_ostream::GREEN);
|
||||
case FileCheckDiag::MatchFoundButExcluded:
|
||||
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFoundButWrongLine:
|
||||
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFoundButDiscarded:
|
||||
return MarkerStyle('!', raw_ostream::CYAN,
|
||||
"discard: overlaps earlier match");
|
||||
case FileCheckDiag::MatchFoundErrorNote:
|
||||
// Note should always be overridden within the FileCheckDiag.
|
||||
return MarkerStyle('!', raw_ostream::RED,
|
||||
"error: unknown error after match",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchNoneAndExcluded:
|
||||
return MarkerStyle('X', raw_ostream::GREEN);
|
||||
case FileCheckDiag::MatchNoneButExpected:
|
||||
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchNoneForInvalidPattern:
|
||||
return MarkerStyle('X', raw_ostream::RED,
|
||||
"error: match failed for invalid pattern",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFuzzy:
|
||||
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
|
||||
/*FiltersAsError=*/true);
|
||||
}
|
||||
llvm_unreachable_internal("unexpected match type");
|
||||
}
|
||||
|
||||
static void DumpInputAnnotationHelp(raw_ostream &OS) {
|
||||
OS << "The following description was requested by -dump-input=help to\n"
|
||||
<< "explain the input dump printed by FileCheck.\n"
|
||||
<< "\n"
|
||||
<< "Related command-line options:\n"
|
||||
<< "\n"
|
||||
<< " - -dump-input=<value> enables or disables the input dump\n"
|
||||
<< " - -dump-input-filter=<value> filters the input lines\n"
|
||||
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
|
||||
<< " - -v and -vv add more annotations\n"
|
||||
<< " - -color forces colors to be enabled both in the dump and below\n"
|
||||
<< " - -help documents the above options in more detail\n"
|
||||
<< "\n"
|
||||
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
|
||||
<< "maximum debugging output on failures:\n"
|
||||
<< "\n"
|
||||
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
|
||||
<< "\n"
|
||||
<< "Input dump annotation format:\n"
|
||||
<< "\n";
|
||||
|
||||
// Labels for input lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
|
||||
OS << " labels line number L of the input file\n"
|
||||
<< " An extra space is added after each input line to represent"
|
||||
<< " the\n"
|
||||
<< " newline character\n";
|
||||
|
||||
// Labels for annotation lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
|
||||
OS << " labels the only match result for either (1) a pattern of type T"
|
||||
<< " from\n"
|
||||
<< " line L of the check file if L is an integer or (2) the"
|
||||
<< " I-th implicit\n"
|
||||
<< " pattern if L is \"imp\" followed by an integer "
|
||||
<< "I (index origin one)\n";
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
|
||||
OS << " labels the Nth match result for such a pattern\n";
|
||||
|
||||
// Markers on annotation lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
|
||||
OS << " marks good match (reported if -v)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
|
||||
OS << " marks bad match, such as:\n"
|
||||
<< " - CHECK-NEXT on same line as previous match (error)\n"
|
||||
<< " - CHECK-NOT found (error)\n"
|
||||
<< " - CHECK-DAG overlapping match (discarded, reported if "
|
||||
<< "-vv)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
|
||||
OS << " marks search range when no match is found, such as:\n"
|
||||
<< " - CHECK-NEXT not found (error)\n"
|
||||
<< " - CHECK-NOT not found (success, reported if -vv)\n"
|
||||
<< " - CHECK-DAG not found after discarded matches (error)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
|
||||
OS << " marks fuzzy match when no match is found\n";
|
||||
|
||||
// Elided lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
|
||||
OS << " indicates elided input lines and annotations, as specified by\n"
|
||||
<< " -dump-input-filter and -dump-input-context\n";
|
||||
|
||||
// Colors.
|
||||
OS << " - colors ";
|
||||
WithColor(OS, raw_ostream::GREEN, true) << "success";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::RED, true) << "error";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
|
||||
OS << "\n";
|
||||
}
|
||||
|
||||
/// An annotation for a single input line.
|
||||
struct InputAnnotation {
|
||||
/// The index of the match result across all checks
|
||||
unsigned DiagIndex;
|
||||
/// The label for this annotation.
|
||||
std::string Label;
|
||||
/// Is this the initial fragment of a diagnostic that has been broken across
|
||||
/// multiple lines?
|
||||
bool IsFirstLine;
|
||||
/// What input line (one-origin indexing) this annotation marks. This might
|
||||
/// be different from the starting line of the original diagnostic if
|
||||
/// !IsFirstLine.
|
||||
unsigned InputLine;
|
||||
/// The column range (one-origin indexing, open end) in which to mark the
|
||||
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
|
||||
/// before the newline.
|
||||
unsigned InputStartCol, InputEndCol;
|
||||
/// The marker to use.
|
||||
MarkerStyle Marker;
|
||||
/// Whether this annotation represents a good match for an expected pattern.
|
||||
bool FoundAndExpectedMatch;
|
||||
};
|
||||
|
||||
/// Get an abbreviation for the check type.
|
||||
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
|
||||
switch (Ty) {
|
||||
case Check::CheckPlain:
|
||||
if (Ty.getCount() > 1)
|
||||
return "count";
|
||||
return "check";
|
||||
case Check::CheckNext:
|
||||
return "next";
|
||||
case Check::CheckSame:
|
||||
return "same";
|
||||
case Check::CheckNot:
|
||||
return "not";
|
||||
case Check::CheckDAG:
|
||||
return "dag";
|
||||
case Check::CheckLabel:
|
||||
return "label";
|
||||
case Check::CheckEmpty:
|
||||
return "empty";
|
||||
case Check::CheckComment:
|
||||
return "com";
|
||||
case Check::CheckEOF:
|
||||
return "eof";
|
||||
case Check::CheckBadNot:
|
||||
return "bad-not";
|
||||
case Check::CheckBadCount:
|
||||
return "bad-count";
|
||||
case Check::CheckNone:
|
||||
llvm_unreachable("invalid FileCheckType");
|
||||
}
|
||||
llvm_unreachable("unknown FileCheckType");
|
||||
}
|
||||
|
||||
static void
|
||||
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
|
||||
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
|
||||
const std::vector<FileCheckDiag> &Diags,
|
||||
std::vector<InputAnnotation> &Annotations,
|
||||
unsigned &LabelWidth) {
|
||||
struct CompareSMLoc {
|
||||
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
|
||||
return LHS.getPointer() < RHS.getPointer();
|
||||
}
|
||||
};
|
||||
// How many diagnostics does each pattern have?
|
||||
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
|
||||
for (auto Diag : Diags)
|
||||
++DiagCountPerPattern[Diag.CheckLoc];
|
||||
// How many diagnostics have we seen so far per pattern?
|
||||
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
|
||||
// How many total diagnostics have we seen so far?
|
||||
unsigned DiagIndex = 0;
|
||||
// What's the widest label?
|
||||
LabelWidth = 0;
|
||||
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
|
||||
++DiagItr) {
|
||||
InputAnnotation A;
|
||||
A.DiagIndex = DiagIndex++;
|
||||
|
||||
// Build label, which uniquely identifies this check result.
|
||||
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
|
||||
auto CheckLineAndCol =
|
||||
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
|
||||
llvm::raw_string_ostream Label(A.Label);
|
||||
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
|
||||
if (CheckBufferID == CheckFileBufferID)
|
||||
Label << CheckLineAndCol.first;
|
||||
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
|
||||
CheckBufferID < ImpPatBufferIDRange.second)
|
||||
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
|
||||
else
|
||||
llvm_unreachable("expected diagnostic's check location to be either in "
|
||||
"the check file or for an implicit pattern");
|
||||
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
|
||||
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
|
||||
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
|
||||
|
||||
A.Marker = GetMarker(DiagItr->MatchTy);
|
||||
if (!DiagItr->Note.empty()) {
|
||||
A.Marker.Note = DiagItr->Note;
|
||||
// It's less confusing if notes that don't actually have ranges don't have
|
||||
// markers. For example, a marker for 'with "VAR" equal to "5"' would
|
||||
// seem to indicate where "VAR" matches, but the location we actually have
|
||||
// for the marker simply points to the start of the match/search range for
|
||||
// the full pattern of which the substitution is potentially just one
|
||||
// component.
|
||||
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
|
||||
DiagItr->InputStartCol == DiagItr->InputEndCol)
|
||||
A.Marker.Lead = ' ';
|
||||
}
|
||||
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
|
||||
assert(!DiagItr->Note.empty() &&
|
||||
"expected custom note for MatchFoundErrorNote");
|
||||
A.Marker.Note = "error: " + A.Marker.Note;
|
||||
}
|
||||
A.FoundAndExpectedMatch =
|
||||
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
|
||||
|
||||
// Compute the mark location, and break annotation into multiple
|
||||
// annotations if it spans multiple lines.
|
||||
A.IsFirstLine = true;
|
||||
A.InputLine = DiagItr->InputStartLine;
|
||||
A.InputStartCol = DiagItr->InputStartCol;
|
||||
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
|
||||
// Sometimes ranges are empty in order to indicate a specific point, but
|
||||
// that would mean nothing would be marked, so adjust the range to
|
||||
// include the following character.
|
||||
A.InputEndCol =
|
||||
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
|
||||
Annotations.push_back(A);
|
||||
} else {
|
||||
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
|
||||
"expected input range not to be inverted");
|
||||
A.InputEndCol = UINT_MAX;
|
||||
Annotations.push_back(A);
|
||||
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
|
||||
L <= E; ++L) {
|
||||
// If a range ends before the first column on a line, then it has no
|
||||
// characters on that line, so there's nothing to render.
|
||||
if (DiagItr->InputEndCol == 1 && L == E)
|
||||
break;
|
||||
InputAnnotation B;
|
||||
B.DiagIndex = A.DiagIndex;
|
||||
B.Label = A.Label;
|
||||
B.IsFirstLine = false;
|
||||
B.InputLine = L;
|
||||
B.Marker = A.Marker;
|
||||
B.Marker.Lead = '~';
|
||||
B.Marker.Note = "";
|
||||
B.InputStartCol = 1;
|
||||
if (L != E)
|
||||
B.InputEndCol = UINT_MAX;
|
||||
else
|
||||
B.InputEndCol = DiagItr->InputEndCol;
|
||||
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
|
||||
Annotations.push_back(B);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned FindInputLineInFilter(
|
||||
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
|
||||
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
|
||||
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
|
||||
if (DumpInputFilter == DumpInputFilterAll)
|
||||
return CurInputLine;
|
||||
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
|
||||
++AnnotationItr) {
|
||||
switch (DumpInputFilter) {
|
||||
case DumpInputFilterAll:
|
||||
llvm_unreachable("unexpected DumpInputFilterAll");
|
||||
break;
|
||||
case DumpInputFilterAnnotationFull:
|
||||
return AnnotationItr->InputLine;
|
||||
case DumpInputFilterAnnotation:
|
||||
if (AnnotationItr->IsFirstLine)
|
||||
return AnnotationItr->InputLine;
|
||||
break;
|
||||
case DumpInputFilterError:
|
||||
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
|
||||
return AnnotationItr->InputLine;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return UINT_MAX;
|
||||
}
|
||||
|
||||
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
|
||||
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
|
||||
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
|
||||
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
|
||||
unsigned LabelWidth) {
|
||||
if (ElidedLines.empty())
|
||||
return;
|
||||
unsigned EllipsisLines = 3;
|
||||
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
|
||||
for (unsigned i = 0; i < EllipsisLines; ++i) {
|
||||
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
|
||||
<< right_justify(".", LabelWidth);
|
||||
OS << '\n';
|
||||
}
|
||||
} else
|
||||
OS << ElidedLines;
|
||||
ElidedLines.clear();
|
||||
}
|
||||
|
||||
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
|
||||
DumpInputFilterValue DumpInputFilter,
|
||||
unsigned DumpInputContext,
|
||||
StringRef InputFileText,
|
||||
std::vector<InputAnnotation> &Annotations,
|
||||
unsigned LabelWidth) {
|
||||
OS << "Input was:\n<<<<<<\n";
|
||||
|
||||
// Sort annotations.
|
||||
llvm::sort(Annotations,
|
||||
[](const InputAnnotation &A, const InputAnnotation &B) {
|
||||
// 1. Sort annotations in the order of the input lines.
|
||||
//
|
||||
// This makes it easier to find relevant annotations while
|
||||
// iterating input lines in the implementation below. FileCheck
|
||||
// does not always produce diagnostics in the order of input
|
||||
// lines due to, for example, CHECK-DAG and CHECK-NOT.
|
||||
if (A.InputLine != B.InputLine)
|
||||
return A.InputLine < B.InputLine;
|
||||
// 2. Sort annotations in the temporal order FileCheck produced
|
||||
// their associated diagnostics.
|
||||
//
|
||||
// This sort offers several benefits:
|
||||
//
|
||||
// A. On a single input line, the order of annotations reflects
|
||||
// the FileCheck logic for processing directives/patterns.
|
||||
// This can be helpful in understanding cases in which the
|
||||
// order of the associated directives/patterns in the check
|
||||
// file or on the command line either (i) does not match the
|
||||
// temporal order in which FileCheck looks for matches for the
|
||||
// directives/patterns (due to, for example, CHECK-LABEL,
|
||||
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
|
||||
// that order but does not match the order of those
|
||||
// diagnostics along an input line (due to, for example,
|
||||
// CHECK-DAG).
|
||||
//
|
||||
// On the other hand, because our presentation format presents
|
||||
// input lines in order, there's no clear way to offer the
|
||||
// same benefit across input lines. For consistency, it might
|
||||
// then seem worthwhile to have annotations on a single line
|
||||
// also sorted in input order (that is, by input column).
|
||||
// However, in practice, this appears to be more confusing
|
||||
// than helpful. Perhaps it's intuitive to expect annotations
|
||||
// to be listed in the temporal order in which they were
|
||||
// produced except in cases the presentation format obviously
|
||||
// and inherently cannot support it (that is, across input
|
||||
// lines).
|
||||
//
|
||||
// B. When diagnostics' annotations are split among multiple
|
||||
// input lines, the user must track them from one input line
|
||||
// to the next. One property of the sort chosen here is that
|
||||
// it facilitates the user in this regard by ensuring the
|
||||
// following: when comparing any two input lines, a
|
||||
// diagnostic's annotations are sorted in the same position
|
||||
// relative to all other diagnostics' annotations.
|
||||
return A.DiagIndex < B.DiagIndex;
|
||||
});
|
||||
|
||||
// Compute the width of the label column.
|
||||
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
|
||||
*InputFileEnd = InputFileText.bytes_end();
|
||||
unsigned LineCount = InputFileText.count('\n');
|
||||
if (InputFileEnd[-1] != '\n')
|
||||
++LineCount;
|
||||
unsigned LineNoWidth = std::log10(LineCount) + 1;
|
||||
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
|
||||
// on input lines and (2) to the right of the (left-aligned) labels on
|
||||
// annotation lines so that input lines and annotation lines are more
|
||||
// visually distinct. For example, the spaces on the annotation lines ensure
|
||||
// that input line numbers and check directive line numbers never align
|
||||
// horizontally. Those line numbers might not even be for the same file.
|
||||
// One space would be enough to achieve that, but more makes it even easier
|
||||
// to see.
|
||||
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
|
||||
|
||||
// Print annotated input lines.
|
||||
unsigned PrevLineInFilter = 0; // 0 means none so far
|
||||
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
|
||||
std::string ElidedLines;
|
||||
raw_string_ostream ElidedLinesOS(ElidedLines);
|
||||
ColorMode TheColorMode =
|
||||
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
|
||||
if (TheColorMode == ColorMode::Enable)
|
||||
ElidedLinesOS.enable_colors(true);
|
||||
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
|
||||
for (unsigned Line = 1;
|
||||
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
|
||||
const unsigned char *InputFileLine = InputFilePtr;
|
||||
|
||||
// Compute the previous and next line included by the filter.
|
||||
if (NextLineInFilter < Line)
|
||||
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
|
||||
AnnotationItr, AnnotationEnd);
|
||||
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
|
||||
if (NextLineInFilter == Line)
|
||||
PrevLineInFilter = Line;
|
||||
|
||||
// Elide this input line and its annotations if it's not within the
|
||||
// context specified by -dump-input-context of an input line included by
|
||||
// -dump-input-filter. However, in case the resulting ellipsis would occupy
|
||||
// more lines than the input lines and annotations it elides, buffer the
|
||||
// elided lines and annotations so we can print them instead.
|
||||
raw_ostream *LineOS = &OS;
|
||||
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
|
||||
(NextLineInFilter == UINT_MAX ||
|
||||
Line + DumpInputContext < NextLineInFilter))
|
||||
LineOS = &ElidedLinesOS;
|
||||
else {
|
||||
LineOS = &OS;
|
||||
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
|
||||
}
|
||||
|
||||
// Print right-aligned line number.
|
||||
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
|
||||
TheColorMode)
|
||||
<< format_decimal(Line, LabelWidth) << ": ";
|
||||
|
||||
// For the case where -v and colors are enabled, find the annotations for
|
||||
// good matches for expected patterns in order to highlight everything
|
||||
// else in the line. There are no such annotations if -v is disabled.
|
||||
std::vector<InputAnnotation> FoundAndExpectedMatches;
|
||||
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
|
||||
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
|
||||
++I) {
|
||||
if (I->FoundAndExpectedMatch)
|
||||
FoundAndExpectedMatches.push_back(*I);
|
||||
}
|
||||
}
|
||||
|
||||
// Print numbered line with highlighting where there are no matches for
|
||||
// expected patterns.
|
||||
bool Newline = false;
|
||||
{
|
||||
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
|
||||
/*BG=*/false, TheColorMode);
|
||||
bool InMatch = false;
|
||||
if (Req.Verbose)
|
||||
COS.changeColor(raw_ostream::CYAN, true, true);
|
||||
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
|
||||
bool WasInMatch = InMatch;
|
||||
InMatch = false;
|
||||
for (auto M : FoundAndExpectedMatches) {
|
||||
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
|
||||
InMatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!WasInMatch && InMatch)
|
||||
COS.resetColor();
|
||||
else if (WasInMatch && !InMatch)
|
||||
COS.changeColor(raw_ostream::CYAN, true, true);
|
||||
if (*InputFilePtr == '\n') {
|
||||
Newline = true;
|
||||
COS << ' ';
|
||||
} else
|
||||
COS << *InputFilePtr;
|
||||
++InputFilePtr;
|
||||
}
|
||||
}
|
||||
*LineOS << '\n';
|
||||
unsigned InputLineWidth = InputFilePtr - InputFileLine;
|
||||
|
||||
// Print any annotations.
|
||||
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
|
||||
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
|
||||
/*BG=*/false, TheColorMode);
|
||||
// The two spaces below are where the ": " appears on input lines.
|
||||
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
|
||||
unsigned Col;
|
||||
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
|
||||
COS << ' ';
|
||||
COS << AnnotationItr->Marker.Lead;
|
||||
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
|
||||
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
|
||||
++Col)
|
||||
COS << '~';
|
||||
const std::string &Note = AnnotationItr->Marker.Note;
|
||||
if (!Note.empty()) {
|
||||
// Put the note at the end of the input line. If we were to instead
|
||||
// put the note right after the marker, subsequent annotations for the
|
||||
// same input line might appear to mark this note instead of the input
|
||||
// line.
|
||||
for (; Col <= InputLineWidth; ++Col)
|
||||
COS << ' ';
|
||||
COS << ' ' << Note;
|
||||
}
|
||||
COS << '\n';
|
||||
++AnnotationItr;
|
||||
}
|
||||
}
|
||||
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
|
||||
|
||||
OS << ">>>>>>\n";
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Enable use of ANSI color codes because FileCheck is using them to
|
||||
// highlight text.
|
||||
llvm::sys::Process::UseANSIEscapeCodes(true);
|
||||
|
||||
InitLLVM X(argc, argv);
|
||||
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
|
||||
"FILECHECK_OPTS");
|
||||
|
||||
// Select -dump-input* values. The -help documentation specifies the default
|
||||
// value and which value to choose if an option is specified multiple times.
|
||||
// In the latter case, the general rule of thumb is to choose the value that
|
||||
// provides the most information.
|
||||
DumpInputValue DumpInput =
|
||||
DumpInputs.empty()
|
||||
? DumpInputFail
|
||||
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
|
||||
DumpInputFilterValue DumpInputFilter;
|
||||
if (DumpInputFilters.empty())
|
||||
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
|
||||
: DumpInputFilterError;
|
||||
else
|
||||
DumpInputFilter =
|
||||
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
|
||||
unsigned DumpInputContext = DumpInputContexts.empty()
|
||||
? 5
|
||||
: *std::max_element(DumpInputContexts.begin(),
|
||||
DumpInputContexts.end());
|
||||
|
||||
if (DumpInput == DumpInputHelp) {
|
||||
DumpInputAnnotationHelp(outs());
|
||||
return 0;
|
||||
}
|
||||
if (CheckFilename.empty()) {
|
||||
errs() << "<check-file> not specified\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
FileCheckRequest Req;
|
||||
append_range(Req.CheckPrefixes, CheckPrefixes);
|
||||
|
||||
append_range(Req.CommentPrefixes, CommentPrefixes);
|
||||
|
||||
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
|
||||
|
||||
bool GlobalDefineError = false;
|
||||
for (StringRef G : GlobalDefines) {
|
||||
size_t EqIdx = G.find('=');
|
||||
if (EqIdx == std::string::npos) {
|
||||
errs() << "Missing equal sign in command-line definition '-D" << G
|
||||
<< "'\n";
|
||||
GlobalDefineError = true;
|
||||
continue;
|
||||
}
|
||||
if (EqIdx == 0) {
|
||||
errs() << "Missing variable name in command-line definition '-D" << G
|
||||
<< "'\n";
|
||||
GlobalDefineError = true;
|
||||
continue;
|
||||
}
|
||||
Req.GlobalDefines.push_back(G);
|
||||
}
|
||||
if (GlobalDefineError)
|
||||
return 2;
|
||||
|
||||
Req.AllowEmptyInput = AllowEmptyInput;
|
||||
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
|
||||
Req.EnableVarScope = EnableVarScope;
|
||||
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
|
||||
Req.Verbose = Verbose;
|
||||
Req.VerboseVerbose = VerboseVerbose;
|
||||
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
|
||||
Req.MatchFullLines = MatchFullLines;
|
||||
Req.IgnoreCase = IgnoreCase;
|
||||
|
||||
if (VerboseVerbose)
|
||||
Req.Verbose = true;
|
||||
|
||||
FileCheck FC(Req);
|
||||
if (!FC.ValidateCheckPrefixes())
|
||||
return 2;
|
||||
|
||||
Regex PrefixRE = FC.buildCheckPrefixRegex();
|
||||
std::string REError;
|
||||
if (!PrefixRE.isValid(REError)) {
|
||||
errs() << "Unable to combine check-prefix strings into a prefix regular "
|
||||
"expression! This is likely a bug in FileCheck's verification of "
|
||||
"the check-prefix strings. Regular expression parsing failed "
|
||||
"with the following error: "
|
||||
<< REError << "\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
SourceMgr SM;
|
||||
|
||||
// Read the expected strings from the check file.
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
|
||||
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
|
||||
if (std::error_code EC = CheckFileOrErr.getError()) {
|
||||
errs() << "Could not open check file '" << CheckFilename
|
||||
<< "': " << EC.message() << '\n';
|
||||
return 2;
|
||||
}
|
||||
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
|
||||
|
||||
SmallString<4096> CheckFileBuffer;
|
||||
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
|
||||
|
||||
unsigned CheckFileBufferID =
|
||||
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
|
||||
CheckFileText, CheckFile.getBufferIdentifier()),
|
||||
SMLoc());
|
||||
|
||||
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
|
||||
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
|
||||
return 2;
|
||||
|
||||
// Open the file to check and add it to SourceMgr.
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
|
||||
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
|
||||
if (InputFilename == "-")
|
||||
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
|
||||
if (std::error_code EC = InputFileOrErr.getError()) {
|
||||
errs() << "Could not open input file '" << InputFilename
|
||||
<< "': " << EC.message() << '\n';
|
||||
return 2;
|
||||
}
|
||||
MemoryBuffer &InputFile = *InputFileOrErr.get();
|
||||
|
||||
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
|
||||
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
|
||||
DumpCommandLine(argc, argv);
|
||||
return 2;
|
||||
}
|
||||
|
||||
SmallString<4096> InputFileBuffer;
|
||||
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
|
||||
|
||||
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
|
||||
InputFileText, InputFile.getBufferIdentifier()),
|
||||
SMLoc());
|
||||
|
||||
std::vector<FileCheckDiag> Diags;
|
||||
int ExitCode = FC.checkInput(SM, InputFileText,
|
||||
DumpInput == DumpInputNever ? nullptr : &Diags)
|
||||
? EXIT_SUCCESS
|
||||
: 1;
|
||||
if (DumpInput == DumpInputAlways ||
|
||||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
|
||||
errs() << "\n"
|
||||
<< "Input file: " << InputFilename << "\n"
|
||||
<< "Check file: " << CheckFilename << "\n"
|
||||
<< "\n"
|
||||
<< "-dump-input=help explains the following input dump.\n"
|
||||
<< "\n";
|
||||
std::vector<InputAnnotation> Annotations;
|
||||
unsigned LabelWidth;
|
||||
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
|
||||
Annotations, LabelWidth);
|
||||
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
|
||||
InputFileText, Annotations, LabelWidth);
|
||||
}
|
||||
|
||||
return ExitCode;
|
||||
}
|
42
bin/triton-opt.cpp
Normal file
42
bin/triton-opt.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "triton/Conversion/Passes.h"
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestAliasPass();
|
||||
void registerTestAlignmentPass();
|
||||
void registerTestAllocationPass();
|
||||
void registerTestMembarPass();
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::registerTritonPasses();
|
||||
mlir::registerTritonGPUPasses();
|
||||
mlir::test::registerTestAliasPass();
|
||||
mlir::test::registerTestAlignmentPass();
|
||||
mlir::test::registerTestAllocationPass();
|
||||
mlir::test::registerTestMembarPass();
|
||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||
mlir::triton::registerConvertTritonGPUToLLVMPass();
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
||||
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
||||
|
||||
return mlir::asMainReturnCode(mlir::MlirOptMain(
|
||||
argc, argv, "Triton (GPU) optimizer driver\n", registry));
|
||||
}
|
131
bin/triton-translate.cpp
Normal file
131
bin/triton-translate.cpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||
#include "triton/driver/llvm.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
||||
MLIRContext &context) {
|
||||
std::string errorMessage;
|
||||
auto input = openInputFile(inputFilename, &errorMessage);
|
||||
if (!input) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, arith::ArithmeticDialect,
|
||||
StandardOpsDialect, scf::SCFDialect>();
|
||||
|
||||
context.appendDialectRegistry(registry);
|
||||
|
||||
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
|
||||
-> OwningOpRef<ModuleOp> {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
||||
|
||||
context.loadAllAvailableDialects();
|
||||
context.allowUnregisteredDialects();
|
||||
|
||||
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
|
||||
if (!module) {
|
||||
llvm::errs() << "Parse MLIR file failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
};
|
||||
|
||||
auto module = processBuffer(std::move(input));
|
||||
if (!module) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
LogicalResult tritonTranslateMain(int argc, char **argv,
|
||||
llvm::StringRef toolName) {
|
||||
static llvm::cl::opt<std::string> inputFilename(
|
||||
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
static llvm::cl::opt<std::string> outputFilename(
|
||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
|
||||
static llvm::cl::opt<std::string> targetKind(
|
||||
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
|
||||
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
|
||||
|
||||
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
|
||||
llvm::cl::init(80));
|
||||
|
||||
static llvm::cl::opt<int> ptxVersion(
|
||||
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
registerAsmPrinterCLOptions();
|
||||
registerMLIRContextCLOptions();
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module = loadMLIRModule(inputFilename, context);
|
||||
if (!module) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::string errorMessage;
|
||||
auto output = openOutputFile(outputFilename, &errorMessage);
|
||||
if (!output) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmir =
|
||||
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
|
||||
if (!llvmir) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
}
|
||||
|
||||
if (targetKind == "llvmir")
|
||||
llvm::outs() << *llvmir << '\n';
|
||||
else if (targetKind == "ptx")
|
||||
llvm::outs() << ::triton::driver::llir_to_ptx(
|
||||
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
return failed(mlir::triton::tritonTranslateMain(
|
||||
argc, argv, "Triton Translate Testing Tool."));
|
||||
}
|
@@ -25,7 +25,7 @@
|
||||
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
|
||||
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
|
||||
#
|
||||
# Note: The variable names were chosen in conformance with the offical CMake
|
||||
# Note: The variable names were chosen in conformance with the official CMake
|
||||
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
|
||||
|
||||
# Try suffixed versions to pick up the newest LLVM install available on Debian
|
||||
@@ -196,4 +196,4 @@ include(FindPackageHandleStandardArgs)
|
||||
|
||||
find_package_handle_standard_args(LLVM
|
||||
REQUIRED_VARS LLVM_ROOT_DIR
|
||||
VERSION_VAR LLVM_VERSION_STRING)
|
||||
VERSION_VAR LLVM_VERSION_STRING)
|
||||
|
1
deps/dlfcn-win32
vendored
1
deps/dlfcn-win32
vendored
Submodule deps/dlfcn-win32 deleted from 522c301ec3
@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
|
||||
Program Representation
|
||||
+++++++++++++++++++++++
|
||||
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
|
||||
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
|
||||
Limitations
|
||||
++++++++++++
|
||||
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||
|
||||
.. table::
|
||||
:widths: 50 50
|
||||
|
1
include/CMakeLists.txt
Normal file
1
include/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(triton)
|
80
include/triton/Analysis/Alias.h
Normal file
80
include/triton/Analysis/Alias.h
Normal file
@@ -0,0 +1,80 @@
|
||||
#ifndef TRITON_ANALYSIS_ALIAS_H
|
||||
#define TRITON_ANALYSIS_ALIAS_H
|
||||
|
||||
#include "mlir/Analysis/AliasAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AliasInfo {
|
||||
public:
|
||||
AliasInfo() = default;
|
||||
AliasInfo(Value value) { insert(value); }
|
||||
|
||||
void insert(Value value) { allocs.insert(value); }
|
||||
|
||||
const DenseSet<Value> &getAllocs() const { return allocs; }
|
||||
|
||||
bool operator==(const AliasInfo &other) const {
|
||||
return allocs == other.allocs;
|
||||
}
|
||||
|
||||
/// The pessimistic value state of a value without alias
|
||||
static AliasInfo getPessimisticValueState(MLIRContext *context) {
|
||||
return AliasInfo();
|
||||
}
|
||||
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
|
||||
|
||||
/// The union of both arguments
|
||||
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The set of allocated values that are aliased by this lattice.
|
||||
/// For now, we only consider aliased value produced by the following
|
||||
/// situations:
|
||||
/// 1. values returned by scf.yield
|
||||
/// 2. block arguments in scf.for
|
||||
/// Example:
|
||||
/// alloc v1 alloc v2
|
||||
/// | |
|
||||
/// |--------------| |------------|
|
||||
/// scf.for v3 scf.for v4 scf.for v5
|
||||
/// |
|
||||
/// scf.yield v6
|
||||
///
|
||||
/// v1's alloc [v1]
|
||||
/// v2's alloc [v2]
|
||||
/// v3's alloc [v1]
|
||||
/// v4's alloc [v1, v2]
|
||||
/// v5's alloc [v2]
|
||||
/// v6's alloc [v1]
|
||||
///
|
||||
/// Therefore, v1's liveness range is the union of v3, v4, and v6
|
||||
/// v2's liveness range is the union of v4 and v5.
|
||||
DenseSet<Value> allocs;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Alias Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
|
||||
/// Given two values, returns their aliasing behavior.
|
||||
AliasResult alias(Value lhs, Value rhs);
|
||||
|
||||
/// Returns the modify-reference behavior of `op` on `location`.
|
||||
ModRefResult getModRef(Operation *op, Value location);
|
||||
|
||||
/// Computes if the alloc set of the results are changed.
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_ALIAS_H
|
192
include/triton/Analysis/Allocation.h
Normal file
192
include/triton/Analysis/Allocation.h
Normal file
@@ -0,0 +1,192 @@
|
||||
#ifndef TRITON_ANALYSIS_ALLOCATION_H
|
||||
#define TRITON_ANALYSIS_ALLOCATION_H
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace triton {
|
||||
class AllocationAnalysis;
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
|
||||
/// A class that represents an interval, specified using a start and an end
|
||||
/// values: [Start, End).
|
||||
template <typename T> class Interval {
|
||||
public:
|
||||
Interval() {}
|
||||
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
|
||||
T start() const { return Start; }
|
||||
T end() const { return End; }
|
||||
T size() const { return End - Start; }
|
||||
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
|
||||
bool intersects(const Interval &R) const {
|
||||
return Start < R.End && R.Start < End;
|
||||
}
|
||||
bool operator==(const Interval &R) const {
|
||||
return Start == R.Start && End == R.End;
|
||||
}
|
||||
bool operator!=(const Interval &R) const { return !(*this == R); }
|
||||
bool operator<(const Interval &R) const {
|
||||
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
|
||||
}
|
||||
|
||||
private:
|
||||
T Start = std::numeric_limits<T>::min();
|
||||
T End = std::numeric_limits<T>::max();
|
||||
};
|
||||
|
||||
class Allocation {
|
||||
public:
|
||||
/// A unique identifier for shared memory buffers
|
||||
using BufferId = size_t;
|
||||
using BufferIdSetT = DenseSet<BufferId>;
|
||||
|
||||
static constexpr BufferId InvalidBufferId =
|
||||
std::numeric_limits<BufferId>::max();
|
||||
|
||||
/// Creates a new Allocation analysis that computes the shared memory
|
||||
/// information for all associated shared memory values.
|
||||
Allocation(Operation *operation) : operation(operation) { run(); }
|
||||
|
||||
/// Returns the operation this analysis was constructed from.
|
||||
Operation *getOperation() const { return operation; }
|
||||
|
||||
/// Returns the offset of the given buffer in the shared memory.
|
||||
size_t getOffset(BufferId bufferId) const {
|
||||
return bufferSet.lookup(bufferId).offset;
|
||||
}
|
||||
|
||||
/// Returns the size of the given buffer in the shared memory.
|
||||
size_t getAllocatedSize(BufferId bufferId) const {
|
||||
return bufferSet.lookup(bufferId).size;
|
||||
}
|
||||
|
||||
/// Returns the buffer id of the given value.
|
||||
/// This interface only returns the allocated buffer id.
|
||||
/// If you want to get all the buffer ids that are associated with the given
|
||||
/// value, including alias buffers, use getBufferIds.
|
||||
BufferId getBufferId(Value value) const {
|
||||
if (valueBuffer.count(value)) {
|
||||
return valueBuffer.lookup(value)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all the buffer ids of the given value, including alias buffers.
|
||||
BufferIdSetT getBufferIds(Value value) const {
|
||||
BufferIdSetT bufferIds;
|
||||
auto allocBufferId = getBufferId(value);
|
||||
if (allocBufferId != InvalidBufferId)
|
||||
bufferIds.insert(allocBufferId);
|
||||
for (auto *buffer : aliasBuffer.lookup(value)) {
|
||||
if (buffer->id != InvalidBufferId)
|
||||
bufferIds.insert(buffer->id);
|
||||
}
|
||||
return bufferIds;
|
||||
}
|
||||
|
||||
/// Returns the scratch buffer id of the given value.
|
||||
BufferId getBufferId(Operation *operation) const {
|
||||
if (opScratch.count(operation)) {
|
||||
return opScratch.lookup(operation)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of total shared memory allocated
|
||||
size_t getSharedMemorySize() const { return sharedMemorySize; }
|
||||
|
||||
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
|
||||
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
|
||||
return false;
|
||||
auto lhsBuffer = bufferSet.lookup(lhsId);
|
||||
auto rhsBuffer = bufferSet.lookup(rhsId);
|
||||
return lhsBuffer.intersects(rhsBuffer);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A class that represents a shared memory buffer
|
||||
struct BufferT {
|
||||
enum class BufferKind { Explicit, Scratch };
|
||||
|
||||
/// MT: thread-safe
|
||||
inline static std::atomic<BufferId> nextId = 0;
|
||||
|
||||
BufferKind kind;
|
||||
BufferId id;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
|
||||
bool operator==(const BufferT &other) const { return id == other.id; }
|
||||
bool operator<(const BufferT &other) const { return id < other.id; }
|
||||
|
||||
BufferT() : BufferT(BufferKind::Explicit) {}
|
||||
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||
|
||||
bool intersects(const BufferT &other) const {
|
||||
return Interval<size_t>(offset, offset + size)
|
||||
.intersects(
|
||||
Interval<size_t>(other.offset, other.offset + other.size));
|
||||
}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
|
||||
/// Value -> Explicit Buffer
|
||||
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
|
||||
/// Value -> Alias Buffer
|
||||
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
|
||||
/// BufferId -> Buffer
|
||||
using BufferSetT = DenseMap<BufferId, BufferT>;
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
void run();
|
||||
|
||||
private:
|
||||
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
|
||||
void addBuffer(KeyType &key, Args &&...args) {
|
||||
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
|
||||
bufferSet[buffer.id] = std::move(buffer);
|
||||
if constexpr (Kind == BufferT::BufferKind::Explicit) {
|
||||
valueBuffer[key] = &bufferSet[buffer.id];
|
||||
} else {
|
||||
opScratch[key] = &bufferSet[buffer.id];
|
||||
}
|
||||
}
|
||||
|
||||
void addAlias(Value value, Value alloc) {
|
||||
aliasBuffer[value].insert(valueBuffer[alloc]);
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
OpScratchMapT opScratch;
|
||||
ValueBufferMapT valueBuffer;
|
||||
AliasBufferMapT aliasBuffer;
|
||||
BufferSetT bufferSet;
|
||||
size_t sharedMemorySize = 0;
|
||||
|
||||
friend class triton::AllocationAnalysis;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_ALLOCATION_H
|
144
include/triton/Analysis/AxisInfo.h
Normal file
144
include/triton/Analysis/AxisInfo.h
Normal file
@@ -0,0 +1,144 @@
|
||||
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
||||
#define TRITON_ANALYSIS_AXISINFO_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
/// Axis information is represented by a std::map<int, int>
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef SmallVector<int, 4> DimVectorT;
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
AxisInfo() : AxisInfo({}, {}, {}) {}
|
||||
// Construct contiguity info with known contiguity
|
||||
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
|
||||
DimVectorT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == (size_t)rank);
|
||||
assert(knownConstancy.size() == (size_t)rank);
|
||||
}
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const DimVectorT &getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
const DimVectorT &getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const DimVectorT &getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
// Comparison
|
||||
bool operator==(const AxisInfo &other) const {
|
||||
return (contiguity == other.contiguity) &&
|
||||
(divisibility == other.divisibility) &&
|
||||
(constancy == other.constancy);
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context) {
|
||||
return AxisInfo();
|
||||
}
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of contiguous integers along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
/// Would have contiguity [1, 4].
|
||||
/// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
/// [18, 22, 26, 30]
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
DimVectorT contiguity;
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
// would have divisibility [1, 2]
|
||||
// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
DimVectorT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of constant integer along it. This is
|
||||
/// particularly useful to infer the contiguity
|
||||
/// of operations (e.g., add) involving a constant
|
||||
/// For example
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
DimVectorT constancy;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
};
|
||||
|
||||
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
|
||||
private:
|
||||
static const int maxPow2Divisor = 65536;
|
||||
|
||||
int highestPowOf2Divisor(int n) {
|
||||
if (n == 0)
|
||||
return maxPow2Divisor;
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
AxisInfo visitBinaryOp(
|
||||
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
|
||||
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
|
||||
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
ChangeResult
|
||||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
unsigned getPtrVectorSize(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
unsigned getMaskAlignment(Value mask);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
119
include/triton/Analysis/Membar.h
Normal file
119
include/triton/Analysis/Membar.h
Normal file
@@ -0,0 +1,119 @@
|
||||
#ifndef TRITON_ANALYSIS_MEMBAR_H
|
||||
#define TRITON_ANALYSIS_MEMBAR_H
|
||||
|
||||
#include "Allocation.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OpBuilder;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Barrier Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class MembarAnalysis {
|
||||
public:
|
||||
/// Creates a new Membar analysis that generates the shared memory barrier
|
||||
/// in the following circumstances:
|
||||
/// - RAW: If a shared memory write is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// - WAR: If a shared memory read is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// The following circumstances do not require a barrier:
|
||||
/// - WAW: not possible because overlapped memory allocation is not allowed.
|
||||
/// - RAR: no write is performed.
|
||||
/// Temporary storage of operations such as Reduce are considered as both
|
||||
/// a shared memory read. If the temporary storage is written but not read,
|
||||
/// it is considered as the problem of the operation itself but not the membar
|
||||
/// analysis.
|
||||
/// The following circumstances are not considered yet:
|
||||
/// - Double buffers
|
||||
/// - N buffers
|
||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run();
|
||||
|
||||
private:
|
||||
struct RegionInfo {
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
|
||||
BufferIdSetT syncReadBuffers;
|
||||
BufferIdSetT syncWriteBuffers;
|
||||
|
||||
RegionInfo() = default;
|
||||
RegionInfo(const BufferIdSetT &syncReadBuffers,
|
||||
const BufferIdSetT &syncWriteBuffers)
|
||||
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
|
||||
}
|
||||
|
||||
/// Unions two RegionInfo objects.
|
||||
void join(const RegionInfo &other) {
|
||||
syncReadBuffers.insert(other.syncReadBuffers.begin(),
|
||||
other.syncReadBuffers.end());
|
||||
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
|
||||
other.syncWriteBuffers.end());
|
||||
}
|
||||
|
||||
/// Returns true if buffers in two RegionInfo objects are intersected.
|
||||
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
|
||||
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
||||
allocation) ||
|
||||
/*WAR*/
|
||||
isIntersected(syncReadBuffers, other.syncWriteBuffers,
|
||||
allocation) ||
|
||||
/*WAW*/
|
||||
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
|
||||
allocation);
|
||||
}
|
||||
|
||||
/// Clears the buffers because a barrier is inserted.
|
||||
void sync() {
|
||||
syncReadBuffers.clear();
|
||||
syncWriteBuffers.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns true if buffers in two sets are intersected.
|
||||
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
|
||||
Allocation *allocation) const {
|
||||
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
|
||||
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
|
||||
return allocation->isIntersected(lhsId, rhsId);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||
/// region has a single basic block only.
|
||||
/// Example:
|
||||
/// region1
|
||||
/// op1
|
||||
/// op2 (scf.if)
|
||||
/// region2
|
||||
/// op3
|
||||
/// op4
|
||||
/// region3
|
||||
/// op5
|
||||
/// op6
|
||||
/// op7
|
||||
/// region2 and region3 started with the information of region1.
|
||||
/// Each region is analyzed separately and keeps their own copy of the
|
||||
/// information. At op7, we union the information of the region2 and region3
|
||||
/// and update the information of region1.
|
||||
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
|
||||
/// Updates the RegionInfo operation based on the operation.
|
||||
void transfer(Operation *operation, RegionInfo *blockInfo,
|
||||
OpBuilder *builder);
|
||||
|
||||
private:
|
||||
Allocation *allocation;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_MEMBAR_H
|
76
include/triton/Analysis/Utility.h
Normal file
76
include/triton/Analysis/Utility.h
Normal file
@@ -0,0 +1,76 @@
|
||||
#ifndef TRITON_ANALYSIS_UTILITY_H
|
||||
#define TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ReduceOpHelper {
|
||||
public:
|
||||
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
|
||||
srcTy = op.operand().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
|
||||
|
||||
Attribute getSrcLayout() { return srcTy.getEncoding(); }
|
||||
|
||||
bool isFastReduction();
|
||||
|
||||
unsigned getInterWarpSize();
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
|
||||
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
|
||||
|
||||
unsigned getScratchSizeInBytes();
|
||||
|
||||
private:
|
||||
triton::ReduceOp op;
|
||||
RankedTensorType srcTy{};
|
||||
};
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
bool maybeAliasOp(Operation *op);
|
||||
|
||||
std::string getValueOperandName(Value value, AsmState &state);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
|
||||
SmallVector<T_OUT> out;
|
||||
for (const T_IN &i : in)
|
||||
out.push_back(T_OUT(i));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
// output[i] = input[order[i]]
|
||||
template <typename T, typename RES_T = T>
|
||||
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
assert(input.size() == rank);
|
||||
SmallVector<RES_T> result(rank);
|
||||
for (auto it : llvm::enumerate(order)) {
|
||||
result[it.index()] = input[it.value()];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
2
include/triton/CMakeLists.txt
Normal file
2
include/triton/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
4
include/triton/Conversion/CMakeLists.txt
Normal file
4
include/triton/Conversion/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(TritonConversionPassIncGen)
|
38
include/triton/Conversion/MLIRTypes.h
Normal file
38
include/triton/Conversion/MLIRTypes.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
|
||||
#define TRITON_CONVERSION_MLIR_TYPES_H_
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
// This file redefines some common MLIR types for easy usage.
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace type {
|
||||
|
||||
// Integer types
|
||||
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
|
||||
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
|
||||
Type u32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
|
||||
}
|
||||
Type u1Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
|
||||
}
|
||||
|
||||
// Float types
|
||||
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
|
||||
|
||||
static bool isFloat(Type type) {
|
||||
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
||||
}
|
||||
|
||||
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
|
||||
|
||||
} // namespace type
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_MLIR_TYPES_H_
|
17
include/triton/Conversion/Passes.h
Normal file
17
include/triton/Conversion/Passes.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES_H
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
54
include/triton/Conversion/Passes.td
Normal file
54
include/triton/Conversion/Passes.td
Normal file
@@ -0,0 +1,54 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES
|
||||
#define TRITON_CONVERSION_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
|
||||
let summary = "Convert Triton to TritonGPU";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::math::MathDialect",
|
||||
"mlir::StandardOpsDialect",
|
||||
// TODO: Does this pass depend on SCF?
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numWarps", "num-warps",
|
||||
"int32_t", /*default*/"4",
|
||||
"number of warps">
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
|
||||
let summary = "Convert TritonGPU to LLVM";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::math::MathDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::LLVM::LLVMDialect",
|
||||
"mlir::tensor::TensorDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::NVVM::NVVMDialect",
|
||||
"mlir::StandardOpsDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
#endif
|
327
include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h
Normal file
327
include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h
Normal file
@@ -0,0 +1,327 @@
|
||||
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
class ConversionPatternRewriter;
|
||||
class Location;
|
||||
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
struct PTXInstr;
|
||||
struct PTXInstrCommon;
|
||||
struct PTXInstrExecution;
|
||||
|
||||
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
|
||||
// instructions.
|
||||
//
|
||||
// A helper for building an ASM program, the objective of PTXBuilder is to give
|
||||
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
|
||||
// Currently, several factors are introduced to reduce the need for mixing
|
||||
// string and C++ if-else code.
|
||||
//
|
||||
// Usage:
|
||||
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
|
||||
// "b"(p));
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& add = builder.create<>();
|
||||
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
|
||||
// // predicate here binds %0 to pVal, pVal is a mlir::Value
|
||||
//
|
||||
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
|
||||
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
|
||||
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
|
||||
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
|
||||
//
|
||||
// To get the asm code:
|
||||
// builder.dump()
|
||||
//
|
||||
// To get all the mlir::Value used in the PTX code,
|
||||
//
|
||||
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
|
||||
//
|
||||
// To get the string containing all the constraints with "," separated,
|
||||
// builder.getConstraints() // get "=r,r,k"
|
||||
//
|
||||
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
|
||||
//
|
||||
// PTXBuilder builder;
|
||||
// auto& mov = builder.create("mov");
|
||||
// auto& cp = builder.create("cp");
|
||||
// mov(...);
|
||||
// cp(...);
|
||||
// This will get a PTX code with two instructions.
|
||||
//
|
||||
// Similar to a C function, a declared PTXInstr instance can be launched
|
||||
// multiple times with different operands, e.g.
|
||||
//
|
||||
// auto& mov = builder.create("mov");
|
||||
// mov(... some operands ...);
|
||||
// mov(... some different operands ...);
|
||||
//
|
||||
// Finally, we will get a PTX code with two mov instructions.
|
||||
//
|
||||
// There are several derived instruction type for typical instructions, for
|
||||
// example, the PtxIOInstr for ld and st instructions.
|
||||
struct PTXBuilder {
|
||||
struct Operand {
|
||||
std::string constraint;
|
||||
Value value;
|
||||
int idx{-1};
|
||||
llvm::SmallVector<Operand *> list;
|
||||
std::function<std::string(int idx)> repr;
|
||||
|
||||
// for list
|
||||
Operand() = default;
|
||||
Operand(const Operation &) = delete;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: constraint(constraint), value(value) {}
|
||||
|
||||
bool isList() const { return !value && constraint.empty(); }
|
||||
|
||||
Operand *listAppend(Operand *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
Operand *listGet(size_t nth) const {
|
||||
assert(nth < list.size());
|
||||
return list[nth];
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
template <typename INSTR = PTXInstr, typename... Args>
|
||||
INSTR *create(Args &&...args) {
|
||||
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
|
||||
return static_cast<INSTR *>(instrs.back().get());
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newListOperand() { return newOperand(); }
|
||||
|
||||
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
|
||||
auto *list = newOperand();
|
||||
for (auto &item : items) {
|
||||
list->listAppend(newOperand(item.first, item.second));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, mlir::Value val,
|
||||
const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(val, constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
Operand *newListOperand(unsigned count, const std::string &constraint) {
|
||||
auto *list = newOperand();
|
||||
for (unsigned i = 0; i < count; ++i) {
|
||||
list->listAppend(newOperand(constraint));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
// @formatter: extra format to represent this operand in ASM code, default is
|
||||
// "%{0}".format(operand.idx).
|
||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int idx)> formatter = nullptr);
|
||||
|
||||
// Create a new operand which is written to, that is, the constraint starts
|
||||
// with "=", e.g. "=r".
|
||||
Operand *newOperand(StringRef constraint);
|
||||
|
||||
// Create a constant integer operand.
|
||||
Operand *newConstantOperand(int64_t v);
|
||||
// Create a constant operand with explicit code specified.
|
||||
Operand *newConstantOperand(const std::string &v);
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||
|
||||
llvm::SmallVector<Operand *, 4> getAllArgs() const;
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
|
||||
|
||||
std::string getConstraints() const;
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type resTy, bool hasSideEffect = true,
|
||||
bool isAlignStack = false,
|
||||
ArrayRef<Attribute> attrs = {}) const;
|
||||
|
||||
private:
|
||||
Operand *newOperand() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
// Make the oprands in argArchive follow the provided \param order.
|
||||
void reorderArgArchive(ArrayRef<Operand *> order) {
|
||||
assert(order.size() == argArchive.size());
|
||||
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
|
||||
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
|
||||
// determined by PTX code snippet passed from external.
|
||||
sort(argArchive.begin(), argArchive.end(),
|
||||
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
|
||||
auto ida = std::find(order.begin(), order.end(), a.get());
|
||||
auto idb = std::find(order.begin(), order.end(), b.get());
|
||||
assert(ida != order.end());
|
||||
assert(idb != order.end());
|
||||
return ida < idb;
|
||||
});
|
||||
}
|
||||
|
||||
friend struct PTXInstr;
|
||||
friend struct PTXInstrCommon;
|
||||
|
||||
protected:
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
|
||||
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
|
||||
int oprCounter{};
|
||||
};
|
||||
|
||||
// PTX instruction common interface.
|
||||
// Put the generic logic for all the instructions here.
|
||||
struct PTXInstrCommon {
|
||||
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
|
||||
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
// clang-format off
|
||||
PTXInstrExecution& operator()() { return call({}); }
|
||||
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
|
||||
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
|
||||
// clang-format on
|
||||
|
||||
// Set operands of this instruction.
|
||||
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
protected:
|
||||
// "Call" the instruction with operands.
|
||||
// \param oprs The operands of this instruction.
|
||||
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
|
||||
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
|
||||
// code.
|
||||
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs = false);
|
||||
|
||||
PTXBuilder *builder{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
|
||||
friend struct PTXInstrExecution;
|
||||
};
|
||||
|
||||
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
|
||||
: PTXInstrCommon(builder) {
|
||||
o(name);
|
||||
}
|
||||
|
||||
// Append a suffix to the instruction.
|
||||
// e.g. PTXInstr("add").o("s32") get a add.s32.
|
||||
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
||||
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||
// get a `add.s32` if isS32 is true.
|
||||
ConcreteT &o(const std::string &suffix, bool predicate = true) {
|
||||
if (predicate)
|
||||
instrParts.push_back(suffix);
|
||||
return *static_cast<ConcreteT *>(this);
|
||||
}
|
||||
};
|
||||
|
||||
struct PTXInstr : public PTXInstrBase<PTXInstr> {
|
||||
using PTXInstrBase<PTXInstr>::PTXInstrBase;
|
||||
|
||||
// Append a ".global" to the instruction.
|
||||
PTXInstr &global();
|
||||
|
||||
// Append a ".shared" to the instruction.
|
||||
PTXInstr &shared();
|
||||
|
||||
// Append a ".v[0-9]+" to the instruction
|
||||
PTXInstr &v(int vecWidth, bool predicate = true);
|
||||
|
||||
// Append a".b[0-9]+" to the instruction
|
||||
PTXInstr &b(int width);
|
||||
};
|
||||
|
||||
// Record the operands and context for "launching" a PtxInstr.
|
||||
struct PTXInstrExecution {
|
||||
using Operand = PTXBuilder::Operand;
|
||||
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
|
||||
PTXInstrExecution() = default;
|
||||
explicit PTXInstrExecution(PTXInstrCommon *instr,
|
||||
llvm::ArrayRef<Operand *> oprs,
|
||||
bool onlyAttachMLIRArgs)
|
||||
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
|
||||
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Prefix a !predicate to the instruction.
|
||||
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
|
||||
pred = instr->builder->newOperand(value, constraint);
|
||||
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
SmallVector<Operand *> getArgList() const;
|
||||
|
||||
PTXInstrCommon *instr{};
|
||||
Operand *pred{};
|
||||
bool onlyAttachMLIRArgs{};
|
||||
};
|
||||
|
||||
//// =============================== Some instruction wrappers
|
||||
///===============================
|
||||
// We add the wrappers to make the usage more intuitive by avoiding mixing the
|
||||
// PTX code with some trivial C++ code.
|
||||
|
||||
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
|
||||
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
|
||||
triton::CacheModifier modifier)
|
||||
: PTXInstrBase(builder, "cp.async") {
|
||||
o(triton::stringifyCacheModifier(modifier).str());
|
||||
o("shared");
|
||||
o("global");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
43
include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Normal file
43
include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h
Normal file
@@ -0,0 +1,43 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
|
||||
mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(
|
||||
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Names for identifying different NVVM annotations. It is used as attribute
|
||||
// names in MLIR modules. Refer to
|
||||
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for
|
||||
// the full list.
|
||||
struct NVVMMetadataField {
|
||||
static constexpr char MaxNTid[] = "nvvm.maxntid";
|
||||
static constexpr char Kernel[] = "nvvm.kernel";
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
@@ -0,0 +1,25 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
|
||||
|
||||
// Create the pass with numWarps passed from cl::opt.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
|
||||
|
||||
// Create the pass with numWarps set explicitly.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertTritonToTritonGPUPass(int numWarps);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
2
include/triton/Dialect/CMakeLists.txt
Normal file
2
include/triton/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(Triton)
|
||||
add_subdirectory(TritonGPU)
|
2
include/triton/Dialect/Triton/CMakeLists.txt
Normal file
2
include/triton/Dialect/Triton/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
19
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
19
include/triton/Dialect/Triton/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
|
||||
mlir_tablegen(Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
|
||||
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
|
||||
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
|
||||
|
||||
add_public_tablegen_target(TritonTableGen)
|
48
include/triton/Dialect/Triton/IR/Dialect.h
Normal file
48
include/triton/Dialect/Triton/IR/Dialect.h
Normal file
@@ -0,0 +1,48 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
|
||||
#include "triton/Dialect/Triton/IR/Traits.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
class DialectInferLayoutInterface
|
||||
: public DialectInterface::Base<DialectInferLayoutInterface> {
|
||||
public:
|
||||
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
virtual LogicalResult
|
||||
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const = 0;
|
||||
|
||||
virtual LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding,
|
||||
Optional<Location> location) const = 0;
|
||||
|
||||
// Note: this function only verify operand encoding but doesn't infer result
|
||||
// encoding
|
||||
virtual LogicalResult
|
||||
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
|
||||
Attribute retEncoding,
|
||||
Optional<Location> location) const = 0;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_IR_DIALECT_H_
|
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
9
include/triton/Dialect/Triton/IR/Interfaces.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifndef TRITON_IR_INTERFACES_H_
|
||||
#define TRITON_IR_INTERFACES_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
60
include/triton/Dialect/Triton/IR/Traits.h
Normal file
60
include/triton/Dialect/Triton/IR/Traits.h
Normal file
@@ -0,0 +1,60 @@
|
||||
#ifndef TRITON_IR_TRAITS_H_
|
||||
#define TRITON_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
|
||||
LogicalResult verifySameOperandsEncoding(Operation *op);
|
||||
// The rationale for this trait is to prevent users from creating programs
|
||||
// that would have catastrophic register pressure and cause the compiler to
|
||||
// hang.
|
||||
// Since H100 has 256KB registers, we should allow users to create tensors
|
||||
// of size up to 256K elements. It will spill for datatypes wider than 1B,
|
||||
// but we probably should limit number of elements (rather than bytes) to
|
||||
// keep specs simple
|
||||
int constexpr maxTensorNumElements = 1048576;
|
||||
LogicalResult verifyTensorSize(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <class ConcreteType>
|
||||
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyTensorSize(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsAndResultEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsEncoding
|
||||
: public TraitBase<ConcreteType, SameOperandsEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
68
include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Normal file
68
include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Normal file
@@ -0,0 +1,68 @@
|
||||
#ifndef TRITON_ATTR_DEFS
|
||||
#define TRITON_ATTR_DEFS
|
||||
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
// Attrs for LoadOp
|
||||
def TT_CacheModifierAttr : I32EnumAttr<
|
||||
"CacheModifier", "",
|
||||
[
|
||||
I32EnumAttrCase<"NONE", 1, "none">,
|
||||
I32EnumAttrCase<"CA", 2, "ca">,
|
||||
I32EnumAttrCase<"CG", 3, "cg">,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||
"EvictionPolicy", "",
|
||||
[
|
||||
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
|
||||
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
|
||||
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// reduction
|
||||
def TT_RedOpAttr : I32EnumAttr<
|
||||
/*name*/"RedOp", /*summary*/"",
|
||||
/*case*/
|
||||
[
|
||||
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
||||
I32EnumAttrCase<"FADD", 2, "fadd">,
|
||||
I32EnumAttrCase<"MIN", 3, "min">,
|
||||
I32EnumAttrCase<"MAX", 4, "max">,
|
||||
I32EnumAttrCase<"UMIN", 5, "umin">,
|
||||
I32EnumAttrCase<"UMAX", 6, "umax">,
|
||||
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
|
||||
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
|
||||
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
|
||||
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
|
||||
I32EnumAttrCase<"FMIN", 11, "fmin">,
|
||||
I32EnumAttrCase<"FMAX", 12, "fmax">,
|
||||
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
|
||||
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
|
||||
I32EnumAttrCase<"XOR", 15, "xor">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
// atomic
|
||||
def TT_AtomicRMWAttr : I32EnumAttr<
|
||||
"RMWOp", "",
|
||||
[
|
||||
I32EnumAttrCase<"AND", 1, "and">,
|
||||
I32EnumAttrCase<"OR", 2, "or">,
|
||||
I32EnumAttrCase<"XOR", 3, "xor">,
|
||||
I32EnumAttrCase<"ADD", 4, "add">,
|
||||
I32EnumAttrCase<"FADD", 5, "fadd">,
|
||||
I32EnumAttrCase<"MAX", 6, "max">,
|
||||
I32EnumAttrCase<"MIN", 7, "min">,
|
||||
I32EnumAttrCase<"UMAX", 8, "umax">,
|
||||
I32EnumAttrCase<"UMIN", 9, "umin">,
|
||||
I32EnumAttrCase<"XCHG", 10, "exch">
|
||||
]> {
|
||||
let cppNamespace = "::mlir::triton";
|
||||
}
|
||||
|
||||
#endif
|
46
include/triton/Dialect/Triton/IR/TritonDialect.td
Normal file
46
include/triton/Dialect/Triton/IR/TritonDialect.td
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef TRITON_DIALECT
|
||||
#define TRITON_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Triton_Dialect : Dialect {
|
||||
let name = "tt";
|
||||
|
||||
let cppNamespace = "::mlir::triton";
|
||||
|
||||
let summary = "The Triton IR in MLIR";
|
||||
|
||||
let description = [{
|
||||
Triton Dialect.
|
||||
|
||||
Dependent Dialects:
|
||||
* Arithmetic:
|
||||
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
|
||||
* Math:
|
||||
* exp, sin, cos, log, ...
|
||||
* StructuredControlFlow:
|
||||
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"arith::ArithmeticDialect",
|
||||
"math::MathDialect",
|
||||
"StandardOpsDialect",
|
||||
"scf::SCFDialect",
|
||||
|
||||
// Since LLVM 15
|
||||
// "cf::ControlFlowDialect",
|
||||
// "func::FuncDialect"
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
|
||||
|
||||
#endif // TRITON_DIALECT
|
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
6
include/triton/Dialect/Triton/IR/TritonInterfaces.td
Normal file
@@ -0,0 +1,6 @@
|
||||
#ifndef TRITON_INTERFACES
|
||||
#define TRITON_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
#endif // TRITON_INTERFACES
|
430
include/triton/Dialect/Triton/IR/TritonOps.td
Normal file
430
include/triton/Dialect/Triton/IR/TritonOps.td
Normal file
@@ -0,0 +1,430 @@
|
||||
#ifndef TRITON_OPS
|
||||
#define TRITON_OPS
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
|
||||
|
||||
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
|
||||
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
|
||||
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
|
||||
|
||||
//
|
||||
// Op Base
|
||||
//
|
||||
class TT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
|
||||
}
|
||||
|
||||
//
|
||||
// CastOps
|
||||
//
|
||||
// Use cast ops in arith:
|
||||
// bitcast
|
||||
// fptoui, fptosi, uitofp, sitofp,
|
||||
// extf, tructf,
|
||||
// extui, extsi, tructi
|
||||
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
let arguments = (ins TT_I64Like:$from);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$from);
|
||||
|
||||
let results = (outs TT_I64Like:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
// arith.bitcast doesn't support pointers
|
||||
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast between types of the same bitwidth";
|
||||
|
||||
let arguments = (ins TT_Type:$from);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
|
||||
// TODO: Add verifier
|
||||
}
|
||||
|
||||
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
Floating point casting for custom types (F8).
|
||||
|
||||
F8 <-> FP16, BF16, FP32, FP64
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatLike:$from);
|
||||
|
||||
let results = (outs TT_FloatLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
|
||||
//
|
||||
// Pointer Arith Ops
|
||||
//
|
||||
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Load/Store Ops
|
||||
//
|
||||
def TT_LoadOp : TT_Op<"load",
|
||||
[SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
AttrSizedOperandSegments,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer ptr type from result type",
|
||||
"result", "ptr", "getPointerTypeSameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from result type or none",
|
||||
"result", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from result type or none",
|
||||
"result", "other", "$_self",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "load";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($result)";
|
||||
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_StoreOp : TT_Op<"store",
|
||||
[SameOperandsShape,
|
||||
SameOperandsEncoding,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"value", "ptr",
|
||||
"getPointerTypeSameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"value", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "store";
|
||||
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
|
||||
];
|
||||
|
||||
// let assemblyFormat = "operands attr-dict `:` type($value)";
|
||||
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//
|
||||
// Atomic Op
|
||||
//
|
||||
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
TypesMatchWith<"infer ptr type from value type",
|
||||
"val", "ptr",
|
||||
"getPointerTypeSameShape($_self)">,
|
||||
TypesMatchWith<"infer mask type from value type",
|
||||
"val", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
|
||||
let summary = "atomic rmw";
|
||||
|
||||
let description = [{
|
||||
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
|
||||
|
||||
return old value at $ptr
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
|
||||
TT_Type:$val, Optional<TT_BoolLike>:$mask);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding]> {
|
||||
let summary = "atomic cas";
|
||||
|
||||
let description = [{
|
||||
compare $cmp with data $old at location $ptr,
|
||||
|
||||
if $old == $cmp, store $val to $ptr,
|
||||
|
||||
else store $old to $ptr,
|
||||
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Shape Manipulation Ops
|
||||
//
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "expand_dims";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "view";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
let summary = "concatenate 2 tensors";
|
||||
|
||||
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
|
||||
SameOperandsAndResultElementType]> {
|
||||
|
||||
let summary = "transpose a tensor";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
//
|
||||
// SPMD Ops
|
||||
//
|
||||
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
|
||||
let arguments = (ins I32Attr:$axis);
|
||||
|
||||
let results = (outs I32:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Dot Op
|
||||
//
|
||||
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
TypesMatchWith<"result's type matches accumulator's type",
|
||||
"d", "c", "$_self">]> {
|
||||
let summary = "dot";
|
||||
|
||||
let description = [{
|
||||
$d = matrix_multiply($a, $b) + $c
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
|
||||
|
||||
let results = (outs TT_FpIntTensor:$d);
|
||||
|
||||
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
|
||||
}
|
||||
|
||||
//
|
||||
// Reduce Op
|
||||
//
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
|
||||
];
|
||||
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// This member function is marked static because we need to call it before the ReduceOp
|
||||
// is constructed, see the implementation of create_reduce in triton.cc.
|
||||
static bool withIndex(mlir::triton::RedOp redOp);
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
// External elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
let summary = "ext_elemwise";
|
||||
|
||||
let description = [{
|
||||
call an external function $symbol implemented in $libpath/$libname with $args
|
||||
|
||||
return $libpath/$libname:$symbol($args...)
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Make Range Op
|
||||
//
|
||||
// TODO: should have ConstantLike as Trait
|
||||
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
let summary = "make range";
|
||||
|
||||
let description = [{
|
||||
Returns an 1D int32 tensor.
|
||||
|
||||
Values span from $start to $end (exclusive), with step = 1
|
||||
}];
|
||||
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//
|
||||
// Make PrintfOp
|
||||
//
|
||||
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix,
|
||||
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
let summary = "Device-side printf, as in CUDA for debugging";
|
||||
let description = [{
|
||||
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||
format are generated automatically from the arguments.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$prefix attr-dict ($args^ `:` type($args))?
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
71
include/triton/Dialect/Triton/IR/TritonTypes.td
Normal file
71
include/triton/Dialect/Triton/IR/TritonTypes.td
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef TRITON_TYPES
|
||||
#define TRITON_TYPES
|
||||
|
||||
include "triton/Dialect/Triton/IR/TritonDialect.td"
|
||||
|
||||
//
|
||||
// Types
|
||||
//
|
||||
class TritonTypeDef<string name, string _mnemonic>
|
||||
: TypeDef<Triton_Dialect, name> {
|
||||
// Used by printer/parser
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
// Boolean Type
|
||||
// TT_Bool -> I1
|
||||
def TT_BoolTensor : TensorOf<[I1]>;
|
||||
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||
|
||||
// Integer Type
|
||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||
def TT_IntTensor : TensorOf<[TT_Int]>;
|
||||
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
|
||||
|
||||
// I32 Type
|
||||
// TT_I32 -> I32
|
||||
// TT_I32Tensor -> I32Tensor
|
||||
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
|
||||
|
||||
// I64 Type
|
||||
// TT_I64 -> I64
|
||||
// TT_I64Tensor -> I64Tensor
|
||||
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
|
||||
|
||||
// Pointer Type
|
||||
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
let description = [{
|
||||
Triton PointerType
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins
|
||||
"Type":$pointeeType,
|
||||
"int":$addressSpace
|
||||
), [{
|
||||
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
|
||||
}]>
|
||||
];
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
|
||||
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
|
||||
|
||||
#endif
|
10
include/triton/Dialect/Triton/IR/Types.h
Normal file
10
include/triton/Dialect/Triton/IR/Types.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#ifndef TRITON_IR_TYPES_H_
|
||||
#define TRITON_IR_TYPES_H_
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/Types.h.inc"
|
||||
|
||||
#endif // TRITON_IR_TYPES_H_
|
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
3
include/triton/Dialect/Triton/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
|
||||
add_public_tablegen_target(TritonTransformsIncGen)
|
18
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
18
include/triton/Dialect/Triton/Transforms/Passes.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<Pass> createCombineOpsPass();
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
23
include/triton/Dialect/Triton/Transforms/Passes.td
Normal file
@@ -0,0 +1,23 @@
|
||||
#ifndef TRITON_PASSES
|
||||
#define TRITON_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
|
||||
let summary = "combine ops";
|
||||
let description = [{
|
||||
dot(a, b, 0) + c => dot(a, b, c)
|
||||
|
||||
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
|
||||
|
||||
select(cond, load(ptrs, broadcast(cond), ???), other) =>
|
||||
load(ptrs, broadcast(cond), other)
|
||||
}];
|
||||
|
||||
let constructor = "mlir::triton::createCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
/*SelectOp*/"mlir::StandardOpsDialect"];
|
||||
}
|
||||
|
||||
#endif
|
2
include/triton/Dialect/TritonGPU/CMakeLists.txt
Normal file
2
include/triton/Dialect/TritonGPU/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
12
include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Normal file
12
include/triton/Dialect/TritonGPU/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
|
||||
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
|
||||
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(TritonGPUTableGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
|
||||
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(TritonGPUAttrDefsIncGen)
|
||||
|
46
include/triton/Dialect/TritonGPU/IR/Dialect.h
Normal file
46
include/triton/Dialect/TritonGPU/IR/Dialect.h
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
// TritonGPU depends on Triton
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace gpu {
|
||||
|
||||
unsigned getElemsPerThread(Type type);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
|
||||
|
||||
SmallVector<unsigned> getOrder(const Attribute &layout);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
31
include/triton/Dialect/TritonGPU/IR/Traits.h
Normal file
31
include/triton/Dialect/TritonGPU/IR/Traits.h
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef TRITON_GPU_IR_TRAITS_H_
|
||||
#define TRITON_GPU_IR_TRAITS_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
// These functions are out-of-line implementations of the methods in the
|
||||
// corresponding trait classes. This avoids them being template
|
||||
// instantiated/duplicated.
|
||||
namespace impl {
|
||||
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
|
||||
} // namespace impl
|
||||
|
||||
template <typename ConcreteType>
|
||||
class ResultsAreSharedEncoding
|
||||
: public TraitBase<ConcreteType, ResultsAreSharedEncoding> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifyResultsAreSharedEncoding(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
430
include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Normal file
430
include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Normal file
@@ -0,0 +1,430 @@
|
||||
#ifndef TRITONGPU_ATTRDEFS
|
||||
#define TRITONGPU_ATTRDEFS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TritonGPU Attribute Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
|
||||
|
||||
let description = [{
|
||||
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
|
||||
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
|
||||
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
|
||||
to the indices of the CUDA threads allowed to access some data at index $i$.
|
||||
|
||||
For example, let us consider the layout function:
|
||||
\mathcal{L}(0, 0) = {0, 4}
|
||||
\mathcal{L}(0, 1) = {1, 5}
|
||||
\mathcal{L}(1, 0) = {2, 6}
|
||||
\mathcal{L}(1, 1) = {3, 7}
|
||||
|
||||
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
|
||||
- T[0,0] is owned by both cuda thread 0 and 4
|
||||
- T[0,1] is owned by both cuda thread 1 and 5
|
||||
- T[1,0] is owned by both cuda thread 2 and 6
|
||||
- T[1,1] is owned by both cuda thread 3 and 7
|
||||
|
||||
Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> {
|
||||
let mnemonic = "shared";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors whose elements may be simultaneously accessed by
|
||||
different cuda threads in the programs, via shared memory. In other words,
|
||||
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
|
||||
|
||||
In order to avoid shared memory bank conflicts, elements may be swizzled
|
||||
in memory. For example, a swizzled row-major layout could store its data
|
||||
as follows:
|
||||
|
||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
|
||||
groups of vec=2 elements
|
||||
are stored contiguously
|
||||
_ _ _ _ /\_ _ _ _
|
||||
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
|
||||
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
// swizzle info
|
||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
|
||||
"ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"Type":$eltTy), [{
|
||||
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
if(!mmaEnc)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
int version = mmaEnc.getVersion();
|
||||
int opIdx = dotOpEnc.getOpIdx();
|
||||
|
||||
// number of rows per phase
|
||||
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
|
||||
perPhase = std::max<int>(perPhase, 1);
|
||||
|
||||
// index of the inner dimension in `order`
|
||||
unsigned inner = (opIdx == 0) ? 0 : 1;
|
||||
|
||||
// ---- begin version 1 ----
|
||||
if (version == 1) {
|
||||
bool is_row = order[0] != 0;
|
||||
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
|
||||
is_row && (shape[order[0]] <= 16);
|
||||
// TODO[Superjomn]: Support the case when is_vec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
|
||||
is_vec4 = true;
|
||||
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
|
||||
((is_row && !is_vec4) ? 2 : 1);
|
||||
int rep = 2 * pack_size;
|
||||
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
|
||||
int vec = 2 * rep;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// ---- begin version 2 ----
|
||||
if (version == 2) {
|
||||
std::vector<size_t> matShape = {8, 8,
|
||||
2 * 64 / eltTy.getIntOrFloatBitWidth()};
|
||||
// for now, disable swizzle when using transposed int8 tensor cores
|
||||
if (eltTy.isInteger(8) && order[0] == inner)
|
||||
return $_get(context, 1, 1, 1, order);
|
||||
|
||||
// --- handle A operand ---
|
||||
if (opIdx == 0) { // compute swizzling for A operand
|
||||
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
|
||||
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
// --- handle B operand ---
|
||||
if (opIdx == 1) {
|
||||
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
|
||||
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
|
||||
int maxPhase = mmaStride / perPhase;
|
||||
return $_get(context, vec, perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
llvm_unreachable("invalid operand index");
|
||||
}
|
||||
|
||||
// ---- not implemented ----
|
||||
llvm_unreachable("unsupported swizzling for provided MMA version");
|
||||
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Distributed Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class DistributedEncoding<string name> : TritonGPU_Attr<name> {
|
||||
let description = [{
|
||||
Distributed encodings have a layout function that is entirely characterized
|
||||
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
|
||||
(or even the same rank) as the tensor it is encoding.
|
||||
|
||||
The layout function \mathcal{L} of this layout is then defined, for an
|
||||
index `i` \in R^D, as follows:
|
||||
|
||||
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
|
||||
|
||||
For example, for a tensor/layout pair
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
||||
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
|
||||
let mnemonic = "blocked";
|
||||
|
||||
let description = [{
|
||||
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
||||
used to promote memory coalescing in LoadInst and StoreInst.
|
||||
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
|
||||
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
|
||||
|
||||
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
|
||||
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
|
||||
...
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
|
||||
|
||||
for
|
||||
|
||||
#triton_gpu.blocked_layout<{
|
||||
sizePerThread = {2, 2}
|
||||
threadsPerWarp = {8, 4}
|
||||
warpsPerCTA = {1, 2}
|
||||
}>
|
||||
}];
|
||||
|
||||
|
||||
let builders = [
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// TODO: compiles on MacOS but not linux?
|
||||
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
|
||||
// "ArrayRef<unsigned>":$threadsPerWarp,
|
||||
// "ArrayRef<unsigned>":$warpsPerCTA,
|
||||
// "ArrayRef<unsigned>":$order), [{
|
||||
// int rank = threadsPerWarp.size();
|
||||
// SmallVector<unsigned, 4> sizePerWarp(rank);
|
||||
// SmallVector<unsigned, 4> sizePerCTA(rank);
|
||||
// for (unsigned i = 0; i < rank; i++) {
|
||||
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
|
||||
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
|
||||
// }
|
||||
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
|
||||
// }]>,
|
||||
// Custom builder initializes sizePerWarp and sizePerCTA automatically
|
||||
// Default builder takes sizePerThread, order and numWarps, and tries to
|
||||
// pack numWarps*32 threads in the provided order for use in a type
|
||||
// of the given shape.
|
||||
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
|
||||
"ArrayRef<unsigned>":$sizePerThread,
|
||||
"ArrayRef<unsigned>":$order,
|
||||
"unsigned":$numWarps), [{
|
||||
int rank = sizePerThread.size();
|
||||
unsigned remainingLanes = 32;
|
||||
unsigned remainingThreads = numWarps*32;
|
||||
unsigned remainingWarps = numWarps;
|
||||
unsigned prevLanes = 1;
|
||||
unsigned prevWarps = 1;
|
||||
SmallVector<unsigned, 4> threadsPerWarp(rank);
|
||||
SmallVector<unsigned, 4> warpsPerCTA(rank);
|
||||
for (int _dim = 0; _dim < rank - 1; ++_dim) {
|
||||
int i = order[_dim];
|
||||
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
|
||||
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
|
||||
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
|
||||
remainingWarps /= warpsPerCTA[i];
|
||||
remainingLanes /= threadsPerWarp[i];
|
||||
remainingThreads /= threadsPerCTA;
|
||||
prevLanes *= threadsPerWarp[i];
|
||||
prevWarps *= warpsPerCTA[i];
|
||||
}
|
||||
// Expand the last dimension to fill the remaining lanes and warps
|
||||
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
|
||||
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
|
||||
|
||||
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$sizePerThread,
|
||||
ArrayRefParameter<"unsigned">:$threadsPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
// fastest-changing axis first
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
// These attributes can be inferred from the rest
|
||||
// ArrayRefParameter<"unsigned">:$sizePerWarp,
|
||||
// ArrayRefParameter<"unsigned">:$sizePerCTA
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MMA Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: MMAv1 and MMAv2 should be two instances of the same class
|
||||
|
||||
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
|
||||
let mnemonic = "mma";
|
||||
|
||||
let description = [{
|
||||
An encoding for tensors that have been produced by tensor cores.
|
||||
It is characterized by two parameters:
|
||||
- A 'version' which specifies the generation the tensor cores
|
||||
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
|
||||
and 2 for second-gen tensor cores (Turing/Ampere).
|
||||
- A `blockTileSize` to indicate how data should be
|
||||
partitioned between warps.
|
||||
|
||||
// -------------------------------- version = 1 --------------------------- //
|
||||
|
||||
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
|
||||
Note: the layout is different from the recommended in PTX ISA
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.884 section, FP32 accumulator).
|
||||
|
||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
|
||||
warp 0
|
||||
--------------------------------/\-------------------------------
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
|
||||
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
|
||||
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
|
||||
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
|
||||
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
|
||||
|
||||
warp 1 = warp0 + 32
|
||||
--------------------------------/\-------------------------------
|
||||
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
|
||||
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
|
||||
[ ............................................................... ]
|
||||
|
||||
|
||||
// -------------------------------- version = 2 --------------------------- //
|
||||
|
||||
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
|
||||
Information about this layout can be found in the official PTX documentation
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
|
||||
(mma.16816 section, FP32 accumulator).
|
||||
|
||||
For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
warp 0 warp 1
|
||||
-----------------/\------------- ----------------/\-------------
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
|
||||
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
|
||||
[ .............................. ..............................
|
||||
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
|
||||
|
||||
warp 3 warp 4
|
||||
----------------/\------------- ----------------/\-------------
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
|
||||
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
|
||||
[ .............................. ...............................
|
||||
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$version,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
let mnemonic = "slice";
|
||||
|
||||
let description = [{
|
||||
TODO: improve docs
|
||||
|
||||
A = [x x x x x x x x]
|
||||
|
||||
parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
dim = 0
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
|
||||
|
||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$dim,
|
||||
// TODO: constraint here to only take distributed encodings
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
template<class T>
|
||||
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
|
||||
}];
|
||||
}
|
||||
|
||||
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
|
||||
let mnemonic = "dot_op";
|
||||
|
||||
let description = [{
|
||||
In TritonGPU dialect, considering `d = tt.dot a, b, c`
|
||||
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
|
||||
a's opIdx is 0, b's opIdx is 1.
|
||||
The parend field in DotOperandEncodingAttr is the layout of d.
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
}
|
||||
|
||||
#endif
|
36
include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Normal file
36
include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef TRITONGPU_DIALECT
|
||||
#define TRITONGPU_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TritonGPU_Dialect : Dialect {
|
||||
let name = "triton_gpu";
|
||||
|
||||
let cppNamespace = "::mlir::triton::gpu";
|
||||
|
||||
let hasOperationAttrVerify = 1;
|
||||
|
||||
let description = [{
|
||||
Triton GPU Dialect.
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
"triton::TritonDialect",
|
||||
"mlir::gpu::GPUDialect",
|
||||
"tensor::TensorDialect",
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
|
||||
static int getNumWarps(ModuleOp mod) {
|
||||
if(!mod->hasAttr("triton_gpu.num-warps"))
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
|
||||
}
|
||||
}];
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif
|
180
include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Normal file
180
include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Normal file
@@ -0,0 +1,180 @@
|
||||
#ifndef TRITONGPU_OPS
|
||||
#define TRITONGPU_OPS
|
||||
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
|
||||
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
|
||||
include "triton/Dialect/Triton/IR/TritonTypes.td"
|
||||
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
|
||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[SameOperandsAndResultShape, NoSideEffect]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
|
||||
let summary = "async wait";
|
||||
|
||||
let arguments = (ins I32Attr:$num);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
|
||||
// This is needed because these ops don't
|
||||
// handle encodings
|
||||
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||
TT_IntLike:$lhs,
|
||||
TT_IntLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
|
||||
let summary = "floating-point comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
|
||||
TT_FloatLike:$lhs,
|
||||
TT_FloatLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
||||
// TODO: migrate to arith::SelectOp on LLVM16
|
||||
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
|
||||
let summary = "select operation";
|
||||
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins TT_BoolLike:$condition,
|
||||
TT_Tensor:$true_value,
|
||||
TT_Tensor:$false_value);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
MemoryEffects<[MemRead]>,
|
||||
TypesMatchWith<"infer mask type from src type",
|
||||
"src", "mask", "getI1SameShape($_self)",
|
||||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
|
||||
TypesMatchWith<"infer other type from src type",
|
||||
"src", "other", "getPointeeType($_self)",
|
||||
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
|
||||
let summary = "insert slice async";
|
||||
|
||||
let description = [{
|
||||
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operation’s
|
||||
`$index` argument and `$axis` attribute.
|
||||
|
||||
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
|
||||
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
|
||||
|
||||
When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
|
||||
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
|
||||
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
|
||||
|
||||
The insert_slice_async operation supports the following arguments:
|
||||
|
||||
* src: the tensor that is inserted.
|
||||
* dst: the tensor into which the `$src` tensor is inserted.
|
||||
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
|
||||
* mask: optional tensor-rank number of boolean masks which specify which
|
||||
elements of the `$src` tensor are inserted into the `$dst` tensor.
|
||||
* other: optional tensor-rank number of other tensors which specify what
|
||||
values are inserted into the `$dst` tensor if the corresponding
|
||||
element of the `$mask` tensor is false.
|
||||
|
||||
In the future, we may decompose this operation into a sequence of:
|
||||
|
||||
* `async` operation to specify a sequence of asynchronous operations
|
||||
* `load` operation to load a tensor from global memory
|
||||
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
|
||||
%2 = triton_gpu.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
|
||||
triiton_gpu.async_wait { num = 0 : i32 }
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
|
||||
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
|
||||
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
|
||||
BoolAttr:$isVolatile, I32Attr:$axis);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
|
||||
"Value":$mask, "Value":$other,
|
||||
"triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
|
||||
];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
//let assemblyFormat = [{
|
||||
// $src `,` $dst ``
|
||||
// $index, $mask, $other
|
||||
// attr-dict `:` type($src) `->` type($dst)
|
||||
//}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability);
|
||||
}];
|
||||
|
||||
// The custom parser could be replaced with oilist in LLVM-16
|
||||
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
|
||||
|
||||
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
|
||||
}
|
||||
|
||||
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
|
||||
ResultsAreSharedEncoding]> {
|
||||
let summary = "allocate tensor";
|
||||
|
||||
let description = [{
|
||||
This operation defines a tensor of a particular shape.
|
||||
The contents of the tensor are supposed to be in shared memory.
|
||||
|
||||
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{attr-dict `:` type($result)}];
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
}
|
||||
|
||||
#endif
|
@@ -0,0 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU)
|
||||
add_public_tablegen_target(TritonGPUTransformsIncGen)
|
25
include/triton/Dialect/TritonGPU/Transforms/Passes.h
Normal file
25
include/triton/Dialect/TritonGPU/Transforms/Passes.h
Normal file
@@ -0,0 +1,25 @@
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
// TODO(Keren): prefetch pass not working yet
|
||||
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
/// Generate the code for registering passes.
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
#endif
|
87
include/triton/Dialect/TritonGPU/Transforms/Passes.td
Normal file
87
include/triton/Dialect/TritonGPU/Transforms/Passes.td
Normal file
@@ -0,0 +1,87 @@
|
||||
#ifndef TRITONGPU_PASSES
|
||||
#define TRITONGPU_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
|
||||
let summary = "pipeline";
|
||||
|
||||
let description = [{
|
||||
Unroll loops to hide global memory -> shared memory latency.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPipelinePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithmeticDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"numStages", "num-stages",
|
||||
"int32_t", /*default*/"2",
|
||||
"number of pipeline stages">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
|
||||
let summary = "prefetch";
|
||||
|
||||
let description = [{
|
||||
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUPrefetchPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::arith::ArithmeticDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let summary = "coalesce";
|
||||
|
||||
let description = [{
|
||||
TODO
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCoalescePass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
let summary = "combine triton gpu ops";
|
||||
|
||||
let description = [{
|
||||
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
|
||||
convert_layout(%src, #LAYOUT_1)
|
||||
|
||||
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCombineOpsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
|
||||
let options = [
|
||||
Option<"computeCapability", "compute-capability",
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
let description = [{
|
||||
This implements some optimizations that are missing in the standard scf.ForOp
|
||||
canonicalizer.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
#endif
|
@@ -0,0 +1,32 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines utilities to use while converting to the TritonGPU dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class TritonGPUTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
|
||||
|
||||
private:
|
||||
MLIRContext *context;
|
||||
int numWarps;
|
||||
};
|
||||
|
||||
class TritonGPUConversionTarget : public ConversionTarget {
|
||||
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx,
|
||||
TritonGPUTypeConverter &typeConverter);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
|
40
include/triton/Target/LLVMIR/LLVMIRTranslation.h
Normal file
40
include/triton/Target/LLVMIR/LLVMIRTranslation.h
Normal file
@@ -0,0 +1,40 @@
|
||||
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||
#define TRITON_TARGET_LLVMIRTRANSLATION_H
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// add external dependent libs
|
||||
void addExternalLibs(mlir::ModuleOp &module,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<std::string> &paths);
|
||||
|
||||
// Translate TritonGPU dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
mlir::ModuleOp module,
|
||||
int computeCapability);
|
||||
|
||||
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
|
||||
std::unique_ptr<llvm::Module>
|
||||
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
|
||||
|
||||
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_TARGET_LLVMIRTRANSLATION_H
|
17
include/triton/Target/PTX/PTXTranslation.h
Normal file
17
include/triton/Target/PTX/PTXTranslation.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifndef TRITON_TARGET_PTXTRANSLATION_H
|
||||
#define TRITON_TARGET_PTXTRANSLATION_H
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
} // namespace llvm
|
||||
|
||||
namespace triton {
|
||||
|
||||
// Translate TritonGPU IR to PTX code.
|
||||
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
|
||||
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -1,80 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class value;
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class cast_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class align {
|
||||
private:
|
||||
struct cst_info {
|
||||
unsigned num_cst;
|
||||
unsigned value;
|
||||
};
|
||||
// helpers
|
||||
std::vector<unsigned> get_shapes(ir::value *v);
|
||||
// populate is_constant
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||
// populate max_contiguous
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
void populate(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value* v, unsigned ax) const;
|
||||
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||
|
||||
private:
|
||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
|
||||
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,47 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class function;
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class tiles;
|
||||
|
||||
class liveness;
|
||||
class cts;
|
||||
|
||||
class allocation {
|
||||
public:
|
||||
allocation(liveness *live)
|
||||
: liveness_(live) { }
|
||||
// accessors
|
||||
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
|
||||
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
|
||||
unsigned allocated_size() const { return allocated_size_; }
|
||||
// run
|
||||
void run(ir::module& mod);
|
||||
|
||||
private:
|
||||
std::map<const data_layout*, unsigned> offsets_;
|
||||
size_t allocated_size_;
|
||||
// dependences
|
||||
liveness *liveness_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,52 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
|
||||
|
||||
#include "triton/tools/graph.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
|
||||
private:
|
||||
// update graph
|
||||
void update_graph_store(ir::instruction *i);
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i,
|
||||
bool is_masked_load_async=false);
|
||||
void update_graph_no_edge(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
// accessors
|
||||
int get(ir::value *value, unsigned dim);
|
||||
std::vector<int> get(ir::value *value);
|
||||
|
||||
private:
|
||||
tools::graph<node_t> graph_;
|
||||
std::map<node_t, size_t> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,345 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "triton/tools/graph.h"
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class type;
|
||||
class module;
|
||||
class instruction;
|
||||
class phi_node;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
class axes;
|
||||
class align;
|
||||
class layout_visitor;
|
||||
class data_layout;
|
||||
class mma_layout;
|
||||
class scanline_layout;
|
||||
class shared_layout;
|
||||
|
||||
|
||||
class layout_visitor {
|
||||
public:
|
||||
virtual void visit_layout(data_layout *);
|
||||
virtual void visit_layout_mma(mma_layout*) = 0;
|
||||
virtual void visit_layout_scanline(scanline_layout*) = 0;
|
||||
virtual void visit_layout_shared(shared_layout*) = 0;
|
||||
};
|
||||
|
||||
class data_layout {
|
||||
protected:
|
||||
enum id_t {
|
||||
MMA,
|
||||
SCANLINE,
|
||||
SHARED
|
||||
};
|
||||
|
||||
typedef std::vector<int> axes_t;
|
||||
typedef std::vector<unsigned> shape_t;
|
||||
typedef std::vector<int> order_t;
|
||||
typedef std::vector<ir::value*> values_t;
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
T* downcast(id_t id) {
|
||||
if(id_ == id)
|
||||
return static_cast<T*>(this);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
public:
|
||||
data_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned> &shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align);
|
||||
// visitor
|
||||
virtual void accept(layout_visitor* vst) = 0;
|
||||
// downcast
|
||||
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
|
||||
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
|
||||
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
|
||||
// accessors
|
||||
size_t get_rank() { return shape_.size(); }
|
||||
const shape_t& get_shape() const { return shape_; }
|
||||
const order_t& get_order() const { return order_; }
|
||||
const values_t& get_values() const { return values_;}
|
||||
int get_axis(size_t k) const { return axes_.at(k); }
|
||||
std::vector<int> get_axes() const { return axes_; }
|
||||
const int get_order(size_t k) const { return order_.at(k); }
|
||||
// find the position of given axis
|
||||
int find_axis(int to_find) const;
|
||||
|
||||
|
||||
private:
|
||||
id_t id_;
|
||||
axes_t axes_;
|
||||
values_t values_;
|
||||
|
||||
protected:
|
||||
order_t order_;
|
||||
shape_t shape_;
|
||||
};
|
||||
|
||||
class distributed_layout: public data_layout{
|
||||
public:
|
||||
distributed_layout(id_t id,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value*>& values,
|
||||
analysis::align* align);
|
||||
|
||||
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
|
||||
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
|
||||
virtual int contig_per_thread(size_t k) = 0;
|
||||
|
||||
protected:
|
||||
std::vector<int> shape_per_cta_;
|
||||
};
|
||||
|
||||
class mma_layout: public distributed_layout {
|
||||
public:
|
||||
enum TensorCoreType : uint8_t {
|
||||
// floating-point tensor core instr
|
||||
FP32_FP16_FP16_FP32 = 0, // default
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_TF32_TF32_FP32,
|
||||
// integer tensor core instr
|
||||
INT32_INT1_INT1_INT32, // Not implemented
|
||||
INT32_INT4_INT4_INT32, // Not implemented
|
||||
INT32_INT8_INT8_INT32, // Not implemented
|
||||
//
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
// Used on nvidia GPUs with sm >= 80
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {16, 8, 16}},
|
||||
{FP32_BF16_BF16_FP32, {16, 8, 16}},
|
||||
{FP32_TF32_TF32_FP32, {16, 8, 8}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {16, 8, 256}},
|
||||
{INT32_INT4_INT4_INT32, {16, 8, 64}},
|
||||
{INT32_INT8_INT8_INT32, {16, 8, 32}},
|
||||
};
|
||||
|
||||
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
|
||||
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
|
||||
{FP32_FP16_FP16_FP32, {8, 8, 8}},
|
||||
{FP32_BF16_BF16_FP32, {8, 8, 8}},
|
||||
{FP32_TF32_TF32_FP32, {8, 8, 4}},
|
||||
|
||||
{INT32_INT1_INT1_INT32, {8, 8, 64}},
|
||||
{INT32_INT4_INT4_INT32, {8, 8, 32}},
|
||||
{INT32_INT8_INT8_INT32, {8, 8, 16}},
|
||||
};
|
||||
|
||||
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
|
||||
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
|
||||
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
|
||||
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
|
||||
|
||||
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
|
||||
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
|
||||
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
|
||||
};
|
||||
|
||||
// vector length per ldmatrix (16*8/elelment_size_in_bits)
|
||||
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
|
||||
{FP32_FP16_FP16_FP32, 8},
|
||||
{FP32_BF16_BF16_FP32, 8},
|
||||
{FP32_TF32_TF32_FP32, 4},
|
||||
|
||||
{INT32_INT1_INT1_INT32, 128},
|
||||
{INT32_INT4_INT4_INT32, 32},
|
||||
{INT32_INT8_INT8_INT32, 16},
|
||||
};
|
||||
|
||||
public:
|
||||
mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target *tgt,
|
||||
shared_layout* layout_a,
|
||||
shared_layout* layout_b,
|
||||
ir::value *dot);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
|
||||
// accessor
|
||||
int fpw(size_t k) { return fpw_.at(k); }
|
||||
int wpt(size_t k) { return wpt_.at(k); }
|
||||
int spw(size_t k) { return spw_.at(k); }
|
||||
int rep(size_t k) { return rep_.at(k); }
|
||||
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
|
||||
|
||||
// helpers for generator.cc
|
||||
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
|
||||
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
|
||||
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
|
||||
|
||||
// setter
|
||||
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
|
||||
|
||||
private:
|
||||
// fragment per warp
|
||||
std::vector<int> fpw_;
|
||||
// shape per warp
|
||||
std::vector<int> spw_;
|
||||
// warp per tile
|
||||
std::vector<int> wpt_;
|
||||
// shape per tile
|
||||
std::vector<int> spt_;
|
||||
// repetitions
|
||||
std::vector<int> rep_;
|
||||
// contiguous per thread
|
||||
std::vector<int> contig_per_thread_;
|
||||
|
||||
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
|
||||
};
|
||||
|
||||
struct scanline_layout: public distributed_layout {
|
||||
scanline_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align,
|
||||
target* tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
|
||||
// accessor
|
||||
int mts(size_t k) { return mts_.at(k); }
|
||||
int nts(size_t k) { return nts_.at(k); }
|
||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||
|
||||
public:
|
||||
// micro tile size. The size of a tile held by a thread block.
|
||||
std::vector<int> mts_;
|
||||
// nano tile size. The size of a tile held by a thread.
|
||||
std::vector<int> nts_;
|
||||
};
|
||||
|
||||
struct double_buffer_info_t {
|
||||
ir::value* first;
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
};
|
||||
|
||||
struct N_buffer_info_t {
|
||||
std::vector<ir::value*> firsts; // not necessarily ordered as input order
|
||||
ir::value* latch;
|
||||
ir::phi_node* phi;
|
||||
std::map<ir::value*, int> firsts_idx;
|
||||
};
|
||||
|
||||
// abstract for dot and coresponding smem values
|
||||
class shared_layout: public data_layout {
|
||||
private:
|
||||
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
|
||||
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
|
||||
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
|
||||
|
||||
public:
|
||||
shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shapes,
|
||||
const std::vector<ir::value *> &values_,
|
||||
ir::type *ty,
|
||||
analysis::align* align, target *tgt);
|
||||
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
|
||||
// accessors
|
||||
size_t get_size() { return size_; }
|
||||
ir::type* get_type() { return ty_; }
|
||||
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
|
||||
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
|
||||
int get_num_stages() const;
|
||||
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
|
||||
size_t get_per_stage_elements() const;
|
||||
size_t get_num_per_phase() { return num_per_phase_; }
|
||||
ir::value* hmma_dot_a() { return hmma_dot_a_; }
|
||||
ir::value* hmma_dot_b() { return hmma_dot_b_; }
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
ir::type *ty_;
|
||||
std::shared_ptr<double_buffer_info_t> double_buffer_;
|
||||
std::shared_ptr<N_buffer_info_t> N_buffer_;
|
||||
size_t num_per_phase_;
|
||||
ir::value* hmma_dot_a_;
|
||||
ir::value* hmma_dot_b_;
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class layouts {
|
||||
typedef ir::value* node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
// graph creation
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
|
||||
void init_hmma_tile(data_layout& layouts);
|
||||
void init_scanline_tile(data_layout &layouts);
|
||||
|
||||
void create(size_t id, const std::vector<ir::value*>& values);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
|
||||
|
||||
// accessors
|
||||
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
|
||||
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
|
||||
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
|
||||
size_t num_layouts() const { return values_.size();}
|
||||
data_layout* get(size_t id) { return layouts_.at(id); }
|
||||
data_layout* get(ir::value *v) { return get(layout_of(v));}
|
||||
std::map<size_t, data_layout*> &get_all() { return layouts_; }
|
||||
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
|
||||
int tmp(ir::value* i) { return tmp_.at(i);}
|
||||
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
|
||||
// execution
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::axes* axes_;
|
||||
analysis::align* align_;
|
||||
size_t num_warps_;
|
||||
target* tgt_;
|
||||
tools::graph<ir::value*> graph_;
|
||||
std::map<ir::value*, size_t> groups_;
|
||||
std::map<size_t, std::vector<ir::value*>> values_;
|
||||
std::map<size_t, data_layout*> layouts_;
|
||||
std::map<ir::value*, size_t> tmp_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,67 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/tools/graph.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class value;
|
||||
class phi_node;
|
||||
class function;
|
||||
class module;
|
||||
class instruction;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
typedef unsigned slot_index;
|
||||
|
||||
class tiles;
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
struct segment {
|
||||
slot_index start;
|
||||
slot_index end;
|
||||
|
||||
bool contains(slot_index idx) const {
|
||||
return start <= idx && idx < end;
|
||||
}
|
||||
|
||||
bool intersect(const segment &Other){
|
||||
return contains(Other.start) || Other.contains(start);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class liveness {
|
||||
private:
|
||||
typedef std::map<shared_layout*, segment> intervals_map_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
liveness(layouts *l): layouts_(l){ }
|
||||
// accessors
|
||||
const intervals_map_t& get() const { return intervals_; }
|
||||
segment get(shared_layout* v) const { return intervals_.at(v); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
// analysis
|
||||
layouts *layouts_;
|
||||
intervals_map_t intervals_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,43 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
class target;
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class layouts;
|
||||
class data_layout;
|
||||
|
||||
class swizzle {
|
||||
public:
|
||||
// constructor
|
||||
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
|
||||
// accessors
|
||||
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
|
||||
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
|
||||
int get_vec (data_layout* layout) { return vec_.at(layout); }
|
||||
// run
|
||||
void run(ir::module &mod);
|
||||
private:
|
||||
layouts* layouts_;
|
||||
target* tgt_;
|
||||
std::map<data_layout*, int> per_phase_;
|
||||
std::map<data_layout*, int> max_phase_;
|
||||
std::map<data_layout*, int> vec_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,42 +0,0 @@
|
||||
#ifndef _TRITON_CODEGEN_PASS_H_
|
||||
#define _TRITON_CODEGEN_PASS_H_
|
||||
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
namespace driver{
|
||||
class device;
|
||||
class module;
|
||||
class kernel;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
|
||||
codegen::target* target,
|
||||
int sm, int num_warps,
|
||||
int num_stages, int &shared_static);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,258 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_SELECTION_GENERATOR_H_
|
||||
#define _TRITON_SELECTION_GENERATOR_H_
|
||||
|
||||
#include "triton/ir/visitor.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include <functional>
|
||||
|
||||
// forward
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class PHINode;
|
||||
class BasicBlock;
|
||||
class Attribute;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class attribute;
|
||||
class load_inst;
|
||||
class store_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
// forward
|
||||
namespace analysis{
|
||||
class liveness;
|
||||
class tiles;
|
||||
class align;
|
||||
class allocation;
|
||||
class cts;
|
||||
class axes;
|
||||
class layouts;
|
||||
class swizzle;
|
||||
}
|
||||
// typedef
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Attribute Attribute;
|
||||
typedef llvm::BasicBlock BasicBlock;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
typedef std::vector<Value*> indices_t;
|
||||
class target;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
struct distributed_axis {
|
||||
int contiguous;
|
||||
std::vector<Value*> values;
|
||||
Value* thread_id;
|
||||
};
|
||||
|
||||
class adder{
|
||||
public:
|
||||
adder(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class multiplier{
|
||||
public:
|
||||
multiplier(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value* x, Value* y, const std::string& name = "");
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class geper{
|
||||
public:
|
||||
geper(Builder** builder): builder_(builder) { }
|
||||
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
|
||||
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
|
||||
|
||||
private:
|
||||
Builder** builder_;
|
||||
};
|
||||
|
||||
class generator: public ir::visitor, public analysis::layout_visitor {
|
||||
private:
|
||||
void init_idx(ir::value *x);
|
||||
Instruction* add_barrier();
|
||||
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
|
||||
void finalize_shared_layout(analysis::shared_layout*);
|
||||
void finalize_function(ir::function*);
|
||||
void finalize_phi_node(ir::phi_node*);
|
||||
|
||||
private:
|
||||
Type *cvt(ir::type *ty);
|
||||
llvm::Attribute cvt(ir::attribute attr);
|
||||
|
||||
public:
|
||||
generator(analysis::axes *a_axes,
|
||||
analysis::layouts *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
analysis::swizzle *swizzle,
|
||||
target *tgt,
|
||||
unsigned num_warps);
|
||||
|
||||
void visit_value(ir::value* v);
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
void visit_icmp_inst(ir::icmp_inst*);
|
||||
void visit_fcmp_inst(ir::fcmp_inst*);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
|
||||
Value* bf16_to_fp32(Value *in0);
|
||||
Value* fp32_to_bf16(Value *in0);
|
||||
|
||||
void visit_cast_inst(ir::cast_inst*);
|
||||
void visit_return_inst(ir::return_inst*);
|
||||
void visit_cond_branch_inst(ir::cond_branch_inst*);
|
||||
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
|
||||
void visit_load_inst(ir::load_inst*);
|
||||
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
|
||||
void visit_masked_load_inst(ir::masked_load_inst*);
|
||||
void visit_store_inst(ir::store_inst*);
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
void visit_cat_inst(ir::cat_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
void visit_downcast_inst(ir::downcast_inst*);
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_cos_inst(ir::cos_inst*);
|
||||
void visit_umulhi_inst(ir::umulhi_inst* x);
|
||||
void visit_sin_inst(ir::sin_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
|
||||
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
|
||||
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_dot_inst(ir::dot_inst*);
|
||||
void visit_trans_inst(ir::trans_inst*);
|
||||
void visit_sqrt_inst(ir::sqrt_inst*);
|
||||
Value* shfl_sync(Value* acc, int32_t i);
|
||||
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
|
||||
void visit_reduce_inst(ir::reduce_inst*);
|
||||
void visit_select_inst(ir::select_inst*);
|
||||
void visit_layout_convert(ir::value *out, ir::value *in);
|
||||
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
|
||||
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
|
||||
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
|
||||
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
|
||||
void visit_barrier_inst(ir::barrier_inst*);
|
||||
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
|
||||
void visit_async_wait_inst(ir::async_wait_inst*);
|
||||
// void visit_make_range_dyn(ir::make_range_dyn*);
|
||||
void visit_make_range(ir::make_range*);
|
||||
// void visit_make_range_sta(ir::make_range_sta*);
|
||||
void visit_undef_value(ir::undef_value*);
|
||||
void visit_constant_int(ir::constant_int*);
|
||||
void visit_constant_fp(ir::constant_fp*);
|
||||
void visit_alloc_const(ir::alloc_const*);
|
||||
void visit_function(ir::function*);
|
||||
void visit_basic_block(ir::basic_block*);
|
||||
void visit_argument(ir::argument*);
|
||||
void visit(ir::module &, llvm::Module &);
|
||||
|
||||
// layouts
|
||||
void visit_layout_mma(analysis::mma_layout*);
|
||||
void visit_layout_scanline(analysis::scanline_layout*);
|
||||
void visit_layout_shared(analysis::shared_layout*);
|
||||
|
||||
|
||||
private:
|
||||
LLVMContext *ctx_;
|
||||
Builder* builder_;
|
||||
Module *mod_;
|
||||
|
||||
analysis::axes *a_axes_;
|
||||
analysis::swizzle *swizzle_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
target *tgt_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::align *alignment_;
|
||||
analysis::allocation *alloc_;
|
||||
Value *shmem_;
|
||||
std::set<ir::value*> seen_;
|
||||
|
||||
unsigned num_warps_;
|
||||
|
||||
std::map<analysis::data_layout*, Value*> offset_a_m_;
|
||||
std::map<analysis::data_layout*, Value*> offset_a_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_k_;
|
||||
std::map<analysis::data_layout*, Value*> offset_b_n_;
|
||||
|
||||
/// layout -> base ptr
|
||||
std::map<analysis::data_layout*, Value*> shared_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
|
||||
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
|
||||
/// offset for double-buffered layout
|
||||
std::map<analysis::data_layout*, Value*> shared_off_;
|
||||
|
||||
/// Base shmem pointer of ir value
|
||||
std::map<ir::value*, Value*> shmems_;
|
||||
std::map<ir::value*, Value*> shoffs_;
|
||||
std::map<ir::value*, std::vector<indices_t>> idxs_;
|
||||
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
|
||||
/// idx for multi-stage pipeline
|
||||
std::map<analysis::data_layout*, Value*> read_smem_idx_;
|
||||
std::map<analysis::data_layout*, Value*> write_smem_idx_;
|
||||
|
||||
/// triton bb -> llvm bb
|
||||
std::map<ir::value*, BasicBlock *> bbs_;
|
||||
std::map<ir::value*, std::vector<int>> ords_;
|
||||
|
||||
// helper for creating llvm values
|
||||
adder add;
|
||||
multiplier mul;
|
||||
geper gep;
|
||||
|
||||
/// PHI nodes
|
||||
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
|
||||
|
||||
/// Record prefetch instrs that needs to be moved
|
||||
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,105 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
|
||||
|
||||
namespace llvm{
|
||||
class Type;
|
||||
class Value;
|
||||
class Instruction;
|
||||
class Constant;
|
||||
class LLVMContext;
|
||||
class Module;
|
||||
class ConstantFolder;
|
||||
class IRBuilderDefaultInserter;
|
||||
template <typename T, typename Inserter>
|
||||
class IRBuilder;
|
||||
class ArrayType;
|
||||
class Function;
|
||||
}
|
||||
|
||||
// typedefs
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
typedef llvm::IRBuilder<llvm::ConstantFolder,
|
||||
llvm::IRBuilderDefaultInserter> Builder;
|
||||
typedef llvm::LLVMContext LLVMContext;
|
||||
typedef llvm::Type Type;
|
||||
typedef llvm::Value Value;
|
||||
typedef llvm::Module Module;
|
||||
typedef llvm::Instruction Instruction;
|
||||
typedef llvm::Constant Constant;
|
||||
typedef llvm::ArrayType ArrayType;
|
||||
typedef llvm::Function Function;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class nvidia_cu_target;
|
||||
|
||||
class target {
|
||||
public:
|
||||
target(bool is_gpu): is_gpu_(is_gpu){}
|
||||
virtual ~target() {}
|
||||
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
|
||||
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
|
||||
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
|
||||
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
|
||||
virtual unsigned guaranteed_alignment() = 0;
|
||||
nvidia_cu_target* as_nvidia();
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
bool is_gpu_;
|
||||
};
|
||||
|
||||
class amd_cl_target: public target {
|
||||
public:
|
||||
amd_cl_target(): target(true){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
public:
|
||||
nvidia_cu_target(int sm): target(true), sm_(sm){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
int sm() { return sm_; }
|
||||
unsigned guaranteed_alignment() { return 16; }
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
public:
|
||||
cpu_target(): target(false){}
|
||||
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
|
||||
Instruction* add_barrier(Module *module, Builder& builder);
|
||||
Instruction* add_memfence(Module *module, Builder& builder);
|
||||
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
|
||||
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
|
||||
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
|
||||
unsigned guaranteed_alignment() { return 1; }
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,48 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class io_inst;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class align;
|
||||
class layouts;
|
||||
class cts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class coalesce {
|
||||
private:
|
||||
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
|
||||
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
|
||||
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
|
||||
|
||||
public:
|
||||
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
|
||||
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::align* align_;
|
||||
analysis::layouts* layout_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,36 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class phi_node;
|
||||
class instruction;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class cts {
|
||||
private:
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
|
||||
|
||||
public:
|
||||
cts(bool use_async = false): use_async_(use_async) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
bool use_async_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,24 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
|
||||
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class dce {
|
||||
public:
|
||||
dce() {}
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
|
||||
|
||||
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
class disassociate {
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,72 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class masked_load_async_inst;
|
||||
class value;
|
||||
class builder;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
|
||||
class allocation;
|
||||
class liveness;
|
||||
class layouts;
|
||||
class cts;
|
||||
class shared_layout;
|
||||
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class prefetch;
|
||||
|
||||
class membar {
|
||||
private:
|
||||
typedef std::pair<unsigned, unsigned> interval_t;
|
||||
typedef std::set<ir::value*> val_set_t;
|
||||
typedef std::vector<ir::value*> val_vec_t;
|
||||
|
||||
private:
|
||||
bool intersect(const val_set_t &X, const val_set_t &Y);
|
||||
bool check_safe_war(ir::instruction* i);
|
||||
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
|
||||
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
|
||||
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
|
||||
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
|
||||
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
|
||||
|
||||
public:
|
||||
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
|
||||
transform::prefetch *prefetch, target* tgt):
|
||||
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::liveness *liveness_;
|
||||
analysis::layouts *layouts_;
|
||||
analysis::allocation *alloc_;
|
||||
transform::prefetch *prefetch_;
|
||||
|
||||
target* tgt_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,53 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
|
||||
|
||||
#include "triton/codegen/target.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class value;
|
||||
class instruction;
|
||||
class trans_inst;
|
||||
class builder;
|
||||
class constant_int;
|
||||
class dot_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class layouts;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class peephole {
|
||||
private:
|
||||
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
|
||||
|
||||
public:
|
||||
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
target* tgt_;
|
||||
analysis::layouts* layouts_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
|
||||
|
||||
// forward declaration
|
||||
namespace triton {
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
} // namespace triton
|
||||
|
||||
namespace triton {
|
||||
namespace codegen {
|
||||
namespace transform {
|
||||
|
||||
class pipeline {
|
||||
public:
|
||||
pipeline(bool has_copy_async, int num_stages)
|
||||
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
|
||||
void run(ir::module &module);
|
||||
|
||||
private:
|
||||
bool has_copy_async_;
|
||||
int num_stages_;
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
} // namespace codegen
|
||||
} // namespace triton
|
||||
|
||||
#endif
|
@@ -1,27 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
|
||||
|
||||
#include <set>
|
||||
|
||||
// forward dclaration
|
||||
namespace triton::ir{
|
||||
class module;
|
||||
class value;
|
||||
}
|
||||
|
||||
namespace triton::codegen {
|
||||
class target;
|
||||
}
|
||||
|
||||
namespace triton::codegen::transform {
|
||||
class prefetch {
|
||||
target* tgt_;
|
||||
std::set<ir::value*> prefetched_vals_;
|
||||
public:
|
||||
prefetch(target *tgt) : tgt_(tgt) {}
|
||||
void run(ir::module &module);
|
||||
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
|
||||
};
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,26 +0,0 @@
|
||||
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
|
||||
|
||||
namespace triton {
|
||||
|
||||
// forward declaration
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace transform{
|
||||
|
||||
class reorder {
|
||||
public:
|
||||
void run(ir::module& module);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,316 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_DISPATCH_H_
|
||||
#define _TRITON_DRIVER_DISPATCH_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <dlfcn.h>
|
||||
|
||||
//CUDA Backend
|
||||
#include "triton/external/CUDA/cuda.h"
|
||||
#include "triton/external/CUDA/nvml.h"
|
||||
|
||||
//// HIP backend
|
||||
//#define __HIP_PLATFORM_AMD__
|
||||
#include "triton/external/hip.h"
|
||||
|
||||
//Exceptions
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace llvm {
|
||||
class PassRegistry;
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton
|
||||
{
|
||||
namespace driver
|
||||
{
|
||||
|
||||
class cu_context;
|
||||
|
||||
template<class T> void check(T){}
|
||||
void check(CUresult err);
|
||||
void check(hipError_t err);
|
||||
|
||||
class dispatch
|
||||
{
|
||||
protected:
|
||||
template <class F>
|
||||
struct return_type;
|
||||
|
||||
template <class R, class... A>
|
||||
struct return_type<R (*)(A...)>
|
||||
{ typedef R type; };
|
||||
|
||||
typedef bool (*f_init_t)();
|
||||
|
||||
template<f_init_t initializer, typename FunPtrT, typename... Args>
|
||||
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
|
||||
{
|
||||
initializer();
|
||||
if(cache == nullptr){
|
||||
cache = dlsym(lib_h, name);
|
||||
if(cache == 0)
|
||||
throw std::runtime_error("dlsym unable to load function");
|
||||
}
|
||||
FunPtrT fptr;
|
||||
*reinterpret_cast<void **>(&fptr) = cache;
|
||||
typename return_type<FunPtrT>::type res = (*fptr)(args...);
|
||||
check(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
public:
|
||||
static void release();
|
||||
// Nvidia
|
||||
static bool nvmlinit();
|
||||
static bool cuinit();
|
||||
// AMD
|
||||
static bool hipinit();
|
||||
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
// context management
|
||||
static CUresult cuInit(unsigned int Flags);
|
||||
static CUresult cuCtxDestroy_v2(CUcontext ctx);
|
||||
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
|
||||
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
|
||||
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
|
||||
static CUresult cuCtxGetDevice(CUdevice* result);
|
||||
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
|
||||
static CUresult cuDriverGetVersion(int *driverVersion);
|
||||
// device management
|
||||
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
|
||||
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
|
||||
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
|
||||
static CUresult cuDeviceGetCount(int *count);
|
||||
// link management
|
||||
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
|
||||
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
|
||||
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
|
||||
static CUresult cuLinkDestroy(CUlinkState state);
|
||||
// module management
|
||||
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
|
||||
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
|
||||
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
|
||||
static CUresult cuModuleUnload(CUmodule hmod);
|
||||
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
|
||||
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
|
||||
// stream management
|
||||
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
|
||||
static CUresult cuStreamSynchronize(CUstream hStream);
|
||||
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
|
||||
static CUresult cuStreamDestroy_v2(CUstream hStream);
|
||||
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
|
||||
// function management
|
||||
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
|
||||
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
|
||||
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
|
||||
// memory management
|
||||
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
|
||||
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
|
||||
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
|
||||
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
static CUresult cuMemFree_v2(CUdeviceptr dptr);
|
||||
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
|
||||
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
|
||||
// event management
|
||||
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
|
||||
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
|
||||
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
|
||||
static CUresult cuEventDestroy_v2(CUevent hEvent);
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
|
||||
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
|
||||
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
|
||||
|
||||
/* ------------------- *
|
||||
* HIP
|
||||
* ------------------- */
|
||||
// context management
|
||||
static hipError_t hipInit(unsigned int Flags);
|
||||
static hipError_t hipCtxDestroy(hipCtx_t ctx);
|
||||
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
|
||||
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
|
||||
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
|
||||
static hipError_t hipCtxGetDevice(hipDevice_t* result);
|
||||
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
|
||||
static hipError_t hipDriverGetVersion(int *driverVersion);
|
||||
// device management
|
||||
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
|
||||
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
|
||||
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
|
||||
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
|
||||
static hipError_t hipGetDeviceCount(int *count);
|
||||
// module management
|
||||
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
|
||||
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
|
||||
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
|
||||
static hipError_t hipModuleUnload(hipModule_t hmod);
|
||||
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
|
||||
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
|
||||
// stream management
|
||||
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
|
||||
static hipError_t hipStreamSynchronize(hipStream_t hStream);
|
||||
static hipError_t hipStreamDestroy(hipStream_t hStream);
|
||||
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
|
||||
// function management
|
||||
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
|
||||
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
|
||||
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
|
||||
// memory management
|
||||
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
|
||||
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
|
||||
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
|
||||
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
|
||||
static hipError_t hipFree(hipDeviceptr_t dptr);
|
||||
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
|
||||
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
|
||||
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
|
||||
// event management
|
||||
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
|
||||
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
|
||||
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
|
||||
static hipError_t hipEventDestroy(hipEvent_t hEvent);
|
||||
|
||||
|
||||
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
static void* cuda_;
|
||||
static void* nvml_;
|
||||
static void* hip_;
|
||||
|
||||
|
||||
/* ------------------- *
|
||||
* CUDA
|
||||
* ------------------- */
|
||||
// context management
|
||||
static void* cuCtxGetCurrent_;
|
||||
static void* cuCtxSetCurrent_;
|
||||
static void* cuCtxDestroy_v2_;
|
||||
static void* cuCtxCreate_v2_;
|
||||
static void* cuCtxGetDevice_;
|
||||
static void* cuCtxPushCurrent_v2_;
|
||||
static void* cuCtxPopCurrent_v2_;
|
||||
static void* cuCtxEnablePeerAccess_;
|
||||
static void* cuDriverGetVersion_;
|
||||
static void* cuInit_;
|
||||
// device management
|
||||
static void* cuDeviceGet_;
|
||||
static void* cuDeviceGetName_;
|
||||
static void* cuDeviceGetPCIBusId_;
|
||||
static void* cuDeviceGetAttribute_;
|
||||
static void* cuDeviceGetCount_;
|
||||
// link management
|
||||
static void* cuLinkAddData_v2_;
|
||||
static void* cuLinkCreate_v2_;
|
||||
static void* cuLinkDestroy_;
|
||||
static void* cuLinkComplete_;
|
||||
// module management
|
||||
static void* cuModuleGetGlobal_v2_;
|
||||
static void* cuModuleLoad_;
|
||||
static void* cuModuleUnload_;
|
||||
static void* cuModuleLoadDataEx_;
|
||||
static void* cuModuleLoadData_;
|
||||
static void* cuModuleGetFunction_;
|
||||
// stream management
|
||||
static void* cuStreamCreate_;
|
||||
static void* cuStreamSynchronize_;
|
||||
static void* cuStreamDestroy_v2_;
|
||||
static void* cuStreamGetCtx_;
|
||||
static void* cuLaunchKernel_;
|
||||
// function management
|
||||
static void* cuFuncGetAttribute_;
|
||||
static void* cuFuncSetAttribute_;
|
||||
static void* cuFuncSetCacheConfig_;
|
||||
// memory management
|
||||
static void* cuMemcpyDtoH_v2_;
|
||||
static void* cuMemFree_v2_;
|
||||
static void* cuMemcpyDtoHAsync_v2_;
|
||||
static void* cuMemcpyHtoDAsync_v2_;
|
||||
static void* cuMemcpyHtoD_v2_;
|
||||
static void* cuMemAlloc_v2_;
|
||||
static void* cuMemsetD8Async_;
|
||||
static void* cuPointerGetAttribute_;
|
||||
// event management
|
||||
static void* cuEventCreate_;
|
||||
static void* cuEventElapsedTime_;
|
||||
static void* cuEventRecord_;
|
||||
static void* cuEventDestroy_v2_;
|
||||
|
||||
/* ------------------- *
|
||||
* NVML
|
||||
* ------------------- */
|
||||
static void* nvmlInit_v2_;
|
||||
static void* nvmlDeviceGetHandleByPciBusId_v2_;
|
||||
static void* nvmlDeviceGetClockInfo_;
|
||||
static void* nvmlDeviceGetMaxClockInfo_;
|
||||
static void* nvmlDeviceSetApplicationsClocks_;
|
||||
|
||||
/* ------------------- *
|
||||
* HIP
|
||||
* ------------------- */
|
||||
// context management
|
||||
static void* hipInit_;
|
||||
static void* hipCtxDestroy_;
|
||||
static void* hipCtxCreate_;
|
||||
static void* hipCtxPushCurrent_;
|
||||
static void* hipCtxPopCurrent_;
|
||||
static void* hipCtxGetDevice_;
|
||||
static void* hipCtxEnablePeerAccess_;
|
||||
static void* hipDriverGetVersion_;
|
||||
// device management
|
||||
static void* hipGetDevice_;
|
||||
static void* hipDeviceGetName_;
|
||||
static void* hipDeviceGetPCIBusId_;
|
||||
static void* hipDeviceGetAttribute_;
|
||||
static void* hipGetDeviceCount_;
|
||||
// module management
|
||||
static void* hipModuleGetGlobal_;
|
||||
static void* hipModuleLoad_;
|
||||
static void* hipModuleLoadData_;
|
||||
static void* hipModuleUnload_;
|
||||
static void* hipModuleLoadDataEx_;
|
||||
static void* hipModuleGetFunction_;
|
||||
// stream management
|
||||
static void* hipStreamCreate_;
|
||||
static void* hipStreamSynchronize_;
|
||||
static void* hipStreamDestroy_;
|
||||
static void* hipModuleLaunchKernel_;;
|
||||
// function management
|
||||
static void* hipFuncGetAttributes_;
|
||||
static void* hipFuncSetAttribute_;
|
||||
static void* hipFuncSetCacheConfig_;
|
||||
// memory management
|
||||
static void* hipMalloc_;
|
||||
static void* hipPointerGetAttribute_;
|
||||
static void* hipMemsetD8Async_;
|
||||
static void* hipMemcpyDtoH_;
|
||||
static void* hipFree_;
|
||||
static void* hipMemcpyDtoHAsync_;
|
||||
static void* hipMemcpyHtoDAsync_;
|
||||
static void* hipMemcpyHtoD_;
|
||||
// event management
|
||||
static void* hipEventCreate_;
|
||||
static void* hipEventElapsedTime_;
|
||||
static void* hipEventRecord_;
|
||||
static void* hipEventDestroy_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -1,220 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_DRIVER_ERROR_H_
|
||||
#define _TRITON_DRIVER_ERROR_H_
|
||||
|
||||
#include <exception>
|
||||
#include "triton/driver/dispatch.h"
|
||||
|
||||
|
||||
namespace triton
|
||||
{
|
||||
|
||||
namespace driver
|
||||
{
|
||||
|
||||
namespace exception
|
||||
{
|
||||
|
||||
namespace nvrtc
|
||||
{
|
||||
|
||||
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
|
||||
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
|
||||
|
||||
#undef TRITON_CREATE_NVRTC_EXCEPTION
|
||||
}
|
||||
|
||||
|
||||
namespace cuda
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
|
||||
|
||||
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||
}
|
||||
|
||||
namespace cublas
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUBLAS_EXCEPTION
|
||||
}
|
||||
|
||||
namespace cudnn
|
||||
{
|
||||
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
|
||||
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
|
||||
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
namespace hip
|
||||
{
|
||||
class base: public std::exception{};
|
||||
|
||||
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "HIP: Error- " msg; } }
|
||||
|
||||
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
|
||||
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
|
||||
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
|
||||
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
|
||||
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
|
||||
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
|
||||
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
|
||||
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
|
||||
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
|
||||
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
|
||||
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
|
||||
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
|
||||
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
|
||||
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
|
||||
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
|
||||
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
|
||||
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
|
||||
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
|
||||
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
|
||||
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
|
||||
|
||||
#undef TRITON_CREATE_CUDA_EXCEPTION
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,20 +0,0 @@
|
||||
#include <string>
|
||||
#include "triton/driver/dispatch.h"
|
||||
|
||||
namespace llvm{
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace driver{
|
||||
|
||||
void init_llvm();
|
||||
std::string path_to_ptxas(int& version);
|
||||
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
|
||||
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
|
||||
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
|
||||
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
|
||||
hipModule_t amdgpu_to_hipmodule(const std::string& path);
|
||||
|
||||
}
|
||||
}
|
18994
include/triton/external/CUDA/cuda.h
vendored
18994
include/triton/external/CUDA/cuda.h
vendored
File diff suppressed because it is too large
Load Diff
6281
include/triton/external/CUDA/nvml.h
vendored
6281
include/triton/external/CUDA/nvml.h
vendored
File diff suppressed because it is too large
Load Diff
3067
include/triton/external/half.hpp
vendored
3067
include/triton/external/half.hpp
vendored
File diff suppressed because it is too large
Load Diff
288
include/triton/external/hip.h
vendored
288
include/triton/external/hip.h
vendored
@@ -1,288 +0,0 @@
|
||||
/*
|
||||
* @brief hipError_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
|
||||
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
|
||||
|
||||
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
|
||||
// we can make that yield a warning
|
||||
|
||||
/*
|
||||
* @brief hipError_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
|
||||
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
typedef enum hipError_t {
|
||||
hipSuccess = 0, ///< Successful completion.
|
||||
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
|
||||
///< or not in an acceptable range.
|
||||
hipErrorOutOfMemory = 2,
|
||||
// Deprecated
|
||||
hipErrorMemoryAllocation = 2, ///< Memory allocation error.
|
||||
hipErrorNotInitialized = 3,
|
||||
// Deprecated
|
||||
hipErrorInitializationError = 3,
|
||||
hipErrorDeinitialized = 4,
|
||||
hipErrorProfilerDisabled = 5,
|
||||
hipErrorProfilerNotInitialized = 6,
|
||||
hipErrorProfilerAlreadyStarted = 7,
|
||||
hipErrorProfilerAlreadyStopped = 8,
|
||||
hipErrorInvalidConfiguration = 9,
|
||||
hipErrorInvalidPitchValue = 12,
|
||||
hipErrorInvalidSymbol = 13,
|
||||
hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer
|
||||
hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction
|
||||
hipErrorInsufficientDriver = 35,
|
||||
hipErrorMissingConfiguration = 52,
|
||||
hipErrorPriorLaunchFailure = 53,
|
||||
hipErrorInvalidDeviceFunction = 98,
|
||||
hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices
|
||||
hipErrorInvalidDevice = 101, ///< DeviceID must be in range 0...#compute-devices.
|
||||
hipErrorInvalidImage = 200,
|
||||
hipErrorInvalidContext = 201, ///< Produced when input context is invalid.
|
||||
hipErrorContextAlreadyCurrent = 202,
|
||||
hipErrorMapFailed = 205,
|
||||
// Deprecated
|
||||
hipErrorMapBufferObjectFailed = 205, ///< Produced when the IPC memory attach failed from ROCr.
|
||||
hipErrorUnmapFailed = 206,
|
||||
hipErrorArrayIsMapped = 207,
|
||||
hipErrorAlreadyMapped = 208,
|
||||
hipErrorNoBinaryForGpu = 209,
|
||||
hipErrorAlreadyAcquired = 210,
|
||||
hipErrorNotMapped = 211,
|
||||
hipErrorNotMappedAsArray = 212,
|
||||
hipErrorNotMappedAsPointer = 213,
|
||||
hipErrorECCNotCorrectable = 214,
|
||||
hipErrorUnsupportedLimit = 215,
|
||||
hipErrorContextAlreadyInUse = 216,
|
||||
hipErrorPeerAccessUnsupported = 217,
|
||||
hipErrorInvalidKernelFile = 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX
|
||||
hipErrorInvalidGraphicsContext = 219,
|
||||
hipErrorInvalidSource = 300,
|
||||
hipErrorFileNotFound = 301,
|
||||
hipErrorSharedObjectSymbolNotFound = 302,
|
||||
hipErrorSharedObjectInitFailed = 303,
|
||||
hipErrorOperatingSystem = 304,
|
||||
hipErrorInvalidHandle = 400,
|
||||
// Deprecated
|
||||
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
|
||||
hipErrorNotFound = 500,
|
||||
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
|
||||
///< ready. This is not actually an error, but is used to distinguish
|
||||
///< from hipSuccess (which indicates completion). APIs that return
|
||||
///< this error include hipEventQuery and hipStreamQuery.
|
||||
hipErrorIllegalAddress = 700,
|
||||
hipErrorLaunchOutOfResources = 701, ///< Out of resources error.
|
||||
hipErrorLaunchTimeOut = 702,
|
||||
hipErrorPeerAccessAlreadyEnabled =
|
||||
704, ///< Peer access was already enabled from the current device.
|
||||
hipErrorPeerAccessNotEnabled =
|
||||
705, ///< Peer access was never enabled from the current device.
|
||||
hipErrorSetOnActiveProcess = 708,
|
||||
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
|
||||
hipErrorHostMemoryAlreadyRegistered =
|
||||
712, ///< Produced when trying to lock a page-locked memory.
|
||||
hipErrorHostMemoryNotRegistered =
|
||||
713, ///< Produced when trying to unlock a non-page-locked memory.
|
||||
hipErrorLaunchFailure =
|
||||
719, ///< An exception occurred on the device while executing a kernel.
|
||||
hipErrorCooperativeLaunchTooLarge =
|
||||
720, ///< This error indicates that the number of blocks launched per grid for a kernel
|
||||
///< that was launched via cooperative launch APIs exceeds the maximum number of
|
||||
///< allowed blocks for the current device
|
||||
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
|
||||
hipErrorUnknown = 999, //< Unknown error.
|
||||
// HSA Runtime Error Codes start here.
|
||||
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
|
||||
///< in production systems.
|
||||
hipErrorRuntimeOther = 1053, ///< HSA runtime call other than memory returned error. Typically
|
||||
///< not seen in production systems.
|
||||
hipErrorTbd ///< Marker that more error codes are needed.
|
||||
} hipError_t;
|
||||
|
||||
|
||||
typedef struct ihipCtx_t* hipCtx_t;
|
||||
|
||||
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
|
||||
typedef int hipDevice_t;
|
||||
|
||||
typedef enum hipDeviceP2PAttr {
|
||||
hipDevP2PAttrPerformanceRank = 0,
|
||||
hipDevP2PAttrAccessSupported,
|
||||
hipDevP2PAttrNativeAtomicSupported,
|
||||
hipDevP2PAttrHipArrayAccessSupported
|
||||
} hipDeviceP2PAttr;
|
||||
|
||||
typedef struct ihipStream_t* hipStream_t;
|
||||
|
||||
#define hipIpcMemLazyEnablePeerAccess 0
|
||||
|
||||
#define HIP_IPC_HANDLE_SIZE 64
|
||||
|
||||
typedef struct hipIpcMemHandle_st {
|
||||
char reserved[HIP_IPC_HANDLE_SIZE];
|
||||
} hipIpcMemHandle_t;
|
||||
|
||||
typedef struct hipIpcEventHandle_st {
|
||||
char reserved[HIP_IPC_HANDLE_SIZE];
|
||||
} hipIpcEventHandle_t;
|
||||
|
||||
typedef struct ihipModule_t* hipModule_t;
|
||||
|
||||
typedef struct ihipModuleSymbol_t* hipFunction_t;
|
||||
|
||||
typedef struct hipFuncAttributes {
|
||||
int binaryVersion;
|
||||
int cacheModeCA;
|
||||
size_t constSizeBytes;
|
||||
size_t localSizeBytes;
|
||||
int maxDynamicSharedSizeBytes;
|
||||
int maxThreadsPerBlock;
|
||||
int numRegs;
|
||||
int preferredShmemCarveout;
|
||||
int ptxVersion;
|
||||
size_t sharedSizeBytes;
|
||||
} hipFuncAttributes;
|
||||
|
||||
typedef struct ihipEvent_t* hipEvent_t;
|
||||
|
||||
/*
|
||||
* @brief hipDeviceAttribute_t
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
typedef enum hipDeviceAttribute_t {
|
||||
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
|
||||
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
|
||||
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
|
||||
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
|
||||
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
|
||||
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
|
||||
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
|
||||
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
|
||||
///< bytes.
|
||||
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
|
||||
hipDeviceAttributeWarpSize, ///< Warp size in threads.
|
||||
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
|
||||
///< thread block. This number is shared by all thread
|
||||
///< blocks simultaneously resident on a
|
||||
///< multiprocessor.
|
||||
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
|
||||
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
|
||||
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
|
||||
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
|
||||
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
|
||||
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
|
||||
///< cache.
|
||||
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
|
||||
///< multiprocessor.
|
||||
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
|
||||
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
|
||||
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
|
||||
///< concurrently.
|
||||
hipDeviceAttributePciBusId, ///< PCI Bus ID.
|
||||
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
|
||||
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
|
||||
///< Multiprocessor.
|
||||
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
|
||||
hipDeviceAttributeIntegrated, ///< iGPU
|
||||
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
|
||||
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
|
||||
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
|
||||
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
|
||||
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
|
||||
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
|
||||
|
||||
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
|
||||
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
|
||||
|
||||
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
|
||||
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
|
||||
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
|
||||
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
|
||||
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
|
||||
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
|
||||
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched functions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched grid dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched block dimensions
|
||||
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
|
||||
///devices with unmatched shared memories
|
||||
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
|
||||
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
|
||||
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
|
||||
/// the device without migration
|
||||
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
|
||||
/// concurrently with the CPU
|
||||
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
|
||||
/// without calling hipHostRegister on it
|
||||
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
|
||||
/// the host's page tables
|
||||
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
|
||||
///< hipStreamWaitValue64() , '0' otherwise.
|
||||
|
||||
} hipDeviceAttribute_t;
|
||||
|
||||
typedef void* hipDeviceptr_t;
|
||||
|
||||
/*
|
||||
* @brief hipJitOption
|
||||
* @enum
|
||||
* @ingroup Enumerations
|
||||
*/
|
||||
typedef enum hipJitOption {
|
||||
hipJitOptionMaxRegisters = 0,
|
||||
hipJitOptionThreadsPerBlock,
|
||||
hipJitOptionWallTime,
|
||||
hipJitOptionInfoLogBuffer,
|
||||
hipJitOptionInfoLogBufferSizeBytes,
|
||||
hipJitOptionErrorLogBuffer,
|
||||
hipJitOptionErrorLogBufferSizeBytes,
|
||||
hipJitOptionOptimizationLevel,
|
||||
hipJitOptionTargetFromContext,
|
||||
hipJitOptionTarget,
|
||||
hipJitOptionFallbackStrategy,
|
||||
hipJitOptionGenerateDebugInfo,
|
||||
hipJitOptionLogVerbose,
|
||||
hipJitOptionGenerateLineInfo,
|
||||
hipJitOptionCacheMode,
|
||||
hipJitOptionSm3xOpt,
|
||||
hipJitOptionFastCompile,
|
||||
hipJitOptionNumOptions
|
||||
} hipJitOption;
|
||||
|
||||
/**
|
||||
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
|
||||
*/
|
||||
typedef enum hipFuncAttribute {
|
||||
hipFuncAttributeMaxDynamicSharedMemorySize = 8,
|
||||
hipFuncAttributePreferredSharedMemoryCarveout = 9,
|
||||
hipFuncAttributeMax
|
||||
} hipFuncAttribute;
|
||||
|
||||
/**
|
||||
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
|
||||
*/
|
||||
typedef enum hipFuncCache_t {
|
||||
hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default)
|
||||
hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 cache
|
||||
hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory
|
||||
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
|
||||
} hipFuncCache_t;
|
||||
|
||||
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
|
||||
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
|
||||
#define HIP_LAUNCH_PARAM_END ((void*)0x03)
|
@@ -1,88 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BASIC_BLOCK_H_
|
||||
#define _TRITON_IR_BASIC_BLOCK_H_
|
||||
|
||||
#include <string>
|
||||
#include <list>
|
||||
#include "value.h"
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class function;
|
||||
class instruction;
|
||||
|
||||
/* Basic Block */
|
||||
class basic_block: public value{
|
||||
public:
|
||||
// instruction iterator types
|
||||
typedef std::list<instruction*> inst_list_t;
|
||||
typedef inst_list_t::iterator iterator;
|
||||
typedef inst_list_t::const_iterator const_iterator;
|
||||
typedef inst_list_t::reverse_iterator reverse_iterator;
|
||||
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
basic_block(context &ctx, const std::string &name, function *parent);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
function* get_parent() { return parent_; }
|
||||
context& get_context() { return ctx_; }
|
||||
|
||||
// get iterator to first instruction that is not a phi
|
||||
iterator get_first_non_phi();
|
||||
|
||||
// get instruction list
|
||||
inst_list_t &get_inst_list() { return inst_list_; }
|
||||
const inst_list_t &get_inst_list() const { return inst_list_; }
|
||||
void erase(instruction *i) { inst_list_.remove(i); }
|
||||
|
||||
// instruction iterator functions
|
||||
inline iterator begin() { return inst_list_.begin(); }
|
||||
inline const_iterator begin() const { return inst_list_.begin(); }
|
||||
inline iterator end () { return inst_list_.end(); }
|
||||
inline const_iterator end () const { return inst_list_.end(); }
|
||||
|
||||
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
|
||||
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
|
||||
inline reverse_iterator rend () { return inst_list_.rend(); }
|
||||
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
|
||||
|
||||
inline size_t size() const { return inst_list_.size(); }
|
||||
inline bool empty() const { return inst_list_.empty(); }
|
||||
inline const instruction &front() const { return *inst_list_.front(); }
|
||||
inline instruction &front() { return *inst_list_.front(); }
|
||||
inline const instruction &back() const { return *inst_list_.back(); }
|
||||
inline instruction &back() { return *inst_list_.back(); }
|
||||
|
||||
// predecessors
|
||||
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
|
||||
const std::vector<basic_block*>& get_successors() const { return succs_; }
|
||||
void add_predecessor(basic_block* pred);
|
||||
|
||||
// factory functions
|
||||
static basic_block* create(context &ctx, const std::string &name, function *parent);
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_basic_block(this); }
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
std::string name_;
|
||||
function *parent_;
|
||||
std::vector<basic_block*> preds_;
|
||||
std::vector<basic_block*> succs_;
|
||||
inst_list_t inst_list_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,191 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BUILDER_H_
|
||||
#define _TRITON_IR_BUILDER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "instructions.h"
|
||||
#include "basic_block.h"
|
||||
#include "type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
class value;
|
||||
class type;
|
||||
class constant_int;
|
||||
class instruction;
|
||||
class context;
|
||||
class phi_node;
|
||||
|
||||
/* Builder */
|
||||
class builder{
|
||||
typedef basic_block::iterator iterator;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
builder(context &ctx);
|
||||
// Getters
|
||||
const context& get_context() { return ctx_; }
|
||||
// Setters
|
||||
void set_insert_point(iterator instr);
|
||||
void set_insert_point(instruction* i);
|
||||
void set_insert_point_after(instruction* i);
|
||||
void set_insert_point(basic_block* block);
|
||||
basic_block* get_insert_block() { return block_; }
|
||||
iterator get_insert_point() { return insert_point_;}
|
||||
// Constants
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(uint32_t val);
|
||||
value *get_int64(uint64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
type *get_int8_ty();
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_fp8_ty();
|
||||
type *get_half_ty();
|
||||
type *get_bf16_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
template<typename InstTy>
|
||||
InstTy* insert(InstTy *inst){
|
||||
assert(block_);
|
||||
block_->get_inst_list().insert(insert_point_, inst);
|
||||
inst->set_parent(block_);
|
||||
// for(ir::value* op: inst->ops())
|
||||
// op->add_use(inst);
|
||||
return inst;
|
||||
}
|
||||
// terminator instructions
|
||||
value* create_br(basic_block *dest);
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value* create_bitcast(value *src, type *dest_ty);
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_int_to_ptr(value *src, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
value* create_fp_to_si(value *src, type *dst_ty);
|
||||
value* create_fp_to_ui(value *src, type *dst_ty);
|
||||
value* create_fp_ext(value *src, type *dst_ty);
|
||||
value* create_fp_trunc(value *src, type *dst_ty);
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
||||
value *create_downcast(value *arg);
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved);
|
||||
// Binary instructions
|
||||
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
|
||||
value *create_fmul(value *lhs, value *rhs);
|
||||
value *create_fdiv(value *lhs, value *rhs);
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
// GEP
|
||||
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
|
||||
// Comparison (int)
|
||||
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_icmpSLE(value *lhs, value *rhs);
|
||||
value *create_icmpSLT(value *lhs, value *rhs);
|
||||
value *create_icmpSGE(value *lhs, value *rhs);
|
||||
value *create_icmpSGT(value *lhs, value *rhs);
|
||||
value *create_icmpULE(value *lhs, value *rhs);
|
||||
value *create_icmpULT(value *lhs, value *rhs);
|
||||
value *create_icmpUGE(value *lhs, value *rhs);
|
||||
value *create_icmpUGT(value *lhs, value *rhs);
|
||||
value *create_icmpEQ(value *lhs, value *rhs);
|
||||
value *create_icmpNE(value *lhs, value *rhs);
|
||||
// Comparison (float)
|
||||
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_fcmpOLT(value *lhs, value *rhs);
|
||||
value *create_fcmpOGT(value *lhs, value *rhs);
|
||||
value *create_fcmpOLE(value *lhs, value *rhs);
|
||||
value *create_fcmpOGE(value *lhs, value *rhs);
|
||||
value *create_fcmpOEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpONE(value *lhs, value *rhs);
|
||||
value *create_fcmpULT(value *lhs, value *rhs);
|
||||
value *create_fcmpUGT(value *lhs, value *rhs);
|
||||
value *create_fcmpULE(value *lhs, value *rhs);
|
||||
value *create_fcmpUGE(value *lhs, value *rhs);
|
||||
value *create_fcmpUEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpUNE(value *lhs, value *rhs);
|
||||
// Logical
|
||||
value *create_and(value *lhs, value *rhs);
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_cat(value *lhs, value *rhs);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Atomic instruction
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
|
||||
value *create_atomic_max(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umax(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_min(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_umin(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_fadd(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_and(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_or(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xor(value *ptr, value *val, value *msk);
|
||||
value *create_atomic_xchg(value *ptr, value *val, value *msk);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_exp(value* arg);
|
||||
value *create_cos(value* arg);
|
||||
value *create_sin(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
value *create_select(value *pred, value *if_value, value *else_value);
|
||||
// Intrinsics
|
||||
// These have no place in the IR, and hopefully they can be removed at some point
|
||||
value *create_umulhi(value* lhs, value* rhs);
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
value *create_prefetch_s(value *arg, int inc);
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
basic_block *block_;
|
||||
iterator insert_point_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,113 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONSTANT_H_
|
||||
#define _TRITON_IR_CONSTANT_H_
|
||||
|
||||
#include "enums.h"
|
||||
#include "value.h"
|
||||
#include <cassert>
|
||||
#include "visitor.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context;
|
||||
|
||||
/* Constant */
|
||||
class constant: public user{
|
||||
protected:
|
||||
using user::user;
|
||||
|
||||
public:
|
||||
static constant* get_all_ones_value(type *ty);
|
||||
static constant* get_null_value(type *ty);
|
||||
virtual std::string repr() const = 0;
|
||||
};
|
||||
|
||||
/* Undef value */
|
||||
class undef_value: public constant{
|
||||
private:
|
||||
undef_value(type *ty);
|
||||
|
||||
public:
|
||||
static undef_value* get(type* ty);
|
||||
std::string repr() const { return "undef"; }
|
||||
void accept(visitor* vst) { vst->visit_undef_value(this); }
|
||||
};
|
||||
|
||||
|
||||
/* Constant int */
|
||||
class constant_int: public constant{
|
||||
protected:
|
||||
constant_int(type *ty, uint64_t value);
|
||||
|
||||
public:
|
||||
virtual uint64_t get_value() const { return value_; }
|
||||
static constant_int *get(type *ty, uint64_t value);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_int(this); }
|
||||
|
||||
protected:
|
||||
uint64_t value_;
|
||||
};
|
||||
|
||||
/* Constant fp */
|
||||
class constant_fp: public constant{
|
||||
constant_fp(type *ty, double value);
|
||||
|
||||
public:
|
||||
double get_value() { return value_; }
|
||||
static constant* get_negative_zero(type *ty);
|
||||
static constant* get_zero_value_for_negation(type *ty);
|
||||
static constant* get(context &ctx, double v);
|
||||
static constant* get(type *ty, double v);
|
||||
std::string repr() const { return std::to_string(value_); }
|
||||
void accept(visitor* vst) { vst->visit_constant_fp(this); }
|
||||
|
||||
private:
|
||||
double value_;
|
||||
};
|
||||
|
||||
|
||||
/* Global Value */
|
||||
class global_value: public constant {
|
||||
public:
|
||||
enum linkage_types_t {
|
||||
external
|
||||
};
|
||||
|
||||
public:
|
||||
global_value(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space);
|
||||
std::string repr() const { return get_name(); }
|
||||
|
||||
private:
|
||||
linkage_types_t linkage_;
|
||||
};
|
||||
|
||||
/* global object */
|
||||
class global_object: public global_value {
|
||||
public:
|
||||
global_object(type *ty, unsigned num_ops,
|
||||
linkage_types_t linkage, const std::string &name,
|
||||
unsigned addr_space = 0);
|
||||
std::string repr() const { return get_name(); }
|
||||
};
|
||||
|
||||
/* global variable */
|
||||
class alloc_const: public global_object {
|
||||
public:
|
||||
alloc_const(type *ty, constant_int *size,
|
||||
const std::string &name = "");
|
||||
std::string repr() const { return get_name(); }
|
||||
void accept(visitor* vst) { vst->visit_alloc_const(this); }
|
||||
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,29 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONTEXT_H_
|
||||
#define _TRITON_IR_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "triton/ir/type.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class context_impl;
|
||||
|
||||
/* Context */
|
||||
class context {
|
||||
public:
|
||||
context();
|
||||
context(const context&) = delete;
|
||||
context& operator=(const context&) = delete;
|
||||
|
||||
public:
|
||||
std::shared_ptr<context_impl> p_impl;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,46 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
|
||||
#define _TRITON_IR_CONTEXT_IMPL_H_
|
||||
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
|
||||
/* Context impl */
|
||||
class context_impl {
|
||||
public:
|
||||
// constructors
|
||||
context_impl(context &ctx);
|
||||
|
||||
public:
|
||||
// non-numeric types
|
||||
type void_ty, label_ty;
|
||||
// floating point types
|
||||
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
|
||||
// integer types
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
|
||||
// Block types
|
||||
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
|
||||
|
||||
// Int constants
|
||||
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
|
||||
// Float constants
|
||||
std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,175 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_ENUMS_H_
|
||||
#define _TRITON_IR_ENUMS_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
enum binary_op_t: unsigned int{
|
||||
Add,
|
||||
FAdd,
|
||||
Sub,
|
||||
FSub,
|
||||
Mul,
|
||||
FMul,
|
||||
UDiv,
|
||||
SDiv,
|
||||
FDiv,
|
||||
URem,
|
||||
SRem,
|
||||
FRem,
|
||||
Shl,
|
||||
LShr,
|
||||
AShr,
|
||||
And,
|
||||
Or,
|
||||
Xor
|
||||
};
|
||||
|
||||
enum class atomic_rmw_op_t: unsigned int{
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Add,
|
||||
Max,
|
||||
Min,
|
||||
UMax,
|
||||
UMin,
|
||||
FAdd,
|
||||
Xchg,
|
||||
};
|
||||
|
||||
enum cast_op_t: unsigned int {
|
||||
Trunc,
|
||||
ZExt,
|
||||
SExt,
|
||||
FPTrunc,
|
||||
FPExt,
|
||||
UIToFP,
|
||||
SIToFP,
|
||||
FPToUI,
|
||||
FPToSI,
|
||||
PtrToInt,
|
||||
IntToPtr,
|
||||
BitCast,
|
||||
AddrSpaceCast
|
||||
};
|
||||
|
||||
enum cmp_pred_t: unsigned int {
|
||||
FIRST_FCMP_PREDICATE,
|
||||
FCMP_FALSE,
|
||||
FCMP_OEQ,
|
||||
FCMP_OGT,
|
||||
FCMP_OGE,
|
||||
FCMP_OLT,
|
||||
FCMP_OLE,
|
||||
FCMP_ONE,
|
||||
FCMP_ORD,
|
||||
FCMP_UNO,
|
||||
FCMP_UEQ,
|
||||
FCMP_UGT,
|
||||
FCMP_UGE,
|
||||
FCMP_ULT,
|
||||
FCMP_ULE,
|
||||
FCMP_UNE,
|
||||
FCMP_TRUE,
|
||||
LAST_FCMP_PREDICATE,
|
||||
FIRST_ICMP_PREDICATE,
|
||||
ICMP_EQ,
|
||||
ICMP_NE,
|
||||
ICMP_UGT,
|
||||
ICMP_UGE,
|
||||
ICMP_ULT,
|
||||
ICMP_ULE,
|
||||
ICMP_SGT,
|
||||
ICMP_SGE,
|
||||
ICMP_SLT,
|
||||
ICMP_SLE,
|
||||
LAST_ICMP_PREDICATE
|
||||
};
|
||||
|
||||
enum value_id_t: unsigned {
|
||||
/* ------------ *
|
||||
INSTRUCTIONS
|
||||
* ------------ */
|
||||
INST_BEGIN,
|
||||
// phi
|
||||
INST_PHI,
|
||||
// arithmetic
|
||||
INST_BINOP,
|
||||
INST_GETELEMENTPTR,
|
||||
INST_SELECT,
|
||||
INST_SQRT,
|
||||
// cmp
|
||||
INST_ICMP,
|
||||
INST_FCMP,
|
||||
// cast
|
||||
INST_CAST_TRUNC,
|
||||
INST_CAST_ZEXT,
|
||||
INST_CAST_SEXT,
|
||||
INST_CAST_FP_TRUNC,
|
||||
INST_CAST_FP_EXT,
|
||||
INST_CAST_UI_TO_FP,
|
||||
INST_CAST_SI_TO_FP,
|
||||
INST_CAST_FP_TO_UI,
|
||||
INST_CAST_FP_TO_SI,
|
||||
INST_CAST_PTR_TO_INT,
|
||||
INST_CAST_INT_TO_PTR,
|
||||
INST_CAST_BIT_CAST,
|
||||
INST_CAST_ADDR_SPACE_CAST,
|
||||
// terminators
|
||||
INST_RETURN,
|
||||
INST_COND_BRANCH,
|
||||
INST_UNCOND_BRANCH,
|
||||
// io
|
||||
INST_UNMASKED_LOAD,
|
||||
INST_MASKED_LOAD,
|
||||
INST_MASKED_LOAD_ASYNC,
|
||||
INST_UNMASKED_STORE,
|
||||
INST_MASKED_STORE,
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
INST_CAT,
|
||||
INST_BROADCAST,
|
||||
INST_DOWNCAST,
|
||||
// builtin
|
||||
INST_GET_PROGRAM_ID,
|
||||
INST_GET_NUM_PROGRAMS,
|
||||
// atomics
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_RMW,
|
||||
// math
|
||||
INST_UMULHI,
|
||||
INST_EXP,
|
||||
INST_COS,
|
||||
INST_SIN,
|
||||
INST_LOG,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
INST_DOT,
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_COPY_FROM_SHARED,
|
||||
INST_CVT_LAYOUT,
|
||||
INST_CVT_SCANLINE,
|
||||
INST_DECOALESCE,
|
||||
INST_RECOALESCE,
|
||||
INST_BARRIER,
|
||||
INST_ASYNC_WAIT,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE,
|
||||
INST_PREFETCH_S,
|
||||
};
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,142 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_FUNCTION_H_
|
||||
#define _TRITON_IR_FUNCTION_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "value.h"
|
||||
#include "constant.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class function;
|
||||
class function_type;
|
||||
class module;
|
||||
class basic_block;
|
||||
|
||||
/* Argument */
|
||||
class argument: public value{
|
||||
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
|
||||
|
||||
public:
|
||||
static argument* create(type *ty, const std::string &name,
|
||||
function *parent = nullptr, unsigned arg_no = 0);
|
||||
function* get_parent() const;
|
||||
unsigned get_arg_no() const;
|
||||
|
||||
void accept(visitor *v);
|
||||
|
||||
private:
|
||||
function *parent_;
|
||||
unsigned arg_no_;
|
||||
};
|
||||
|
||||
/* Attribute */
|
||||
enum attribute_kind_t {
|
||||
readonly = 0,
|
||||
writeonly,
|
||||
noalias,
|
||||
aligned,
|
||||
multiple_of,
|
||||
retune,
|
||||
not_implemented
|
||||
};
|
||||
|
||||
class attribute {
|
||||
public:
|
||||
attribute(attribute_kind_t kind, unsigned value = 0):
|
||||
kind_(kind), value_(value){}
|
||||
|
||||
bool operator<(const attribute& other) const {
|
||||
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
|
||||
}
|
||||
|
||||
attribute_kind_t get_kind() const {
|
||||
return kind_;
|
||||
}
|
||||
|
||||
unsigned get_value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
bool is_llvm_attr() const {
|
||||
return kind_ != multiple_of;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(kind_){
|
||||
case readonly: return ".readonly";
|
||||
case writeonly: return ".writeonly";
|
||||
case noalias: return ".noalias";
|
||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
|
||||
case retune: return ".retunr";
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
return "";
|
||||
}
|
||||
|
||||
private:
|
||||
attribute_kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
/* Function */
|
||||
class function: public global_object{
|
||||
typedef std::vector<argument*> args_t;
|
||||
typedef args_t::iterator arg_iterator;
|
||||
typedef args_t::const_iterator const_arg_iterator;
|
||||
|
||||
typedef std::vector<basic_block*> blocks_t;
|
||||
typedef blocks_t::iterator block_iterator;
|
||||
typedef blocks_t::const_iterator const_block_iterator;
|
||||
|
||||
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
|
||||
|
||||
private:
|
||||
function(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name = "", module *parent = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const args_t &args() const { return args_; }
|
||||
function_type* get_fn_type() { return fn_ty_; }
|
||||
const function_type* get_fn_type() const { return fn_ty_; }
|
||||
module *get_parent() { return parent_; }
|
||||
const module *get_parent() const { return parent_; }
|
||||
|
||||
// factory methods
|
||||
static function *create(function_type *ty, linkage_types_t linkage,
|
||||
const std::string &name, module *mod);
|
||||
// blocks
|
||||
const blocks_t &blocks() { return blocks_; }
|
||||
const blocks_t &blocks() const { return blocks_; }
|
||||
void insert_block(basic_block* block, basic_block *next = nullptr);
|
||||
|
||||
// attributes
|
||||
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
|
||||
const attr_map_t &attrs() { return attrs_; }
|
||||
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
|
||||
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
// visitor
|
||||
void accept(visitor *v) { v->visit_function(this); }
|
||||
|
||||
private:
|
||||
module *parent_;
|
||||
bool init_;
|
||||
function_type *fn_ty_;
|
||||
args_t args_;
|
||||
blocks_t blocks_;
|
||||
attr_map_t attrs_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,978 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_INSTRUCTIONS_H_
|
||||
#define _TRITON_IR_INSTRUCTIONS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "triton/ir/enums.h"
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/visitor.h"
|
||||
|
||||
#define _TRITON_DEFINE_CLONE(name) \
|
||||
ir::instruction* clone_impl() const { return new name(*this); }
|
||||
|
||||
#define _TRITON_DEFINE_ACCEPT(name) \
|
||||
void accept(visitor* v) { v->visit_ ## name (this); }
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class constant_int;
|
||||
class constant;
|
||||
class make_range;
|
||||
class basic_block;
|
||||
class context;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// instruction classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class result_reference;
|
||||
|
||||
|
||||
class instruction: public user{
|
||||
public:
|
||||
virtual std::string repr_impl() const = 0;
|
||||
|
||||
private:
|
||||
virtual ir::instruction* clone_impl() const = 0;
|
||||
|
||||
protected:
|
||||
// constructors
|
||||
instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// parent
|
||||
void set_parent(basic_block *block) { parent_ = block; }
|
||||
const basic_block *get_parent() const { return parent_; }
|
||||
basic_block *get_parent() { return parent_; }
|
||||
void erase_from_parent();
|
||||
// helpers
|
||||
bool has_tile_result_or_op();
|
||||
// repr
|
||||
std::string repr() const { return repr_impl(); }
|
||||
// metadata
|
||||
void set_metadata(ir::metadata::kind_t kind,
|
||||
unsigned value) { metadatas_[kind] = value;}
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
// for(auto it = op_begin(); it != op_end(); it++)
|
||||
// (*it)->add_use(res);
|
||||
res->parent_ = nullptr;
|
||||
res->users_.clear();
|
||||
return res;
|
||||
}
|
||||
// instruction id
|
||||
value_id_t get_id() const { return id_; }
|
||||
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
value_id_t id_;
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class phi_node: public instruction {
|
||||
private:
|
||||
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "phi"; }
|
||||
|
||||
public:
|
||||
void set_incoming_value(unsigned i, value *v);
|
||||
void set_incoming_block(unsigned i, basic_block *block);
|
||||
value *get_value_for_block(basic_block *block);
|
||||
value *get_incoming_value(unsigned i) { return get_operand(i); }
|
||||
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
|
||||
unsigned get_num_incoming() { return get_num_operands(); }
|
||||
void add_incoming(value *v, basic_block *block);
|
||||
|
||||
// Type
|
||||
void set_type(type *ty) { ty_ = ty; }
|
||||
|
||||
// Factory methods
|
||||
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(phi_node)
|
||||
_TRITON_DEFINE_ACCEPT(phi_node)
|
||||
|
||||
private:
|
||||
unsigned num_reserved_;
|
||||
std::vector<basic_block*> blocks_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class binary_operator: public instruction {
|
||||
public:
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
// Constructors
|
||||
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// Get operand
|
||||
binary_op_t get_op() const { return op_; }
|
||||
|
||||
// Bool
|
||||
bool is_terminator() const;
|
||||
bool is_binary_op() const;
|
||||
bool is_int_div_rem() const;
|
||||
bool is_shift() const;
|
||||
bool is_cast() const;
|
||||
bool is_int_mult() const;
|
||||
bool is_int_add_sub() const;
|
||||
bool is_int_div() const;
|
||||
bool is_int_rem() const;
|
||||
bool is_shl() const;
|
||||
bool is_shr() const;
|
||||
|
||||
// Approx
|
||||
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
|
||||
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
|
||||
|
||||
// Wraps
|
||||
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
|
||||
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
|
||||
|
||||
// Factory methods
|
||||
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(binary_operator)
|
||||
_TRITON_DEFINE_ACCEPT(binary_operator)
|
||||
|
||||
public:
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
bool has_no_signed_wrap_;
|
||||
|
||||
bool fdiv_ieee_rnd_;
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cmp_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cmp_inst: public instruction{
|
||||
public:
|
||||
typedef cmp_pred_t pred_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
static bool is_fp_predicate(cmp_pred_t pred);
|
||||
static bool is_int_predicate(cmp_pred_t pred);
|
||||
static type* make_cmp_result_type(type *ty);
|
||||
|
||||
public:
|
||||
cmp_pred_t get_pred() const { return pred_; }
|
||||
|
||||
private:
|
||||
cmp_pred_t pred_;
|
||||
};
|
||||
|
||||
class icmp_inst: public cmp_inst {
|
||||
icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(icmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(icmp_inst)
|
||||
};
|
||||
|
||||
class fcmp_inst: public cmp_inst {
|
||||
fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(fcmp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(fcmp_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// unary_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class unary_inst: public instruction {
|
||||
protected:
|
||||
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class cast_inst: public unary_inst{
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
|
||||
: unary_inst(ty, id, v, name, next), op_(op) { }
|
||||
|
||||
private:
|
||||
static bool is_valid(cast_op_t op, value *arg, type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
cast_op_t get_op() const { return op_; }
|
||||
|
||||
// factory methods
|
||||
static cast_inst *create(cast_op_t op, value *arg, type *ty,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_ACCEPT(cast_inst)
|
||||
|
||||
private:
|
||||
cast_op_t op_;
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
|
||||
class name : public cast_inst { \
|
||||
_TRITON_DEFINE_CLONE(name) \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, id, v, name, next, op){ } \
|
||||
};
|
||||
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class terminator_inst: public instruction{
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
// return instruction
|
||||
class return_inst: public terminator_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "ret"; }
|
||||
return_inst(context &ctx, value *ret_val, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_return_value()
|
||||
{ return get_num_operands() ? get_operand(0) : nullptr; }
|
||||
|
||||
unsigned get_num_successors() const { return 0; }
|
||||
|
||||
// factory methods
|
||||
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(return_inst)
|
||||
_TRITON_DEFINE_ACCEPT(return_inst)
|
||||
};
|
||||
|
||||
// base branch instruction
|
||||
class branch_inst: public terminator_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "br"; }
|
||||
|
||||
protected:
|
||||
using terminator_inst::terminator_inst;
|
||||
|
||||
public:
|
||||
static branch_inst* create(basic_block *dest,
|
||||
instruction *next = nullptr);
|
||||
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
// conditional branch
|
||||
class cond_branch_inst: public branch_inst {
|
||||
private:
|
||||
friend class branch_inst;
|
||||
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
|
||||
|
||||
public:
|
||||
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
|
||||
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
|
||||
value *get_cond() { return get_operand(2); }
|
||||
_TRITON_DEFINE_CLONE(cond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
|
||||
};
|
||||
|
||||
// unconditional branch
|
||||
class uncond_branch_inst: public branch_inst {
|
||||
private:
|
||||
friend class branch_inst;
|
||||
uncond_branch_inst(basic_block *dst, instruction *next);
|
||||
|
||||
public:
|
||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||
_TRITON_DEFINE_CLONE(uncond_branch_inst)
|
||||
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
|
||||
};
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class getelementptr_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "getelementptr"; }
|
||||
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
|
||||
|
||||
private:
|
||||
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
|
||||
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
|
||||
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
type *get_source_elt_ty() { return source_elt_ty; }
|
||||
op_iterator idx_begin() { return op_begin() + 1; }
|
||||
op_iterator idx_end() { return op_end(); }
|
||||
value *get_pointer_operand() { return *op_begin(); }
|
||||
|
||||
// factory methods
|
||||
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(getelementptr_inst)
|
||||
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
|
||||
|
||||
private:
|
||||
type *source_elt_ty;
|
||||
type *res_elt_ty;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// load_inst/store_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class io_inst: public instruction {
|
||||
protected:
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
};
|
||||
|
||||
// load
|
||||
class load_inst: public io_inst {
|
||||
public:
|
||||
enum CACHE_MODIFIER : uint32_t {
|
||||
NONE=0,
|
||||
CA,
|
||||
CG,
|
||||
};
|
||||
|
||||
enum EVICTION_POLICY : uint32_t {
|
||||
NORMAL=0,
|
||||
EVICT_FIRST,
|
||||
EVICT_LAST,
|
||||
};
|
||||
|
||||
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
|
||||
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
|
||||
bool get_is_volatile() const { return is_volatile_; }
|
||||
|
||||
protected:
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
std::string get_cache_modifier_repr() const {
|
||||
if (cache_ == CA) return ".ca";
|
||||
if (cache_ == CG) return ".cg";
|
||||
return "";
|
||||
}
|
||||
std::string get_eviction_policy_repr() const {
|
||||
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
|
||||
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
|
||||
}
|
||||
EVICTION_POLICY eviction_;
|
||||
CACHE_MODIFIER cache_;
|
||||
|
||||
std::string get_volatile_repr() {
|
||||
return is_volatile_ ? ".volatile" : "";
|
||||
}
|
||||
bool is_volatile_;
|
||||
|
||||
private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
|
||||
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
|
||||
};
|
||||
|
||||
// masked load
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
bool is_volatile,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_inst)
|
||||
};
|
||||
|
||||
// masked load async
|
||||
class masked_load_async_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
|
||||
masked_load_async_inst(value *ptr, value *mask, value *false_value,
|
||||
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(1); }
|
||||
value *get_false_value_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
|
||||
load_inst::CACHE_MODIFIER cache,
|
||||
EVICTION_POLICY eviction,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_async_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
|
||||
};
|
||||
|
||||
|
||||
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
};
|
||||
|
||||
// unmasked_store
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_store"; }
|
||||
masked_store_inst(value *ptr, value *v, value *mask,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_mask_operand() { return get_operand(2); }
|
||||
// factory method
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
_TRITON_DEFINE_ACCEPT(masked_store_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// cat
|
||||
|
||||
class cat_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "cat"; }
|
||||
cat_inst(value *x, value *y, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cat_inst)
|
||||
};
|
||||
|
||||
// retile
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
// reshape
|
||||
|
||||
class reshape_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "reshape"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(reshape_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reshape_inst)
|
||||
};
|
||||
|
||||
// splat
|
||||
|
||||
class splat_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "splat"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(splat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(splat_inst)
|
||||
};
|
||||
|
||||
// broadcast
|
||||
|
||||
class broadcast_inst: public retile_inst {
|
||||
private:
|
||||
using retile_inst::retile_inst;
|
||||
std::string repr_impl() const { return "broadcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(broadcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(broadcast_inst)
|
||||
};
|
||||
|
||||
|
||||
// downcast
|
||||
|
||||
class downcast_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "downcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(downcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(downcast_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// builtin_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class builtin_inst: public instruction{
|
||||
protected:
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
class get_program_id_inst: public builtin_inst {
|
||||
private:
|
||||
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_program_id_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class get_num_programs_inst: public builtin_inst {
|
||||
private:
|
||||
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_num_programs_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
|
||||
class atomic_inst: public io_inst {
|
||||
public:
|
||||
using io_inst::io_inst;
|
||||
};
|
||||
|
||||
class atomic_rmw_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_rmw"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
|
||||
atomic_rmw_op_t get_op() { return op_; }
|
||||
|
||||
private:
|
||||
atomic_rmw_op_t op_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public atomic_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_cas_inst)
|
||||
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class umulhi_inst: public builtin_inst {
|
||||
private:
|
||||
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "umulhi"; }
|
||||
_TRITON_DEFINE_CLONE(umulhi_inst)
|
||||
_TRITON_DEFINE_ACCEPT(umulhi_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class exp_inst: public builtin_inst {
|
||||
private:
|
||||
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "exp"; }
|
||||
_TRITON_DEFINE_CLONE(exp_inst)
|
||||
_TRITON_DEFINE_ACCEPT(exp_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class cos_inst: public builtin_inst {
|
||||
private:
|
||||
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "cos"; }
|
||||
_TRITON_DEFINE_CLONE(cos_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cos_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class sin_inst: public builtin_inst {
|
||||
private:
|
||||
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "sin"; }
|
||||
_TRITON_DEFINE_CLONE(sin_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sin_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class log_inst: public builtin_inst {
|
||||
private:
|
||||
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "log"; }
|
||||
_TRITON_DEFINE_CLONE(log_inst)
|
||||
_TRITON_DEFINE_ACCEPT(log_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
|
||||
class dot_inst: public builtin_inst {
|
||||
public:
|
||||
enum TransT { NoTrans, Trans };
|
||||
enum DataType {
|
||||
FP8, FP16, BF16, TF32, FP32,
|
||||
INT1, INT4, INT8, INT32,
|
||||
UNKNOWN,
|
||||
};
|
||||
|
||||
private:
|
||||
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "dot"; }
|
||||
|
||||
public:
|
||||
bool is_prefetched() const { return is_prefetched_; }
|
||||
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
|
||||
bool allow_tf32() const { return allow_tf32_; }
|
||||
|
||||
public:
|
||||
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(dot_inst)
|
||||
_TRITON_DEFINE_ACCEPT(dot_inst)
|
||||
|
||||
private:
|
||||
bool is_prefetched_ = false;
|
||||
bool allow_tf32_ = false;
|
||||
DataType C_type_ = DataType::FP32;
|
||||
DataType A_type_ = DataType::FP16;
|
||||
DataType B_type_ = DataType::FP16;
|
||||
};
|
||||
|
||||
//class outer_inst: public builtin_inst {
|
||||
//private:
|
||||
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
|
||||
//public:
|
||||
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
//};
|
||||
|
||||
class trans_inst: public builtin_inst {
|
||||
public:
|
||||
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
|
||||
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
const std::vector<int> get_perm() const;
|
||||
_TRITON_DEFINE_CLONE(trans_inst)
|
||||
_TRITON_DEFINE_ACCEPT(trans_inst)
|
||||
|
||||
private:
|
||||
std::vector<int> perm_;
|
||||
};
|
||||
|
||||
class sqrt_inst: public builtin_inst {
|
||||
private:
|
||||
sqrt_inst(value *arg, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "sqrt"; }
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(sqrt_inst)
|
||||
_TRITON_DEFINE_ACCEPT(sqrt_inst)
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
public:
|
||||
enum op_t{
|
||||
ADD, SUB, MAX, MIN,
|
||||
FADD, FSUB, FMAX, FMIN,
|
||||
XOR
|
||||
};
|
||||
|
||||
private:
|
||||
static type* get_res_type(value *arg, unsigned axis);
|
||||
static std::string to_str(op_t op);
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
_TRITON_DEFINE_CLONE(reduce_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reduce_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
op_t get_op() const { return op_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
op_t op_;
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "select"; }
|
||||
_TRITON_DEFINE_CLONE(select_inst)
|
||||
_TRITON_DEFINE_ACCEPT(select_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
value* get_pred_op() { return get_operand(0); }
|
||||
value* get_if_value_op() { return get_operand(1); }
|
||||
value* get_else_value_op() { return get_operand(2); }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsics classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
class copy_to_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_to_shared"; }
|
||||
|
||||
public:
|
||||
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
|
||||
};
|
||||
|
||||
class copy_from_shared_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "copy_from_shared"; }
|
||||
|
||||
public:
|
||||
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
|
||||
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
|
||||
};
|
||||
|
||||
class cvt_layout_inst: public unary_inst {
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "cvt_layout_inst"; }
|
||||
|
||||
public:
|
||||
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(cvt_layout_inst)
|
||||
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "barrier"; }
|
||||
_TRITON_DEFINE_CLONE(barrier_inst)
|
||||
_TRITON_DEFINE_ACCEPT(barrier_inst)
|
||||
|
||||
public:
|
||||
static barrier_inst* create(context &ctx, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class async_wait_inst: public instruction{
|
||||
private:
|
||||
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
|
||||
_TRITON_DEFINE_CLONE(async_wait_inst)
|
||||
_TRITON_DEFINE_ACCEPT(async_wait_inst)
|
||||
|
||||
public:
|
||||
static async_wait_inst* create(context &ctx, int N,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
int get_N() { return N_; }
|
||||
void set_N(int n) { N_ = n; }
|
||||
|
||||
private:
|
||||
int N_;
|
||||
};
|
||||
|
||||
class prefetch_s_inst : public instruction {
|
||||
std::string repr_impl() const { return "prefetch_s"; }
|
||||
_TRITON_DEFINE_CLONE(prefetch_s_inst)
|
||||
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
|
||||
|
||||
/// inc_: 0->first, 1->latch
|
||||
int inc_ = 0;
|
||||
public:
|
||||
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
|
||||
set_operand(0, arg);
|
||||
}
|
||||
int get_inc() const { return inc_; }
|
||||
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
|
||||
instruction *next=nullptr);
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
class make_range: public instruction{
|
||||
make_range(type *ty, constant_int* first, constant_int* last);
|
||||
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
|
||||
_TRITON_DEFINE_CLONE(make_range)
|
||||
_TRITON_DEFINE_ACCEPT(make_range)
|
||||
|
||||
public:
|
||||
static make_range *create(constant_int *first, constant_int *last);
|
||||
const constant_int* get_first() const;
|
||||
const constant_int* get_last() const;
|
||||
|
||||
private:
|
||||
constant_int* first_;
|
||||
constant_int* last_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,32 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_METADATA_H_
|
||||
#define _TRITON_IR_METADATA_H_
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/* Metadata */
|
||||
class metadata{
|
||||
public:
|
||||
enum kind_t{
|
||||
multiple_of,
|
||||
max_contiguous
|
||||
};
|
||||
|
||||
private:
|
||||
metadata(kind_t kind, unsigned value);
|
||||
|
||||
public:
|
||||
static metadata* get(kind_t kind, unsigned value);
|
||||
|
||||
private:
|
||||
kind_t kind_;
|
||||
unsigned value_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,92 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_MODULE_H_
|
||||
#define _TRITON_IR_MODULE_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
#include "triton/ir/context.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace lang{
|
||||
|
||||
class iteration_statement;
|
||||
class compound_statement;
|
||||
|
||||
}
|
||||
|
||||
namespace ir{
|
||||
|
||||
class basic_block;
|
||||
class phi_node;
|
||||
class value;
|
||||
class context;
|
||||
class function;
|
||||
class attribute;
|
||||
class function_type;
|
||||
class constant;
|
||||
class global_value;
|
||||
class alloc_const;
|
||||
|
||||
/* Module */
|
||||
|
||||
class module {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
friend class function;
|
||||
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
|
||||
|
||||
public:
|
||||
typedef std::map<std::string, global_value*> symbols_map_t;
|
||||
typedef std::vector<function*> functions_list_t;
|
||||
struct current_iteration_info_t{
|
||||
lang::iteration_statement *statement;
|
||||
basic_block *block;
|
||||
};
|
||||
|
||||
private:
|
||||
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
|
||||
value *try_remove_trivial_phis(ir::phi_node *&phi);
|
||||
value *add_phi_operands(const std::string& name, phi_node *&phi);
|
||||
value *get_value_recursive(const std::string& name, basic_block *block);
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
|
||||
builder &get_builder() { return builder_; };
|
||||
const std::string& get_name() { return name_; };
|
||||
|
||||
// Functions
|
||||
const functions_list_t &get_function_list() const { return functions_; }
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||
// Const allocation
|
||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||
// Register global
|
||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
// Metadata
|
||||
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
|
||||
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
|
||||
void print(std::ostream &os);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
builder &builder_;
|
||||
functions_list_t functions_;
|
||||
symbols_map_t symbols_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
std::map<std::string, md_pair_t> metadatas_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#ifndef _TRITON_IR_PRINT_H_
|
||||
#define _TRITON_IR_PRINT_H_
|
||||
|
||||
#include "builder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
|
||||
void print(module &mod, std::ostream& os);
|
||||
void print(function &func, std::ostream& os);
|
||||
void print(basic_block &bb, std::ostream& os);
|
||||
void print(instruction &instr, std::ostream& os);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,239 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_TYPE_H_
|
||||
#define _TRITON_IR_TYPE_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class context;
|
||||
class value;
|
||||
class integer_type;
|
||||
class constant_int;
|
||||
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
typedef std::vector<unsigned> block_shapes_t;
|
||||
|
||||
protected:
|
||||
typedef std::vector<type*> contained_tys_vec_t;
|
||||
typedef contained_tys_vec_t::iterator ty_iterator;
|
||||
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
|
||||
|
||||
public:
|
||||
enum id_t {
|
||||
// primitive types
|
||||
VoidTyID = 0, ///< type with no size
|
||||
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
|
||||
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
|
||||
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
|
||||
FP32TyID, ///< 32-bit floating point type
|
||||
FP64TyID, ///< 64-bit floating point type
|
||||
LabelTyID, ///< Labels
|
||||
MetadataTyID, ///< Metadata
|
||||
TokenTyID, ///< Token
|
||||
// derived types
|
||||
IntegerTyID, ///< Arbitrary bit width integers
|
||||
FunctionTyID, ///< Functions
|
||||
PointerTyID, ///< Pointers
|
||||
StructTyID, ///< Struct
|
||||
BlockTyID, ///< Block
|
||||
};
|
||||
|
||||
public:
|
||||
//constructors
|
||||
type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
|
||||
|
||||
//destructor
|
||||
virtual ~type(){}
|
||||
|
||||
// accessors
|
||||
context &get_context() const { return ctx_; }
|
||||
id_t get_type_id() const { return id_; }
|
||||
// type attributes
|
||||
unsigned get_fp_mantissa_width() const;
|
||||
unsigned get_integer_bitwidth() const;
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
block_shapes_t get_block_shapes() const;
|
||||
const size_t get_tile_rank() const;
|
||||
const size_t get_tile_ranks1() const;
|
||||
unsigned get_tile_num_elements() const;
|
||||
type *get_tile_element_ty() const;
|
||||
unsigned get_pointer_address_space() const;
|
||||
type *get_pointer_element_ty() const;
|
||||
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_fp8_ty() const { return id_ == FP8TyID; }
|
||||
bool is_fp16_ty() const { return id_ == FP16TyID; }
|
||||
bool is_bf16_ty() const { return id_ == BF16TyID; }
|
||||
bool is_fp32_ty() const { return id_ == FP32TyID; }
|
||||
bool is_fp64_ty() const { return id_ == FP64TyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
bool is_integer_ty(unsigned width) const;
|
||||
bool is_floating_point_ty() const;
|
||||
bool is_sized() const ;
|
||||
|
||||
// Factory methods
|
||||
// primitive types
|
||||
static type *get_void_ty(context &ctx);
|
||||
static type *get_label_ty(context &ctx);
|
||||
// half
|
||||
static type *get_fp8_ty(context &ctx);
|
||||
static type *get_fp16_ty(context &ctx);
|
||||
static type *get_bf16_ty(context &ctx);
|
||||
static type *get_fp32_ty(context &ctx);
|
||||
static type *get_fp64_ty(context &ctx);
|
||||
// integer types
|
||||
static integer_type *get_int1_ty(context &ctx);
|
||||
static integer_type *get_int8_ty(context &ctx);
|
||||
static integer_type *get_int16_ty(context &ctx);
|
||||
static integer_type *get_int32_ty(context &ctx);
|
||||
static integer_type *get_int64_ty(context &ctx);
|
||||
static integer_type *get_int128_ty(context &ctx);
|
||||
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
std::string res = get_tile_element_ty()->repr();
|
||||
auto shapes = get_block_shapes();
|
||||
res += "<";
|
||||
for(size_t i = 0; i < shapes.size(); i++){
|
||||
if(i > 0)
|
||||
res += ", ";
|
||||
res += std::to_string(shapes[i]);
|
||||
}
|
||||
res+= ">";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string repr() const {
|
||||
switch(id_) {
|
||||
case VoidTyID: return "void";
|
||||
case FP8TyID: return "fp8";
|
||||
case FP16TyID: return "f16";
|
||||
case FP32TyID: return "f32";
|
||||
case FP64TyID: return "f64";
|
||||
case BF16TyID: return "bf16";
|
||||
case LabelTyID: return "label";
|
||||
case MetadataTyID: return "md";
|
||||
case TokenTyID: return "tok";
|
||||
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
case BlockTyID: return tile_repr();
|
||||
default: break;
|
||||
}
|
||||
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
|
||||
};
|
||||
|
||||
private:
|
||||
context &ctx_;
|
||||
id_t id_;
|
||||
|
||||
protected:
|
||||
contained_tys_vec_t contained_tys_;
|
||||
};
|
||||
|
||||
class integer_type: public type {
|
||||
friend class context_impl;
|
||||
|
||||
private:
|
||||
// constructors
|
||||
integer_type(context &ctx, unsigned bitwidth)
|
||||
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_bitwidth() const { return bitwidth_; }
|
||||
|
||||
// factory methods
|
||||
static integer_type* get(context &ctx, unsigned width);
|
||||
|
||||
private:
|
||||
unsigned bitwidth_;
|
||||
};
|
||||
|
||||
class composite_type: public type{
|
||||
protected:
|
||||
using type::type;
|
||||
|
||||
public:
|
||||
bool index_valid(value *idx) const;
|
||||
type* get_type_at_index(value *idx) const;
|
||||
};
|
||||
|
||||
class block_type: public composite_type {
|
||||
private:
|
||||
block_type(type *ty, const block_shapes_t &shapes);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const block_shapes_t& get_shapes() const { return shapes_; }
|
||||
unsigned get_num_elements() const;
|
||||
unsigned get_bitwidth() const;
|
||||
|
||||
// factory methods
|
||||
static block_type* get(type *ty, const block_shapes_t &shapes);
|
||||
static block_type* get_same_shapes(type *ty, type *ref);
|
||||
|
||||
private:
|
||||
block_shapes_t shapes_;
|
||||
};
|
||||
|
||||
class pointer_type: public type {
|
||||
private:
|
||||
pointer_type(type *ty, unsigned address_space);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_address_space() const { return address_space_; }
|
||||
type *get_element_ty() const { return contained_tys_[0]; }
|
||||
// factory methods
|
||||
static pointer_type* get(type *ty, unsigned address_space);
|
||||
|
||||
private:
|
||||
unsigned address_space_;
|
||||
};
|
||||
|
||||
class function_type: public type {
|
||||
private:
|
||||
function_type(type *ret_ty, const std::vector<type *> ¶m_tys);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
unsigned get_num_params() const { return contained_tys_.size() - 1; }
|
||||
const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
|
||||
const_ty_iterator params_end() const { return contained_tys_.end(); }
|
||||
ty_iterator params_begin() { return contained_tys_.begin() + 1; }
|
||||
ty_iterator params_end() { return contained_tys_.end(); }
|
||||
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
|
||||
type* get_return_ty() const { return contained_tys_.at(0); }
|
||||
// factory methods
|
||||
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_CFG_H_
|
||||
#define _TRITON_IR_CFG_H_
|
||||
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class value;
|
||||
|
||||
class cfg {
|
||||
public:
|
||||
static std::vector<basic_block *> post_order(function* fn);
|
||||
static std::vector<basic_block *> reverse_post_order(function* fn);
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,95 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VALUE_H_
|
||||
#define _TRITON_IR_VALUE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class type;
|
||||
class use;
|
||||
class user;
|
||||
class visitor;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// value class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class value {
|
||||
public:
|
||||
typedef std::set<user*> users_t;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
value(type *ty, const std::string &name = "");
|
||||
virtual ~value(){ }
|
||||
// uses
|
||||
void add_use(user* arg);
|
||||
users_t::iterator erase_use(user* arg);
|
||||
const std::set<user*> &get_users() { return users_; }
|
||||
void replace_all_uses_with(value *target);
|
||||
// name
|
||||
void set_name(const std::string &name);
|
||||
const std::string &get_name() const { return name_; }
|
||||
bool has_name() const { return !name_.empty(); }
|
||||
type* get_type() const { return ty_; }
|
||||
// visitor
|
||||
virtual void accept(visitor *v) = 0;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
||||
protected:
|
||||
type *ty_;
|
||||
users_t users_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// user class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class user: public value{
|
||||
public:
|
||||
typedef std::vector<value*> ops_t;
|
||||
typedef ops_t::iterator op_iterator;
|
||||
typedef ops_t::const_iterator const_op_iterator;
|
||||
|
||||
protected:
|
||||
void resize_ops(unsigned num_ops) { ops_.resize(num_ops + num_hidden_); num_ops_ = num_ops; }
|
||||
void resize_hidden(unsigned num_hidden) { ops_.resize(num_ops_ + num_hidden); num_hidden_ = num_hidden; }
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
user(type *ty, unsigned num_ops, const std::string &name = "")
|
||||
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
|
||||
}
|
||||
virtual ~user() { }
|
||||
|
||||
// Operands
|
||||
const ops_t& ops() { return ops_; }
|
||||
const ops_t& ops() const { return ops_; }
|
||||
op_iterator op_begin() { return ops_.begin(); }
|
||||
op_iterator op_end() { return ops_.end(); }
|
||||
void set_operand(unsigned i, value *x);
|
||||
value *get_operand(unsigned i) const;
|
||||
unsigned get_num_operands() const ;
|
||||
unsigned get_num_hidden() const;
|
||||
|
||||
// Utils
|
||||
value::users_t::iterator replace_uses_of_with(value *before, value *after);
|
||||
|
||||
|
||||
private:
|
||||
ops_t ops_;
|
||||
unsigned num_ops_;
|
||||
unsigned num_hidden_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,170 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_VISITOR_H_
|
||||
#define _TRITON_IR_VISITOR_H_
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class value;
|
||||
|
||||
class instruction;
|
||||
|
||||
class phi_node;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class cast_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
class s_ext_inst;
|
||||
class fp_trunc_inst;
|
||||
class fp_ext_inst;
|
||||
class ui_to_fp_inst;
|
||||
class si_to_fp_inst;
|
||||
class fp_to_ui_inst;
|
||||
class fp_to_si_inst;
|
||||
class ptr_to_int_inst;
|
||||
class int_to_ptr_inst;
|
||||
class bit_cast_inst;
|
||||
class addr_space_cast_inst;
|
||||
|
||||
class return_inst;
|
||||
class cond_branch_inst;
|
||||
class uncond_branch_inst;
|
||||
|
||||
|
||||
class unmasked_load_inst;
|
||||
class masked_load_inst;
|
||||
class unmasked_store_inst;
|
||||
class masked_store_inst;
|
||||
|
||||
class retile_inst;
|
||||
class reshape_inst;
|
||||
class splat_inst;
|
||||
class cat_inst;
|
||||
class broadcast_inst;
|
||||
class downcast_inst;
|
||||
|
||||
class umulhi_inst;
|
||||
class exp_inst;
|
||||
class cos_inst;
|
||||
class sin_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_rmw_inst;
|
||||
class dot_inst;
|
||||
class trans_inst;
|
||||
class sqrt_inst;
|
||||
class reduce_inst;
|
||||
class select_inst;
|
||||
|
||||
class cvt_layout_inst;
|
||||
class copy_to_shared_inst;
|
||||
class copy_from_shared_inst;
|
||||
class masked_load_async_inst;
|
||||
class barrier_inst;
|
||||
class async_wait_inst;
|
||||
class make_range_dyn;
|
||||
class make_range;
|
||||
class prefetch_s_inst;
|
||||
|
||||
class make_range_sta;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class constant_fp;
|
||||
class undef_value;
|
||||
class constant_int;
|
||||
class constant_fp;
|
||||
class global_value;
|
||||
class global_object;
|
||||
class alloc_const;
|
||||
|
||||
class function;
|
||||
|
||||
class basic_block;
|
||||
|
||||
class argument;
|
||||
|
||||
class visitor {
|
||||
public:
|
||||
virtual ~visitor() {}
|
||||
|
||||
virtual void visit_value(ir::value*);
|
||||
|
||||
virtual void visit_basic_block(basic_block*) = 0;
|
||||
virtual void visit_argument(argument*) = 0;
|
||||
virtual void visit_phi_node(phi_node*) = 0;
|
||||
virtual void visit_binary_operator(binary_operator*) = 0;
|
||||
virtual void visit_getelementptr_inst(getelementptr_inst*) = 0;
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_cast_inst(cast_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
|
||||
virtual void visit_uncond_branch_inst(uncond_branch_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_unmasked_load_inst(unmasked_load_inst*) = 0;
|
||||
virtual void visit_masked_load_inst(masked_load_inst*) = 0;
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
|
||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||
virtual void visit_cos_inst(cos_inst*) = 0;
|
||||
virtual void visit_sin_inst(sin_inst*) = 0;
|
||||
virtual void visit_log_inst(log_inst*) = 0;
|
||||
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_cat_inst(cat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
|
||||
virtual void visit_dot_inst(dot_inst*) = 0;
|
||||
virtual void visit_trans_inst(trans_inst*) = 0;
|
||||
virtual void visit_sqrt_inst(sqrt_inst*) = 0;
|
||||
virtual void visit_reduce_inst(reduce_inst*) = 0;
|
||||
virtual void visit_select_inst(select_inst*) = 0;
|
||||
|
||||
virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
|
||||
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
|
||||
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
|
||||
|
||||
|
||||
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
|
||||
virtual void visit_barrier_inst(barrier_inst*) = 0;
|
||||
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
|
||||
virtual void visit_make_range(make_range*) = 0;
|
||||
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
|
||||
virtual void visit_function(function*) = 0;
|
||||
|
||||
virtual void visit_undef_value(undef_value*) = 0;
|
||||
virtual void visit_constant_int(constant_int*) = 0;
|
||||
virtual void visit_constant_fp(constant_fp*) = 0;
|
||||
virtual void visit_alloc_const(alloc_const*) = 0;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,54 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_TOOLS_BENCH_H_
|
||||
#define _TRITON_TOOLS_BENCH_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
namespace triton{
|
||||
namespace tools{
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
for(size_t i = 0; i < warmup; i++)
|
||||
op();
|
||||
stream->synchronize();
|
||||
tmr.start();
|
||||
for(size_t i = 0; i < repeat; i++){
|
||||
op();
|
||||
}
|
||||
stream->synchronize();
|
||||
return (float)tmr.get().count() / repeat;
|
||||
|
||||
// return *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,69 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
#define _TRITON_TOOLS_THREAD_GRAPH_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace tools{
|
||||
|
||||
template<class node_t>
|
||||
class graph {
|
||||
typedef std::map<node_t, std::set<node_t>> edges_t;
|
||||
|
||||
public:
|
||||
typedef std::map<size_t, std::vector<node_t>> cmap_t;
|
||||
typedef std::map<node_t, size_t> nmap_t;
|
||||
|
||||
private:
|
||||
void connected_components_impl(node_t x, std::set<node_t> &nodes,
|
||||
nmap_t* nmap, cmap_t* cmap, int id) const {
|
||||
if(nmap)
|
||||
(*nmap)[x] = id;
|
||||
if(cmap)
|
||||
(*cmap)[id].push_back(x);
|
||||
if(nodes.find(x) != nodes.end()) {
|
||||
nodes.erase(x);
|
||||
for(const node_t &y: edges_.at(x))
|
||||
connected_components_impl(y, nodes, nmap, cmap, id);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void connected_components(cmap_t *cmap, nmap_t *nmap) const {
|
||||
if(cmap)
|
||||
cmap->clear();
|
||||
if(nmap)
|
||||
nmap->clear();
|
||||
std::set<node_t> nodes = nodes_;
|
||||
unsigned id = 0;
|
||||
while(!nodes.empty()){
|
||||
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
|
||||
}
|
||||
}
|
||||
|
||||
void add_edge(node_t x, node_t y) {
|
||||
nodes_.insert(x);
|
||||
nodes_.insert(y);
|
||||
edges_[x].insert(y);
|
||||
edges_[y].insert(x);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
nodes_.clear();
|
||||
edges_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::set<node_t> nodes_;
|
||||
edges_t edges_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user