Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)

This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Philippe Tillet
2022-12-21 01:30:50 -08:00
committed by GitHub
parent 8650b4d1cb
commit 20100a7254
285 changed files with 26312 additions and 50143 deletions

1
.clang-format Normal file
View File

@ -0,0 +1 @@
BasedOnStyle: LLVM

57
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,57 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# @global-owner1 and @global-owner2 will be requested for
# review when someone opens a pull request.
* @ptillet
# --------
# Analyses
# --------
# Alias analysis
include/triton/Analysis/Alias.h @Jokeren
lib/Analysis/Alias.cpp @Jokeren
# Allocation analysis
include/triton/Analysis/Allocation.h @Jokeren
lib/Analysis/Allocation.cpp @Jokeren
# Membar analysis
include/triton/Analysis/Membar.h @Jokeren
lib/Analysis/Membar.cpp @Jokeren
# AxisInfo analysis
include/triton/Analysis/AxisInfo.h @ptillet
lib/Analysis/AxisInfo.cpp @ptillet
# Utilities
include/triton/Analysis/Utility.h @Jokeren
lib/Analysis/Utility.cpp @Jokeren
# ----------
# Dialects
# ----------
# Pipeline pass
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada
# Prefetch pass
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @daadaada
# Coalesce pass
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet
# Layout simplification pass
lib/Dialect/TritonGPU/Transforms/Combine.cpp @ptillet
# -----------
# Conversions
# -----------
# TritonGPUToLLVM
include/triton/Conversion/TritonGPUToLLVM/ @goostavz @Superjomn
lib/Conversions/TritonGPUToLLVM @goostavz @Superjomn
# TritonToTritonGPU
include/triton/Conversion/TritonToTritonGPU/ @daadaada
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @daadaada
# -------
# Targets
# -------
# LLVMIR
include/triton/Target/LLVMIR/ @goostavz @Superjomn
lib/Target/LLVMIR @goostavz @Superjomn
# PTX
include/triton/Target/PTX/ @goostavz @Superjomn
lib/Target/PTX @goostavz @Superjomn

View File

@ -1,55 +0,0 @@
name: Documentation
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
Build-Documentation:
runs-on: [self-hosted, V100]
steps:
- name: Checkout gh-pages
uses: actions/checkout@v1
with:
ref: 'gh-pages'
- name: Clear docs
run: |
rm -r /tmp/triton-docs
continue-on-error: true
- name: Checkout branch
uses: actions/checkout@v1
- name: Build docs
run: |
git fetch origin master:master
cd docs
sphinx-multiversion . _build/html/
- name: Publish docs
run: |
git branch
# update docs
mkdir /tmp/triton-docs;
mv docs/_build/html/* /tmp/triton-docs/
git checkout gh-pages
cp -r CNAME /tmp/triton-docs/
cp -r index.html /tmp/triton-docs/
cp -r .nojekyll /tmp/triton-docs/
rm -r *
cp -r /tmp/triton-docs/* .
# ln -s master/index.html .
# mv master docs
git add .
git commit -am "[GH-PAGES] Updated website"
# publish docs
eval `ssh-agent -s`
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
git remote set-url origin git@github.com:openai/triton.git
git push

View File

@ -5,50 +5,88 @@ on:
pull_request:
branches:
- master
- triton-mlir
jobs:
Runner-Preparation:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
Integration-Tests:
needs: Runner-Preparation
runs-on: [self-hosted, V100]
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Clear cache
run: |
rm -r ~/.triton/
continue-on-error: true
rm -rf ~/.triton/cache/
- name: Check imports
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e '.[tests]'
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Flake8
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"
- name: Unit tests
- name: Run lit tests
run: |
cd python/test/unit
pytest -vs .
cd python
LIT_TEST_DIR="build/$(ls build)/test"
if [ ! -d "$LIT_TEST_DIR" ]; then
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
fi
lit -v "$LIT_TEST_DIR"
- name: Regression tests
- name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted'}}
run: |
cd python/test/regression
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc
cd python/test/unit/
pytest
- name: Run CXX unittests
run: |
cd python/
cd "build/$(ls build)"
ctest

14
.gitignore vendored
View File

@ -1,12 +1,20 @@
# Triton builds
build/
__pycache__
.pytest_cache
# Triton Python module builds
python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
# Python caches
__pycache__
.pytest_cache
# VS Code project files
.vscode
.vs
# JetBrains project files
.idea
cmake-build-*

View File

@ -3,6 +3,8 @@ include(ExternalProject)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_INCLUDE_CURRENT_DIR ON)
project(triton)
include(CTest)
if(NOT WIN32)
@ -10,8 +12,16 @@ if(NOT WIN32)
endif()
# Options
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
# Ensure Python3 vars are set correctly
# used conditionally in this file and by lit tests
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
# Default build type
if(NOT CMAKE_BUILD_TYPE)
@ -35,13 +45,18 @@ if(WIN32)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()
##########
# LLVM
##########
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
if (NOT MLIR_DIR)
if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
@ -60,95 +75,148 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
endif()
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a
libLLVMAMDGPUDisassembler.a
libLLVMMCDisassembler.a
libLLVMAMDGPUCodeGen.a
libLLVMMIRParser.a
libLLVMGlobalISel.a
libLLVMSelectionDAG.a
libLLVMipo.a
libLLVMInstrumentation.a
libLLVMVectorize.a
libLLVMLinker.a
libLLVMIRReader.a
libLLVMAsmParser.a
libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a
libLLVMDebugInfoDWARF.a
libLLVMCodeGen.a
libLLVMTarget.a
libLLVMScalarOpts.a
libLLVMInstCombine.a
libLLVMAggressiveInstCombine.a
libLLVMTransformUtils.a
libLLVMBitWriter.a
libLLVMAnalysis.a
libLLVMProfileData.a
libLLVMObject.a
libLLVMTextAPI.a
libLLVMBitReader.a
libLLVMAMDGPUAsmParser.a
libLLVMMCParser.a
libLLVMAMDGPUDesc.a
libLLVMAMDGPUUtils.a
libLLVMMC.a
libLLVMDebugInfoCodeView.a
libLLVMDebugInfoMSF.a
libLLVMCore.a
libLLVMRemarks.a
libLLVMBitstreamReader.a
libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a
libLLVMAMDGPUDisassembler.a
libLLVMMCDisassembler.a
libLLVMAMDGPUCodeGen.a
libLLVMMIRParser.a
libLLVMGlobalISel.a
libLLVMSelectionDAG.a
libLLVMipo.a
libLLVMInstrumentation.a
libLLVMVectorize.a
libLLVMLinker.a
libLLVMIRReader.a
libLLVMAsmParser.a
libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a
libLLVMDebugInfoDWARF.a
libLLVMCodeGen.a
libLLVMTarget.a
libLLVMScalarOpts.a
libLLVMInstCombine.a
libLLVMAggressiveInstCombine.a
libLLVMTransformUtils.a
libLLVMBitWriter.a
libLLVMAnalysis.a
libLLVMProfileData.a
libLLVMObject.a
libLLVMTextAPI.a
libLLVMBitReader.a
libLLVMAMDGPUAsmParser.a
libLLVMMCParser.a
libLLVMAMDGPUDesc.a
libLLVMAMDGPUUtils.a
libLLVMMC.a
libLLVMDebugInfoCodeView.a
libLLVMDebugInfoMSF.a
libLLVMCore.a
libLLVMRemarks.a
libLLVMBitstreamReader.a
libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
endif()
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
endif()
include_directories("${LLVM_INCLUDE_DIRS}")
# Python module
if(BUILD_PYTHON_MODULE)
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# Build CUTLASS python wrapper if requested
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
add_definitions(-DWITH_CUTLASS_BINDINGS)
set(CUTLASS_LIBRARIES "cutlass.a")
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
include_directories("." ${PYTHON_SRC_PATH})
if (PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
else()
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif()
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
if (WIN32 AND BUILD_PYTHON_MODULE)
find_package(Python3 REQUIRED COMPONENTS Development)
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
endif()
# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()
# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen) # required by AddMLIR
include(AddLLVM)
include(AddMLIR)
# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
# link_directories(${LLVM_LIBRARY_DIR})
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(bin)
add_library(triton SHARED ${PYTHON_SRC})
# find_package(PythonLibs REQUIRED)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(triton
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonLLVMIR
TritonPTX
${dialect_libs}
${conversion_libs}
# optimizations
MLIRPass
MLIRTransforms
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRIR
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
@ -159,7 +227,7 @@ else()
endif()
if(BUILD_PYTHON_MODULE AND NOT WIN32)
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS
if(APPLE)
@ -167,3 +235,7 @@ if(BUILD_PYTHON_MODULE AND NOT WIN32)
endif()
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif()
add_subdirectory(test)
add_subdirectory(unittest)

View File

@ -12,7 +12,7 @@
# Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment for expressing tensor math workloads that offers high flexibility, developer productivity and end to end performance.
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton!

60
bin/CMakeLists.txt Normal file
View File

@ -0,0 +1,60 @@
add_subdirectory(FileCheck)
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
# TODO: what's this?
llvm_update_compile_flags(triton-opt)
target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIROptLib
MLIRPass
MLIRTransforms
)
mlir_check_all_link_libraries(triton-opt)
# add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
#llvm_update_compile_flags(triton-translate)
# target_link_libraries(triton-translate PRIVATE
# TritonAnalysis
# TritonTransforms
# TritonGPUTransforms
# TritonLLVMIR
# TritonDriver
# ${dialect_libs}
# ${conversion_libs}
# # tests
# TritonTestAnalysis
# LLVMCore
# LLVMSupport
# LLVMOption
# LLVMCodeGen
# LLVMAsmParser
# # MLIR core
# MLIROptLib
# MLIRIR
# MLIRPass
# MLIRSupport
# MLIRTransforms
# MLIRExecutionEngine
# MLIRMathToLLVM
# MLIRTransformUtils
# MLIRLLVMToLLVMIRTranslation
# MLIRNVVMToLLVMIRTranslation
# )
# mlir_check_all_link_libraries(triton-translate)

View File

@ -0,0 +1,2 @@
add_llvm_executable(FileCheck FileCheck.cpp)
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)

882
bin/FileCheck/FileCheck.cpp Normal file
View File

@ -0,0 +1,882 @@
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// FileCheck does a line-by line check of a file that validates whether it
// contains the expected content. This is useful for regression tests etc.
//
// This program exits with an exit status of 2 on error, exit status of 0 if
// the file matched the expected contents, and exit status of 1 if it did not
// contain the expected contents.
//
//===----------------------------------------------------------------------===//
#include "llvm/FileCheck/FileCheck.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <map>
using namespace llvm;
static cl::extrahelp FileCheckOptsEnv(
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
"from the command line.\n");
static cl::opt<std::string>
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
static cl::opt<std::string>
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
cl::init("-"), cl::value_desc("filename"));
static cl::list<std::string> CheckPrefixes(
"check-prefix",
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
static cl::alias CheckPrefixesAlias(
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
cl::NotHidden,
cl::desc(
"Alias for -check-prefix permitting multiple comma separated values"));
static cl::list<std::string> CommentPrefixes(
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
"LLVM's LIT-based test suites, which should be easier to\n"
"maintain if they all follow a consistent comment style. This\n"
"feature is meant for non-LIT test suites using FileCheck."));
static cl::opt<bool> NoCanonicalizeWhiteSpace(
"strict-whitespace",
cl::desc("Do not treat all horizontal whitespace as equivalent"));
static cl::opt<bool> IgnoreCase("ignore-case",
cl::desc("Use case-insensitive matching"));
static cl::list<std::string> ImplicitCheckNot(
"implicit-check-not",
cl::desc("Add an implicit negative check with this pattern to every\n"
"positive check. This can be used to ensure that no instances of\n"
"this pattern occur which are not matched by a positive pattern"),
cl::value_desc("pattern"));
static cl::list<std::string>
GlobalDefines("D", cl::AlwaysPrefix,
cl::desc("Define a variable to be used in capture patterns."),
cl::value_desc("VAR=VALUE"));
static cl::opt<bool> AllowEmptyInput(
"allow-empty", cl::init(false),
cl::desc("Allow the input file to be empty. This is useful when making\n"
"checks that some error message does not occur, for example."));
static cl::opt<bool> AllowUnusedPrefixes(
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
cl::desc("Allow prefixes to be specified but not appear in the test."));
static cl::opt<bool> MatchFullLines(
"match-full-lines", cl::init(false),
cl::desc("Require all positive matches to cover an entire input line.\n"
"Allows leading and trailing whitespace if --strict-whitespace\n"
"is not also passed."));
static cl::opt<bool> EnableVarScope(
"enable-var-scope", cl::init(false),
cl::desc("Enables scope for regex variables. Variables with names that\n"
"do not start with '$' will be reset at the beginning of\n"
"each CHECK-LABEL block."));
static cl::opt<bool> AllowDeprecatedDagOverlap(
"allow-deprecated-dag-overlap", cl::init(false),
cl::desc("Enable overlapping among matches in a group of consecutive\n"
"CHECK-DAG directives. This option is deprecated and is only\n"
"provided for convenience as old tests are migrated to the new\n"
"non-overlapping CHECK-DAG implementation.\n"));
static cl::opt<bool> Verbose(
"v", cl::init(false), cl::ZeroOrMore,
cl::desc("Print directive pattern matches, or add them to the input dump\n"
"if enabled.\n"));
static cl::opt<bool> VerboseVerbose(
"vv", cl::init(false), cl::ZeroOrMore,
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
"issues, or add it to the input dump if enabled. Implies\n"
"-v.\n"));
// The order of DumpInputValue members affects their precedence, as documented
// for -dump-input below.
enum DumpInputValue {
DumpInputNever,
DumpInputFail,
DumpInputAlways,
DumpInputHelp
};
static cl::list<DumpInputValue> DumpInputs(
"dump-input",
cl::desc("Dump input to stderr, adding annotations representing\n"
"currently enabled diagnostics. When there are multiple\n"
"occurrences of this option, the <value> that appears earliest\n"
"in the list below has precedence. The default is 'fail'.\n"),
cl::value_desc("mode"),
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
clEnumValN(DumpInputAlways, "always", "Always dump input"),
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
clEnumValN(DumpInputNever, "never", "Never dump input")));
// The order of DumpInputFilterValue members affects their precedence, as
// documented for -dump-input-filter below.
enum DumpInputFilterValue {
DumpInputFilterError,
DumpInputFilterAnnotation,
DumpInputFilterAnnotationFull,
DumpInputFilterAll
};
static cl::list<DumpInputFilterValue> DumpInputFilters(
"dump-input-filter",
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
"kind <value> plus any context specified by -dump-input-context.\n"
"When there are multiple occurrences of this option, the <value>\n"
"that appears earliest in the list below has precedence. The\n"
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
"-dump-input=always.\n"),
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
"Input lines with annotations"),
clEnumValN(DumpInputFilterAnnotation, "annotation",
"Input lines with starting points of annotations"),
clEnumValN(DumpInputFilterError, "error",
"Input lines with starting points of error "
"annotations")));
static cl::list<unsigned> DumpInputContexts(
"dump-input-context", cl::value_desc("N"),
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
"before and <N> input lines after any lines specified by\n"
"-dump-input-filter. When there are multiple occurrences of\n"
"this option, the largest specified <N> has precedence. The\n"
"default is 5.\n"));
typedef cl::list<std::string>::const_iterator prefix_iterator;
static void DumpCommandLine(int argc, char **argv) {
errs() << "FileCheck command line: ";
for (int I = 0; I < argc; I++)
errs() << " " << argv[I];
errs() << "\n";
}
struct MarkerStyle {
/// The starting char (before tildes) for marking the line.
char Lead;
/// What color to use for this annotation.
raw_ostream::Colors Color;
/// A note to follow the marker, or empty string if none.
std::string Note;
/// Does this marker indicate inclusion by -dump-input-filter=error?
bool FiltersAsError;
MarkerStyle() {}
MarkerStyle(char Lead, raw_ostream::Colors Color,
const std::string &Note = "", bool FiltersAsError = false)
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
assert((!FiltersAsError || !Note.empty()) &&
"expected error diagnostic to have note");
}
};
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
switch (MatchTy) {
case FileCheckDiag::MatchFoundAndExpected:
return MarkerStyle('^', raw_ostream::GREEN);
case FileCheckDiag::MatchFoundButExcluded:
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButWrongLine:
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButDiscarded:
return MarkerStyle('!', raw_ostream::CYAN,
"discard: overlaps earlier match");
case FileCheckDiag::MatchFoundErrorNote:
// Note should always be overridden within the FileCheckDiag.
return MarkerStyle('!', raw_ostream::RED,
"error: unknown error after match",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneAndExcluded:
return MarkerStyle('X', raw_ostream::GREEN);
case FileCheckDiag::MatchNoneButExpected:
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneForInvalidPattern:
return MarkerStyle('X', raw_ostream::RED,
"error: match failed for invalid pattern",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFuzzy:
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
/*FiltersAsError=*/true);
}
llvm_unreachable_internal("unexpected match type");
}
static void DumpInputAnnotationHelp(raw_ostream &OS) {
OS << "The following description was requested by -dump-input=help to\n"
<< "explain the input dump printed by FileCheck.\n"
<< "\n"
<< "Related command-line options:\n"
<< "\n"
<< " - -dump-input=<value> enables or disables the input dump\n"
<< " - -dump-input-filter=<value> filters the input lines\n"
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
<< " - -v and -vv add more annotations\n"
<< " - -color forces colors to be enabled both in the dump and below\n"
<< " - -help documents the above options in more detail\n"
<< "\n"
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
<< "maximum debugging output on failures:\n"
<< "\n"
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
<< "\n"
<< "Input dump annotation format:\n"
<< "\n";
// Labels for input lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
OS << " labels line number L of the input file\n"
<< " An extra space is added after each input line to represent"
<< " the\n"
<< " newline character\n";
// Labels for annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
OS << " labels the only match result for either (1) a pattern of type T"
<< " from\n"
<< " line L of the check file if L is an integer or (2) the"
<< " I-th implicit\n"
<< " pattern if L is \"imp\" followed by an integer "
<< "I (index origin one)\n";
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
OS << " labels the Nth match result for such a pattern\n";
// Markers on annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
OS << " marks good match (reported if -v)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
OS << " marks bad match, such as:\n"
<< " - CHECK-NEXT on same line as previous match (error)\n"
<< " - CHECK-NOT found (error)\n"
<< " - CHECK-DAG overlapping match (discarded, reported if "
<< "-vv)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
OS << " marks search range when no match is found, such as:\n"
<< " - CHECK-NEXT not found (error)\n"
<< " - CHECK-NOT not found (success, reported if -vv)\n"
<< " - CHECK-DAG not found after discarded matches (error)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
OS << " marks fuzzy match when no match is found\n";
// Elided lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
OS << " indicates elided input lines and annotations, as specified by\n"
<< " -dump-input-filter and -dump-input-context\n";
// Colors.
OS << " - colors ";
WithColor(OS, raw_ostream::GREEN, true) << "success";
OS << ", ";
WithColor(OS, raw_ostream::RED, true) << "error";
OS << ", ";
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
OS << "\n";
}
/// An annotation for a single input line.
struct InputAnnotation {
/// The index of the match result across all checks
unsigned DiagIndex;
/// The label for this annotation.
std::string Label;
/// Is this the initial fragment of a diagnostic that has been broken across
/// multiple lines?
bool IsFirstLine;
/// What input line (one-origin indexing) this annotation marks. This might
/// be different from the starting line of the original diagnostic if
/// !IsFirstLine.
unsigned InputLine;
/// The column range (one-origin indexing, open end) in which to mark the
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
/// before the newline.
unsigned InputStartCol, InputEndCol;
/// The marker to use.
MarkerStyle Marker;
/// Whether this annotation represents a good match for an expected pattern.
bool FoundAndExpectedMatch;
};
/// Get an abbreviation for the check type.
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
switch (Ty) {
case Check::CheckPlain:
if (Ty.getCount() > 1)
return "count";
return "check";
case Check::CheckNext:
return "next";
case Check::CheckSame:
return "same";
case Check::CheckNot:
return "not";
case Check::CheckDAG:
return "dag";
case Check::CheckLabel:
return "label";
case Check::CheckEmpty:
return "empty";
case Check::CheckComment:
return "com";
case Check::CheckEOF:
return "eof";
case Check::CheckBadNot:
return "bad-not";
case Check::CheckBadCount:
return "bad-count";
case Check::CheckNone:
llvm_unreachable("invalid FileCheckType");
}
llvm_unreachable("unknown FileCheckType");
}
static void
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
const std::vector<FileCheckDiag> &Diags,
std::vector<InputAnnotation> &Annotations,
unsigned &LabelWidth) {
struct CompareSMLoc {
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
return LHS.getPointer() < RHS.getPointer();
}
};
// How many diagnostics does each pattern have?
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
for (auto Diag : Diags)
++DiagCountPerPattern[Diag.CheckLoc];
// How many diagnostics have we seen so far per pattern?
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
// How many total diagnostics have we seen so far?
unsigned DiagIndex = 0;
// What's the widest label?
LabelWidth = 0;
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
++DiagItr) {
InputAnnotation A;
A.DiagIndex = DiagIndex++;
// Build label, which uniquely identifies this check result.
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
auto CheckLineAndCol =
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
llvm::raw_string_ostream Label(A.Label);
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
if (CheckBufferID == CheckFileBufferID)
Label << CheckLineAndCol.first;
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
CheckBufferID < ImpPatBufferIDRange.second)
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
else
llvm_unreachable("expected diagnostic's check location to be either in "
"the check file or for an implicit pattern");
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
A.Marker = GetMarker(DiagItr->MatchTy);
if (!DiagItr->Note.empty()) {
A.Marker.Note = DiagItr->Note;
// It's less confusing if notes that don't actually have ranges don't have
// markers. For example, a marker for 'with "VAR" equal to "5"' would
// seem to indicate where "VAR" matches, but the location we actually have
// for the marker simply points to the start of the match/search range for
// the full pattern of which the substitution is potentially just one
// component.
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
DiagItr->InputStartCol == DiagItr->InputEndCol)
A.Marker.Lead = ' ';
}
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
assert(!DiagItr->Note.empty() &&
"expected custom note for MatchFoundErrorNote");
A.Marker.Note = "error: " + A.Marker.Note;
}
A.FoundAndExpectedMatch =
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
// Compute the mark location, and break annotation into multiple
// annotations if it spans multiple lines.
A.IsFirstLine = true;
A.InputLine = DiagItr->InputStartLine;
A.InputStartCol = DiagItr->InputStartCol;
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
// Sometimes ranges are empty in order to indicate a specific point, but
// that would mean nothing would be marked, so adjust the range to
// include the following character.
A.InputEndCol =
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
Annotations.push_back(A);
} else {
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
"expected input range not to be inverted");
A.InputEndCol = UINT_MAX;
Annotations.push_back(A);
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
L <= E; ++L) {
// If a range ends before the first column on a line, then it has no
// characters on that line, so there's nothing to render.
if (DiagItr->InputEndCol == 1 && L == E)
break;
InputAnnotation B;
B.DiagIndex = A.DiagIndex;
B.Label = A.Label;
B.IsFirstLine = false;
B.InputLine = L;
B.Marker = A.Marker;
B.Marker.Lead = '~';
B.Marker.Note = "";
B.InputStartCol = 1;
if (L != E)
B.InputEndCol = UINT_MAX;
else
B.InputEndCol = DiagItr->InputEndCol;
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
Annotations.push_back(B);
}
}
}
}
static unsigned FindInputLineInFilter(
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
if (DumpInputFilter == DumpInputFilterAll)
return CurInputLine;
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
++AnnotationItr) {
switch (DumpInputFilter) {
case DumpInputFilterAll:
llvm_unreachable("unexpected DumpInputFilterAll");
break;
case DumpInputFilterAnnotationFull:
return AnnotationItr->InputLine;
case DumpInputFilterAnnotation:
if (AnnotationItr->IsFirstLine)
return AnnotationItr->InputLine;
break;
case DumpInputFilterError:
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
return AnnotationItr->InputLine;
break;
}
}
return UINT_MAX;
}
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
unsigned LabelWidth) {
if (ElidedLines.empty())
return;
unsigned EllipsisLines = 3;
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
for (unsigned i = 0; i < EllipsisLines; ++i) {
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
<< right_justify(".", LabelWidth);
OS << '\n';
}
} else
OS << ElidedLines;
ElidedLines.clear();
}
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
DumpInputFilterValue DumpInputFilter,
unsigned DumpInputContext,
StringRef InputFileText,
std::vector<InputAnnotation> &Annotations,
unsigned LabelWidth) {
OS << "Input was:\n<<<<<<\n";
// Sort annotations.
llvm::sort(Annotations,
[](const InputAnnotation &A, const InputAnnotation &B) {
// 1. Sort annotations in the order of the input lines.
//
// This makes it easier to find relevant annotations while
// iterating input lines in the implementation below. FileCheck
// does not always produce diagnostics in the order of input
// lines due to, for example, CHECK-DAG and CHECK-NOT.
if (A.InputLine != B.InputLine)
return A.InputLine < B.InputLine;
// 2. Sort annotations in the temporal order FileCheck produced
// their associated diagnostics.
//
// This sort offers several benefits:
//
// A. On a single input line, the order of annotations reflects
// the FileCheck logic for processing directives/patterns.
// This can be helpful in understanding cases in which the
// order of the associated directives/patterns in the check
// file or on the command line either (i) does not match the
// temporal order in which FileCheck looks for matches for the
// directives/patterns (due to, for example, CHECK-LABEL,
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
// that order but does not match the order of those
// diagnostics along an input line (due to, for example,
// CHECK-DAG).
//
// On the other hand, because our presentation format presents
// input lines in order, there's no clear way to offer the
// same benefit across input lines. For consistency, it might
// then seem worthwhile to have annotations on a single line
// also sorted in input order (that is, by input column).
// However, in practice, this appears to be more confusing
// than helpful. Perhaps it's intuitive to expect annotations
// to be listed in the temporal order in which they were
// produced except in cases the presentation format obviously
// and inherently cannot support it (that is, across input
// lines).
//
// B. When diagnostics' annotations are split among multiple
// input lines, the user must track them from one input line
// to the next. One property of the sort chosen here is that
// it facilitates the user in this regard by ensuring the
// following: when comparing any two input lines, a
// diagnostic's annotations are sorted in the same position
// relative to all other diagnostics' annotations.
return A.DiagIndex < B.DiagIndex;
});
// Compute the width of the label column.
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
*InputFileEnd = InputFileText.bytes_end();
unsigned LineCount = InputFileText.count('\n');
if (InputFileEnd[-1] != '\n')
++LineCount;
unsigned LineNoWidth = std::log10(LineCount) + 1;
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
// on input lines and (2) to the right of the (left-aligned) labels on
// annotation lines so that input lines and annotation lines are more
// visually distinct. For example, the spaces on the annotation lines ensure
// that input line numbers and check directive line numbers never align
// horizontally. Those line numbers might not even be for the same file.
// One space would be enough to achieve that, but more makes it even easier
// to see.
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
// Print annotated input lines.
unsigned PrevLineInFilter = 0; // 0 means none so far
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
std::string ElidedLines;
raw_string_ostream ElidedLinesOS(ElidedLines);
ColorMode TheColorMode =
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
if (TheColorMode == ColorMode::Enable)
ElidedLinesOS.enable_colors(true);
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
for (unsigned Line = 1;
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
const unsigned char *InputFileLine = InputFilePtr;
// Compute the previous and next line included by the filter.
if (NextLineInFilter < Line)
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
AnnotationItr, AnnotationEnd);
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
if (NextLineInFilter == Line)
PrevLineInFilter = Line;
// Elide this input line and its annotations if it's not within the
// context specified by -dump-input-context of an input line included by
// -dump-input-filter. However, in case the resulting ellipsis would occupy
// more lines than the input lines and annotations it elides, buffer the
// elided lines and annotations so we can print them instead.
raw_ostream *LineOS = &OS;
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
(NextLineInFilter == UINT_MAX ||
Line + DumpInputContext < NextLineInFilter))
LineOS = &ElidedLinesOS;
else {
LineOS = &OS;
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
}
// Print right-aligned line number.
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
TheColorMode)
<< format_decimal(Line, LabelWidth) << ": ";
// For the case where -v and colors are enabled, find the annotations for
// good matches for expected patterns in order to highlight everything
// else in the line. There are no such annotations if -v is disabled.
std::vector<InputAnnotation> FoundAndExpectedMatches;
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
++I) {
if (I->FoundAndExpectedMatch)
FoundAndExpectedMatches.push_back(*I);
}
}
// Print numbered line with highlighting where there are no matches for
// expected patterns.
bool Newline = false;
{
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
/*BG=*/false, TheColorMode);
bool InMatch = false;
if (Req.Verbose)
COS.changeColor(raw_ostream::CYAN, true, true);
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
bool WasInMatch = InMatch;
InMatch = false;
for (auto M : FoundAndExpectedMatches) {
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
InMatch = true;
break;
}
}
if (!WasInMatch && InMatch)
COS.resetColor();
else if (WasInMatch && !InMatch)
COS.changeColor(raw_ostream::CYAN, true, true);
if (*InputFilePtr == '\n') {
Newline = true;
COS << ' ';
} else
COS << *InputFilePtr;
++InputFilePtr;
}
}
*LineOS << '\n';
unsigned InputLineWidth = InputFilePtr - InputFileLine;
// Print any annotations.
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
/*BG=*/false, TheColorMode);
// The two spaces below are where the ": " appears on input lines.
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
unsigned Col;
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
COS << ' ';
COS << AnnotationItr->Marker.Lead;
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
++Col)
COS << '~';
const std::string &Note = AnnotationItr->Marker.Note;
if (!Note.empty()) {
// Put the note at the end of the input line. If we were to instead
// put the note right after the marker, subsequent annotations for the
// same input line might appear to mark this note instead of the input
// line.
for (; Col <= InputLineWidth; ++Col)
COS << ' ';
COS << ' ' << Note;
}
COS << '\n';
++AnnotationItr;
}
}
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
OS << ">>>>>>\n";
}
int main(int argc, char **argv) {
// Enable use of ANSI color codes because FileCheck is using them to
// highlight text.
llvm::sys::Process::UseANSIEscapeCodes(true);
InitLLVM X(argc, argv);
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
"FILECHECK_OPTS");
// Select -dump-input* values. The -help documentation specifies the default
// value and which value to choose if an option is specified multiple times.
// In the latter case, the general rule of thumb is to choose the value that
// provides the most information.
DumpInputValue DumpInput =
DumpInputs.empty()
? DumpInputFail
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
DumpInputFilterValue DumpInputFilter;
if (DumpInputFilters.empty())
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
: DumpInputFilterError;
else
DumpInputFilter =
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
unsigned DumpInputContext = DumpInputContexts.empty()
? 5
: *std::max_element(DumpInputContexts.begin(),
DumpInputContexts.end());
if (DumpInput == DumpInputHelp) {
DumpInputAnnotationHelp(outs());
return 0;
}
if (CheckFilename.empty()) {
errs() << "<check-file> not specified\n";
return 2;
}
FileCheckRequest Req;
append_range(Req.CheckPrefixes, CheckPrefixes);
append_range(Req.CommentPrefixes, CommentPrefixes);
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
bool GlobalDefineError = false;
for (StringRef G : GlobalDefines) {
size_t EqIdx = G.find('=');
if (EqIdx == std::string::npos) {
errs() << "Missing equal sign in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
if (EqIdx == 0) {
errs() << "Missing variable name in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
Req.GlobalDefines.push_back(G);
}
if (GlobalDefineError)
return 2;
Req.AllowEmptyInput = AllowEmptyInput;
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
Req.EnableVarScope = EnableVarScope;
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
Req.Verbose = Verbose;
Req.VerboseVerbose = VerboseVerbose;
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
Req.MatchFullLines = MatchFullLines;
Req.IgnoreCase = IgnoreCase;
if (VerboseVerbose)
Req.Verbose = true;
FileCheck FC(Req);
if (!FC.ValidateCheckPrefixes())
return 2;
Regex PrefixRE = FC.buildCheckPrefixRegex();
std::string REError;
if (!PrefixRE.isValid(REError)) {
errs() << "Unable to combine check-prefix strings into a prefix regular "
"expression! This is likely a bug in FileCheck's verification of "
"the check-prefix strings. Regular expression parsing failed "
"with the following error: "
<< REError << "\n";
return 2;
}
SourceMgr SM;
// Read the expected strings from the check file.
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
if (std::error_code EC = CheckFileOrErr.getError()) {
errs() << "Could not open check file '" << CheckFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
SmallString<4096> CheckFileBuffer;
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
unsigned CheckFileBufferID =
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
CheckFileText, CheckFile.getBufferIdentifier()),
SMLoc());
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
return 2;
// Open the file to check and add it to SourceMgr.
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
if (InputFilename == "-")
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
if (std::error_code EC = InputFileOrErr.getError()) {
errs() << "Could not open input file '" << InputFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &InputFile = *InputFileOrErr.get();
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
DumpCommandLine(argc, argv);
return 2;
}
SmallString<4096> InputFileBuffer;
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
InputFileText, InputFile.getBufferIdentifier()),
SMLoc());
std::vector<FileCheckDiag> Diags;
int ExitCode = FC.checkInput(SM, InputFileText,
DumpInput == DumpInputNever ? nullptr : &Diags)
? EXIT_SUCCESS
: 1;
if (DumpInput == DumpInputAlways ||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
errs() << "\n"
<< "Input file: " << InputFilename << "\n"
<< "Check file: " << CheckFilename << "\n"
<< "\n"
<< "-dump-input=help explains the following input dump.\n"
<< "\n";
std::vector<InputAnnotation> Annotations;
unsigned LabelWidth;
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
Annotations, LabelWidth);
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
InputFileText, Annotations, LabelWidth);
}
return ExitCode;
}

42
bin/triton-opt.cpp Normal file
View File

@ -0,0 +1,42 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/Passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
namespace mlir {
namespace test {
void registerTestAliasPass();
void registerTestAlignmentPass();
void registerTestAllocationPass();
void registerTestMembarPass();
} // namespace test
} // namespace mlir
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
}

131
bin/triton-translate.cpp Normal file
View File

@ -0,0 +1,131 @@
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/llvm.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <iostream>
namespace mlir {
namespace triton {
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
MLIRContext &context) {
std::string errorMessage;
auto input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
mlir::DialectRegistry registry;
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, arith::ArithmeticDialect,
StandardOpsDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
-> OwningOpRef<ModuleOp> {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
if (!module) {
llvm::errs() << "Parse MLIR file failed.";
return nullptr;
}
return module;
};
auto module = processBuffer(std::move(input));
if (!module) {
return nullptr;
}
return module;
}
LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::StringRef toolName) {
static llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> targetKind(
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
llvm::cl::init(80));
static llvm::cl::opt<int> ptxVersion(
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
llvm::InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
mlir::MLIRContext context;
auto module = loadMLIRModule(inputFilename, context);
if (!module) {
return failure();
}
std::string errorMessage;
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return failure();
}
llvm::LLVMContext llvmContext;
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
if (targetKind == "llvmir")
llvm::outs() << *llvmir << '\n';
else if (targetKind == "ptx")
llvm::outs() << ::triton::driver::llir_to_ptx(
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
return success();
}
} // namespace triton
} // namespace mlir
int main(int argc, char **argv) {
return failed(mlir::triton::tritonTranslateMain(
argc, argv, "Triton Translate Testing Tool."));
}

1
deps/dlfcn-win32 vendored

Submodule deps/dlfcn-win32 deleted from 522c301ec3

View File

@ -45,7 +45,7 @@ def setup(app):
def wrapped(obj, **kwargs):
import triton
if isinstance(obj, triton.runtime.JITFunction):
if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn
return old(obj)
@ -56,7 +56,7 @@ def setup(app):
def documenter(app, obj, parent):
import triton
if isinstance(obj, triton.runtime.JITFunction):
if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn
return old_documenter(app, obj, parent)

View File

@ -34,13 +34,11 @@ You can install the Python package from source by running the following commands
.. code-block:: bash
git clone https://github.com/openai/triton.git;
cd triton;
git submodule update --init --recursive;
cd python;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
Note that, if llvm-11 is not present on your system and you are on linux, the setup.py script will download the official LLVM11 static libraries link against that. For windows users, LLVM must be installed and configured in PATH.
Note that, if llvm-11 is not present on your system, the setup.py script will download the official LLVM11 static libraries link against that.
You can then test your installation by running the unit tests:

View File

@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
Limitations
++++++++++++
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
.. table::
:widths: 50 50

View File

@ -106,13 +106,9 @@ Atomic Ops
:nosignatures:
atomic_cas
atomic_xchg
atomic_add
atomic_max
atomic_min
atomic_and
atomic_or
atomic_xor
Comparison ops

1
include/CMakeLists.txt Normal file
View File

@ -0,0 +1 @@
add_subdirectory(triton)

View File

@ -0,0 +1,80 @@
#ifndef TRITON_ANALYSIS_ALIAS_H
#define TRITON_ANALYSIS_ALIAS_H
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
class AliasInfo {
public:
AliasInfo() = default;
AliasInfo(Value value) { insert(value); }
void insert(Value value) { allocs.insert(value); }
const DenseSet<Value> &getAllocs() const { return allocs; }
bool operator==(const AliasInfo &other) const {
return allocs == other.allocs;
}
/// The pessimistic value state of a value without alias
static AliasInfo getPessimisticValueState(MLIRContext *context) {
return AliasInfo();
}
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
/// The union of both arguments
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
private:
/// The set of allocated values that are aliased by this lattice.
/// For now, we only consider aliased value produced by the following
/// situations:
/// 1. values returned by scf.yield
/// 2. block arguments in scf.for
/// Example:
/// alloc v1 alloc v2
/// | |
/// |--------------| |------------|
/// scf.for v3 scf.for v4 scf.for v5
/// |
/// scf.yield v6
///
/// v1's alloc [v1]
/// v2's alloc [v2]
/// v3's alloc [v1]
/// v4's alloc [v1, v2]
/// v5's alloc [v2]
/// v6's alloc [v1]
///
/// Therefore, v1's liveness range is the union of v3, v4, and v6
/// v2's liveness range is the union of v4 and v5.
DenseSet<Value> allocs;
};
//===----------------------------------------------------------------------===//
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
public:
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
/// Given two values, returns their aliasing behavior.
AliasResult alias(Value lhs, Value rhs);
/// Returns the modify-reference behavior of `op` on `location`.
ModRefResult getModRef(Operation *op, Value location);
/// Computes if the alloc set of the results are changed.
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_ALIAS_H

View File

@ -0,0 +1,192 @@
#ifndef TRITON_ANALYSIS_ALLOCATION_H
#define TRITON_ANALYSIS_ALLOCATION_H
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/raw_ostream.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <atomic>
#include <limits>
namespace mlir {
namespace triton {
class AllocationAnalysis;
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
/// A class that represents an interval, specified using a start and an end
/// values: [Start, End).
template <typename T> class Interval {
public:
Interval() {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
T size() const { return End - Start; }
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
bool intersects(const Interval &R) const {
return Start < R.End && R.Start < End;
}
bool operator==(const Interval &R) const {
return Start == R.Start && End == R.End;
}
bool operator!=(const Interval &R) const { return !(*this == R); }
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
private:
T Start = std::numeric_limits<T>::min();
T End = std::numeric_limits<T>::max();
};
class Allocation {
public:
/// A unique identifier for shared memory buffers
using BufferId = size_t;
using BufferIdSetT = DenseSet<BufferId>;
static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();
/// Creates a new Allocation analysis that computes the shared memory
/// information for all associated shared memory values.
Allocation(Operation *operation) : operation(operation) { run(); }
/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
/// Returns the offset of the given buffer in the shared memory.
size_t getOffset(BufferId bufferId) const {
return bufferSet.lookup(bufferId).offset;
}
/// Returns the size of the given buffer in the shared memory.
size_t getAllocatedSize(BufferId bufferId) const {
return bufferSet.lookup(bufferId).size;
}
/// Returns the buffer id of the given value.
/// This interface only returns the allocated buffer id.
/// If you want to get all the buffer ids that are associated with the given
/// value, including alias buffers, use getBufferIds.
BufferId getBufferId(Value value) const {
if (valueBuffer.count(value)) {
return valueBuffer.lookup(value)->id;
} else {
return InvalidBufferId;
}
}
/// Returns all the buffer ids of the given value, including alias buffers.
BufferIdSetT getBufferIds(Value value) const {
BufferIdSetT bufferIds;
auto allocBufferId = getBufferId(value);
if (allocBufferId != InvalidBufferId)
bufferIds.insert(allocBufferId);
for (auto *buffer : aliasBuffer.lookup(value)) {
if (buffer->id != InvalidBufferId)
bufferIds.insert(buffer->id);
}
return bufferIds;
}
/// Returns the scratch buffer id of the given value.
BufferId getBufferId(Operation *operation) const {
if (opScratch.count(operation)) {
return opScratch.lookup(operation)->id;
} else {
return InvalidBufferId;
}
}
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
return false;
auto lhsBuffer = bufferSet.lookup(lhsId);
auto rhsBuffer = bufferSet.lookup(rhsId);
return lhsBuffer.intersects(rhsBuffer);
}
private:
/// A class that represents a shared memory buffer
struct BufferT {
enum class BufferKind { Explicit, Scratch };
/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;
BufferKind kind;
BufferId id;
size_t size;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
bool intersects(const BufferT &other) const {
return Interval<size_t>(offset, offset + size)
.intersects(
Interval<size_t>(other.offset, other.offset + other.size));
}
};
/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = DenseMap<BufferId, BufferT>;
/// Runs allocation analysis on the given top-level operation.
void run();
private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&...args) {
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
bufferSet[buffer.id] = std::move(buffer);
if constexpr (Kind == BufferT::BufferKind::Explicit) {
valueBuffer[key] = &bufferSet[buffer.id];
} else {
opScratch[key] = &bufferSet[buffer.id];
}
}
void addAlias(Value value, Value alloc) {
aliasBuffer[value].insert(valueBuffer[alloc]);
}
private:
Operation *operation;
OpScratchMapT opScratch;
ValueBufferMapT valueBuffer;
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;
friend class triton::AllocationAnalysis;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_ALLOCATION_H

View File

@ -0,0 +1,144 @@
#ifndef TRITON_ANALYSIS_AXISINFO_H
#define TRITON_ANALYSIS_AXISINFO_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
/// This lattice value represents known information on the axes of a lattice.
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef SmallVector<int, 4> DimVectorT;
public:
// Default constructor
AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == (size_t)rank);
assert(knownConstancy.size() == (size_t)rank);
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const DimVectorT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DimVectorT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const DimVectorT &getConstancy() const { return constancy; }
int getRank() const { return rank; }
// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&
(divisibility == other.divisibility) &&
(constancy == other.constancy);
}
/// The pessimistic value state of the contiguity is unknown.
static AxisInfo getPessimisticValueState(MLIRContext *context) {
return AxisInfo();
}
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of contiguous integers along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
/// Would have contiguity [1, 4].
/// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
/// [18, 22, 26, 30]
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
DimVectorT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
// would have divisibility [1, 2]
// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of constant integer along it. This is
/// particularly useful to infer the contiguity
/// of operations (e.g., add) involving a constant
/// For example
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
DimVectorT constancy;
// number of dimensions of the lattice
int rank;
};
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n) {
if (n == 0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
unsigned getPtrVectorSize(Value ptr);
unsigned getPtrAlignment(Value ptr);
unsigned getMaskAlignment(Value mask);
};
} // namespace mlir
#endif

View File

@ -0,0 +1,119 @@
#ifndef TRITON_ANALYSIS_MEMBAR_H
#define TRITON_ANALYSIS_MEMBAR_H
#include "Allocation.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
class OpBuilder;
//===----------------------------------------------------------------------===//
// Shared Memory Barrier Analysis
//===----------------------------------------------------------------------===//
class MembarAnalysis {
public:
/// Creates a new Membar analysis that generates the shared memory barrier
/// in the following circumstances:
/// - RAW: If a shared memory write is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// - WAR: If a shared memory read is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// The following circumstances do not require a barrier:
/// - WAW: not possible because overlapped memory allocation is not allowed.
/// - RAR: no write is performed.
/// Temporary storage of operations such as Reduce are considered as both
/// a shared memory read. If the temporary storage is written but not read,
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
private:
struct RegionInfo {
using BufferIdSetT = Allocation::BufferIdSetT;
BufferIdSetT syncReadBuffers;
BufferIdSetT syncWriteBuffers;
RegionInfo() = default;
RegionInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
}
/// Unions two RegionInfo objects.
void join(const RegionInfo &other) {
syncReadBuffers.insert(other.syncReadBuffers.begin(),
other.syncReadBuffers.end());
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
other.syncWriteBuffers.end());
}
/// Returns true if buffers in two RegionInfo objects are intersected.
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/
isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation) ||
/*WAW*/
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
allocation);
}
/// Clears the buffers because a barrier is inserted.
void sync() {
syncReadBuffers.clear();
syncWriteBuffers.clear();
}
private:
/// Returns true if buffers in two sets are intersected.
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
Allocation *allocation) const {
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
return allocation->isIntersected(lhsId, rhsId);
});
});
}
};
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:
/// region1
/// op1
/// op2 (scf.if)
/// region2
/// op3
/// op4
/// region3
/// op5
/// op6
/// op7
/// region2 and region3 started with the information of region1.
/// Each region is analyzed separately and keeps their own copy of the
/// information. At op7, we union the information of the region2 and region3
/// and update the information of region1.
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
/// Updates the RegionInfo operation based on the operation.
void transfer(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
private:
Allocation *allocation;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_MEMBAR_H

View File

@ -0,0 +1,82 @@
#ifndef TRITON_ANALYSIS_UTILITY_H
#define TRITON_ANALYSIS_UTILITY_H
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
#include <string>
namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.operand().getType().cast<RankedTensorType>();
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
Attribute getSrcLayout() { return srcTy.getEncoding(); }
bool isFastReduction();
unsigned getInterWarpSize();
unsigned getIntraWarpSize();
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
unsigned getScratchSizeInBytes();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
};
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);
bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
Type getElementType(Value value);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
SmallVector<T_OUT> out;
for (const T_IN &i : in)
out.push_back(T_OUT(i));
return out;
}
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
template <typename T, typename RES_T = T>
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<RES_T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[it.index()] = input[it.value()];
}
return result;
}
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@ -0,0 +1,2 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TritonConversionPassIncGen)

View File

@ -0,0 +1,40 @@
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
#define TRITON_CONVERSION_MLIR_TYPES_H_
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
// This file redefines some common MLIR types for easy usage.
namespace mlir {
namespace triton {
namespace type {
// Integer types
// TODO(Superjomn): may change `static` into better implementations
static Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
static Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); }
static Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
static Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
static Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types
static Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
static Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
static Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
static Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
}
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
} // namespace type
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_MLIR_TYPES_H_

View File

@ -0,0 +1,17 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
namespace mlir {
namespace triton {
#define GEN_PASS_REGISTRATION
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@ -0,0 +1,54 @@
#ifndef TRITON_CONVERSION_PASSES
#define TRITON_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
let summary = "Convert Triton to TritonGPU";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::StandardOpsDialect",
// TODO: Does this pass depend on SCF?
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::gpu::GPUDialect",
"mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect",
"mlir::StandardOpsDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
#endif

View File

@ -0,0 +1,326 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
namespace mlir {
class ConversionPatternRewriter;
class Location;
namespace triton {
using llvm::StringRef;
struct PTXInstr;
struct PTXInstrCommon;
struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
//
// A helper for building an ASM program, the objective of PTXBuilder is to give
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// Currently, several factors are introduced to reduce the need for mixing
// string and C++ if-else code.
//
// Usage:
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
// "b"(p));
//
// PTXBuilder builder;
// auto& add = builder.create<>();
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
// // predicate here binds %0 to pVal, pVal is a mlir::Value
//
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
//
// To get the asm code:
// builder.dump()
//
// To get all the mlir::Value used in the PTX code,
//
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
//
// To get the string containing all the constraints with "," separated,
// builder.getConstraints() // get "=r,r,k"
//
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
//
// PTXBuilder builder;
// auto& mov = builder.create("mov");
// auto& cp = builder.create("cp");
// mov(...);
// cp(...);
// This will get a PTX code with two instructions.
//
// Similar to a C function, a declared PTXInstr instance can be launched
// multiple times with different operands, e.g.
//
// auto& mov = builder.create("mov");
// mov(... some operands ...);
// mov(... some different operands ...);
//
// Finally, we will get a PTX code with two mov instructions.
//
// There are several derived instruction type for typical instructions, for
// example, the PtxIOInstr for ld and st instructions.
struct PTXBuilder {
struct Operand {
std::string constraint;
Value value;
int idx{-1};
llvm::SmallVector<Operand *> list;
std::function<std::string(int idx)> repr;
// for list
Operand() = default;
Operand(const Operation &) = delete;
Operand(Value value, StringRef constraint)
: constraint(constraint), value(value) {}
bool isList() const { return !value && constraint.empty(); }
Operand *listAppend(Operand *arg) {
list.push_back(arg);
return this;
}
Operand *listGet(size_t nth) const {
assert(nth < list.size());
return list[nth];
}
std::string dump() const;
};
template <typename INSTR = PTXInstr, typename... Args>
INSTR *create(Args &&...args) {
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
return static_cast<INSTR *>(instrs.back().get());
}
// Create a list of operands.
Operand *newListOperand() { return newOperand(); }
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
auto *list = newOperand();
for (auto &item : items) {
list->listAppend(newOperand(item.first, item.second));
}
return list;
}
Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) {
auto *list = newOperand();
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint));
}
return list;
}
Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand();
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint));
}
return list;
}
// Create a new operand. It will not add to operand list.
// @value: the MLIR value bind to this operand.
// @constraint: ASM operand constraint, .e.g. "=r"
// @formatter: extra format to represent this operand in ASM code, default is
// "%{0}".format(operand.idx).
Operand *newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int idx)> formatter = nullptr);
// Create a new operand which is written to, that is, the constraint starts
// with "=", e.g. "=r".
Operand *newOperand(StringRef constraint);
// Create a constant integer operand.
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.
Operand *newConstantOperand(const std::string &v);
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
llvm::SmallVector<Operand *, 4> getAllArgs() const;
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
std::string getConstraints() const;
std::string dump() const;
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
Type resTy, bool hasSideEffect = true,
bool isAlignStack = false,
ArrayRef<Attribute> attrs = {}) const;
private:
Operand *newOperand() {
argArchive.emplace_back(std::make_unique<Operand>());
return argArchive.back().get();
}
// Make the operands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
// determined by PTX code snippet passed from external.
sort(argArchive.begin(), argArchive.end(),
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
auto ida = std::find(order.begin(), order.end(), a.get());
auto idb = std::find(order.begin(), order.end(), b.get());
assert(ida != order.end());
assert(idb != order.end());
return ida < idb;
});
}
friend struct PTXInstr;
friend struct PTXInstrCommon;
protected:
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
int oprCounter{};
};
// PTX instruction common interface.
// Put the generic logic for all the instructions here.
struct PTXInstrCommon {
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
using Operand = PTXBuilder::Operand;
// clang-format off
PTXInstrExecution& operator()() { return call({}); }
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
// clang-format on
// Set operands of this instruction.
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
protected:
// "Call" the instruction with operands.
// \param oprs The operands of this instruction.
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
// code.
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
friend struct PTXInstrExecution;
};
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
using Operand = PTXBuilder::Operand;
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
: PTXInstrCommon(builder) {
o(name);
}
// Append a suffix to the instruction.
// e.g. PTXInstr("add").o("s32") get a add.s32.
// A predicate is used to tell whether to apply the suffix, so that no if-else
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
// get a `add.s32` if isS32 is true.
ConcreteT &o(const std::string &suffix, bool predicate = true) {
if (predicate)
instrParts.push_back(suffix);
return *static_cast<ConcreteT *>(this);
}
};
struct PTXInstr : public PTXInstrBase<PTXInstr> {
using PTXInstrBase<PTXInstr>::PTXInstrBase;
// Append a ".global" to the instruction.
PTXInstr &global();
// Append a ".shared" to the instruction.
PTXInstr &shared();
// Append a ".v[0-9]+" to the instruction
PTXInstr &v(int vecWidth, bool predicate = true);
// Append a".b[0-9]+" to the instruction
PTXInstr &b(int width);
};
// Record the operands and context for "launching" a PtxInstr.
struct PTXInstrExecution {
using Operand = PTXBuilder::Operand;
llvm::SmallVector<Operand *> argsInOrder;
PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
// Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
pred = instr->builder->newOperand(value, constraint);
return *this;
}
// Prefix a !predicate to the instruction.
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
pred = instr->builder->newOperand(value, constraint);
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
return *this;
}
std::string dump() const;
SmallVector<Operand *> getArgList() const;
PTXInstrCommon *instr{};
Operand *pred{};
bool onlyAttachMLIRArgs{};
};
/// ====== Some instruction wrappers ======
// We add the wrappers to make the usage more intuitive by avoiding mixing the
// PTX code with some trivial C++ code.
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXInstrBase(builder, "cp.async") {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
};
} // namespace triton
} // namespace mlir
#endif

View File

@ -0,0 +1,22 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
} // namespace triton
} // namespace mlir
#endif

View File

@ -0,0 +1,25 @@
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps);
} // namespace triton
} // namespace mlir
#endif

View File

@ -0,0 +1,2 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,19 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(TritonTableGen)

View File

@ -0,0 +1,48 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.h.inc"
namespace mlir {
namespace triton {
class DialectInferLayoutInterface
: public DialectInterface::Base<DialectInferLayoutInterface> {
public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
Optional<Location> location) const = 0;
// Note: this function only verify operand encoding but doesn't infer result
// encoding
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const = 0;
};
} // namespace triton
} // namespace mlir
#endif // TRITON_IR_DIALECT_H_

View File

@ -0,0 +1,9 @@
#ifndef TRITON_IR_INTERFACES_H_
#define TRITON_IR_INTERFACES_H_
#include "mlir/IR/OpDefinition.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@ -0,0 +1,60 @@
#ifndef TRITON_IR_TRAITS_H_
#define TRITON_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include <iostream>
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
LogicalResult verifySameOperandsEncoding(Operation *op);
// The rationale for this trait is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
// Since H100 has 256KB registers, we should allow users to create tensors
// of size up to 256K elements. It will spill for datatypes wider than 1B,
// but we probably should limit number of elements (rather than bytes) to
// keep specs simple
int constexpr maxTensorNumElements = 1048576;
LogicalResult verifyTensorSize(Operation *op);
} // namespace impl
template <class ConcreteType>
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyTensorSize(op);
}
};
template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultEncoding(op);
}
};
template <typename ConcreteType>
class SameOperandsEncoding
: public TraitBase<ConcreteType, SameOperandsEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@ -0,0 +1,68 @@
#ifndef TRITON_ATTR_DEFS
#define TRITON_ATTR_DEFS
include "mlir/IR/EnumAttr.td"
// Attrs for LoadOp
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {
let cppNamespace = "::mlir::triton";
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> {
let cppNamespace = "::mlir::triton";
}
// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
[
I32EnumAttrCase<"AND", 1, "and">,
I32EnumAttrCase<"OR", 2, "or">,
I32EnumAttrCase<"XOR", 3, "xor">,
I32EnumAttrCase<"ADD", 4, "add">,
I32EnumAttrCase<"FADD", 5, "fadd">,
I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin">,
I32EnumAttrCase<"XCHG", 10, "exch">
]> {
let cppNamespace = "::mlir::triton";
}
#endif

View File

@ -0,0 +1,46 @@
#ifndef TRITON_DIALECT
#define TRITON_DIALECT
include "mlir/IR/OpBase.td"
def Triton_Dialect : Dialect {
let name = "tt";
let cppNamespace = "::mlir::triton";
let summary = "The Triton IR in MLIR";
let description = [{
Triton Dialect.
Dependent Dialects:
* Arithmetic:
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
* Math:
* exp, sin, cos, log, ...
* StructuredControlFlow:
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
}];
let dependentDialects = [
"arith::ArithmeticDialect",
"math::MathDialect",
"StandardOpsDialect",
"scf::SCFDialect",
// Since LLVM 15
// "cf::ControlFlowDialect",
// "func::FuncDialect"
];
let extraClassDeclaration = [{
void registerTypes();
}];
let hasConstantMaterializer = 1;
}
include "triton/Dialect/Triton/IR/TritonTypes.td"
#endif // TRITON_DIALECT

View File

@ -0,0 +1,11 @@
#ifndef TRITON_INTERFACES
#define TRITON_INTERFACES
include "mlir/IR/OpBase.td"
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
#endif // TRITON_INTERFACES

View File

@ -0,0 +1,423 @@
#ifndef TRITON_OPS
#define TRITON_OPS
include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
//
// Op Base
//
class TT_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
}
//
// CastOps
//
// Use cast ops in arith:
// bitcast
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
let arguments = (ins TT_I64Like:$from);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";
let arguments = (ins TT_PtrLike:$from);
let results = (outs TT_I64Like:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast between types of the same bitwidth";
let arguments = (ins TT_Type:$from);
let results = (outs TT_Type:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: Add verifier
}
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Floating point casting for custom types";
let description = [{
Floating point casting for custom types (F8).
F8 <-> FP16, BF16, FP32, FP64
}];
let arguments = (ins TT_FloatLike:$from);
let results = (outs TT_FloatLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: We need a verifier here.
}
//
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
}
//
// Load/Store Ops
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr", "getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none",
"result", "other", "$_self",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
// let assemblyFormat = "operands attr-dict `:` type($result)";
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
let hasCanonicalizer = 1;
}
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
SameOperandsEncoding,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
// let assemblyFormat = "operands attr-dict `:` type($value)";
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
let hasCanonicalizer = 1;
}
//
// Atomic Op
//
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"val", "ptr",
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"val", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "atomic rmw";
let description = [{
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
return old value at $ptr
}];
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
TT_Type:$val, Optional<TT_BoolLike>:$mask);
let results = (outs TT_Type:$result);
}
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic cas";
let description = [{
compare $cmp with data $old at location $ptr,
if $old == $cmp, store $val to $ptr,
else store $old to $ptr,
return $old
}];
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}
//
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "expand_dims";
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Dot Op
//
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
}];
}
//
// External elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
//
// Make Range Op
//
// TODO: should have ConstantLike as Trait
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let summary = "make range";
let description = [{
Returns an 1D int32 tensor.
Values span from $start to $end (exclusive), with step = 1
}];
let arguments = (ins I32Attr:$start, I32Attr:$end);
let results = (outs TT_IntTensor:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict ($args^ `:` type($args))?
}];
}
#endif // Triton_OPS

View File

@ -0,0 +1,71 @@
#ifndef TRITON_TYPES
#define TRITON_TYPES
include "triton/Dialect/Triton/IR/TritonDialect.td"
//
// Types
//
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
// Boolean Type
// TT_Bool -> I1
def TT_BoolTensor : TensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntTensor : TensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
// I32 Type
// TT_I32 -> I32
// TT_I32Tensor -> I32Tensor
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
// I64 Type
// TT_I64 -> I64
// TT_I64Tensor -> I64Tensor
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
// Pointer Type
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
Triton PointerType
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
#endif

View File

@ -0,0 +1,10 @@
#ifndef TRITON_IR_TYPES_H_
#define TRITON_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
add_public_tablegen_target(TritonTransformsIncGen)

View File

@ -0,0 +1,18 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
} // namespace triton
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
} // namespace mlir
#endif

View File

@ -0,0 +1,23 @@
#ifndef TRITON_PASSES
#define TRITON_PASSES
include "mlir/Pass/PassBase.td"
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
let summary = "combine ops";
let description = [{
dot(a, b, 0) + c => dot(a, b, c)
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)
}];
let constructor = "mlir::triton::createCombineOpsPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
/*SelectOp*/"mlir::StandardOpsDialect"];
}
#endif

View File

@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -0,0 +1,12 @@
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
add_public_tablegen_target(TritonGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

View File

@ -0,0 +1,46 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace gpu {
unsigned getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
SmallVector<unsigned> getContigPerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout);
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@ -0,0 +1,31 @@
#ifndef TRITON_GPU_IR_TRAITS_H_
#define TRITON_GPU_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
} // namespace impl
template <typename ConcreteType>
class ResultsAreSharedEncoding
: public TraitBase<ConcreteType, ResultsAreSharedEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreSharedEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@ -0,0 +1,481 @@
#ifndef TRITONGPU_ATTRDEFS
#define TRITONGPU_ATTRDEFS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
//===----------------------------------------------------------------------===//
class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
let description = [{
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
to the indices of the CUDA threads allowed to access some data at index $i$.
For example, let us consider the layout function:
\mathcal{L}(0, 0) = {0, 4}
\mathcal{L}(0, 1) = {1, 5}
\mathcal{L}(1, 0) = {2, 6}
\mathcal{L}(1, 1) = {3, 7}
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
- T[0,0] is owned by both cuda thread 0 and 4
- T[0,1] is owned by both cuda thread 1 and 5
- T[1,0] is owned by both cuda thread 2 and 6
- T[1,1] is owned by both cuda thread 3 and 7
Right now, Triton implements two classes of layouts: shared, and distributed.
}];
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> {
let mnemonic = "shared";
let description = [{
An encoding for tensors whose elements may be simultaneously accessed by
different cuda threads in the programs, via shared memory. In other words,
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
In order to avoid shared memory bank conflicts, elements may be swizzled
in memory. For example, a swizzled row-major layout could store its data
as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements
are stored contiguously
_ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
}];
let parameters = (
ins
// swizzle info
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
);
let builders = [
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"Type":$eltTy), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
int opIdx = dotOpEnc.getOpIdx();
// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// ---- begin Volta ----
if (mmaEnc.isVolta()) {
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
// TODO[Superjomn]: Support the case when is_vec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
is_vec4 = true;
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
}
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
// --- handle B operand ---
if (opIdx == 1) {
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
llvm_unreachable("invalid operand index");
}
// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
// Distributed Layout Encoding
//===----------------------------------------------------------------------===//
class DistributedEncoding<string name> : TritonGPU_Attr<name> {
let description = [{
Distributed encodings have a layout function that is entirely characterized
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
(or even the same rank) as the tensor it is encoding.
The layout function \mathcal{L} of this layout is then defined, for an
index `i` \in R^D, as follows:
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
For example, for a tensor/layout pair
A = [x x x x x x x x]
[x x x x x x x x]
L = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
}];
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
let mnemonic = "blocked";
let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
used to promote memory coalescing in LoadInst and StoreInst.
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
}>
}];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
// fastest-changing axis first
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
// These attributes can be inferred from the rest
// ArrayRefParameter<"unsigned">:$sizePerWarp,
// ArrayRefParameter<"unsigned">:$sizePerCTA
);
}
//===----------------------------------------------------------------------===//
// MMA Layout Encoding
//===----------------------------------------------------------------------===//
// TODO: MMAv1 and MMAv2 should be two instances of the same class
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
let mnemonic = "mma";
let description = [{
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'versionMajor' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A 'versionMinor' which indicates the specific layout of a tensor core
generation, e.g. for Volta, there might be multiple kinds of layouts annotated
by 0,1,2 and so on.
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Note: the layout is different from the recommended in PTX ISA
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, when versionMinor=1, the matrix L corresponding to
blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
[ ............................................................... ]
// -------------------------------- version = 2 --------------------------- //
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.16816 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0 warp 1
-----------------/\------------- ----------------/\-------------
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
warp 3 warp 4
----------------/\------------- ----------------/\-------------
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
}];
let parameters = (
ins
"unsigned":$versionMajor,
"unsigned":$versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
let builders = [
// specific for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"ArrayRef<unsigned>":$warpsPerCTA,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"bool":$isARow,
"bool":$isBRow), [{
assert(versionMajor == 1 && "Only MMAv1 has multiple versionMinor.");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool isVolta() const;
bool isAmpere() const;
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool> decodeVoltaLayoutStates() const;
}];
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
let mnemonic = "slice";
let description = [{
TODO: improve docs
A = [x x x x x x x x]
parent = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
dim = 0
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
}];
let parameters = (
ins
"unsigned":$dim,
// TODO: constraint here to only take distributed encodings
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration # [{
template<class T>
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
}];
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";
let description = [{
In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@ -0,0 +1,36 @@
#ifndef TRITONGPU_DIALECT
#define TRITONGPU_DIALECT
include "mlir/IR/OpBase.td"
def TritonGPU_Dialect : Dialect {
let name = "triton_gpu";
let cppNamespace = "::mlir::triton::gpu";
let hasOperationAttrVerify = 1;
let description = [{
Triton GPU Dialect.
}];
let dependentDialects = [
"triton::TritonDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
}];
}
#endif

View File

@ -0,0 +1,198 @@
#ifndef TRITONGPU_OPS
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape, NoSideEffect]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let summary = "async wait";
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 80;
}
}];
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntLike:$lhs,
TT_IntLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
TT_FloatLike:$lhs,
TT_FloatLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";
let description = [{}];
let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
let results = (outs TT_Tensor:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
let summary = "insert slice async";
let description = [{
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operations
`$index` argument and `$axis` attribute.
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
The insert_slice_async operation supports the following arguments:
* src: the tensor that is inserted.
* dst: the tensor into which the `$src` tensor is inserted.
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
* mask: optional tensor-rank number of boolean masks which specify which
elements of the `$src` tensor are inserted into the `$dst` tensor.
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
In the future, we may decompose this operation into a sequence of:
* `async` operation to specify a sequence of asynchronous operations
* `load` operation to load a tensor from global memory
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
Example:
```
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
%2 = triton_gpu.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
triiton_gpu.async_wait { num = 0 : i32 }
```
}];
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let builders = [
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"Value":$mask, "Value":$other,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
];
let results = (outs TT_Tensor:$result);
//let assemblyFormat = [{
// $src `,` $dst ``
// $index, $mask, $other
// attr-dict `:` type($src) `->` type($dst)
//}];
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}];
// The custom parser could be replaced with oilist in LLVM-16
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
ResultsAreSharedEncoding]> {
let summary = "allocate tensor";
let description = [{
This operation defines a tensor of a particular shape.
The contents of the tensor are supposed to be in shared memory.
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let results = (outs TT_Tensor:$result);
}
#endif

View File

@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU)
add_public_tablegen_target(TritonGPUTransformsIncGen)

View File

@ -0,0 +1,25 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
// TODO(Keren): prefetch pass not working yet
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUVerifier();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
} // namespace mlir
#endif

View File

@ -0,0 +1,87 @@
#ifndef TRITONGPU_PASSES
#define TRITONGPU_PASSES
include "mlir/Pass/PassBase.td"
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
Unroll loops to hide global memory -> shared memory latency.
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"2",
"number of pipeline stages">
];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";
let description = [{
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
}];
let constructor = "mlir::createTritonGPUPrefetchPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
let description = [{
TODO
}];
let constructor = "mlir::createTritonGPUCoalescePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";
let description = [{
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
convert_layout(%src, #LAYOUT_1)
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
}];
let constructor = "mlir::createTritonGPUCombineOpsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
let summary = "canonicalize scf.ForOp ops";
let description = [{
This implements some optimizations that are missing in the standard scf.ForOp
canonicalizer.
}];
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif

View File

@ -0,0 +1,33 @@
//===----------------------------------------------------------------------===//
//
// Defines utilities to use while converting to the TritonGPU dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
int getNumWarps() const { return numWarps; }
private:
MLIRContext *context;
int numWarps;
};
class TritonGPUConversionTarget : public ConversionTarget {
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx,
TritonGPUTypeConverter &typeConverter);
};
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_

View File

@ -0,0 +1,39 @@
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIRTRANSLATION_H
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
#include <vector>
namespace llvm {
class Module;
class LLVMContext;
} // namespace llvm
namespace mlir {
class ModuleOp;
} // namespace mlir
namespace mlir {
namespace triton {
// add external dependent libs
void addExternalLibs(mlir::ModuleOp &module,
const std::vector<std::string> &names,
const std::vector<std::string> &paths);
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
} // namespace triton
} // namespace mlir
#endif // TRITON_TARGET_LLVMIRTRANSLATION_H

View File

@ -0,0 +1,17 @@
#ifndef TRITON_TARGET_PTXTRANSLATION_H
#define TRITON_TARGET_PTXTRANSLATION_H
#include <string>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate TritonGPU IR to PTX code.
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
} // namespace triton
#endif

View File

@ -22,26 +22,32 @@
#ifndef TDL_TOOLS_SYS_GETENV_HPP
#define TDL_TOOLS_SYS_GETENV_HPP
#include <string>
#include <algorithm>
#include <cstdlib>
#include <string>
namespace triton
{
namespace triton {
namespace tools
{
inline std::string getenv(const char * name)
{
const char * cstr = std::getenv(name);
if(!cstr)
return "";
std::string result(cstr);
return result;
}
namespace tools {
inline std::string getenv(const char *name) {
const char *cstr = std::getenv(name);
if (!cstr)
return "";
std::string result(cstr);
return result;
}
inline bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}
} // namespace tools
} // namespace triton
#endif

View File

@ -1,87 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map>
#include <vector>
namespace triton {
namespace ir {
class value;
class module;
class phi_node;
class splat_inst;
class cast_inst;
class cmp_inst;
class reshape_inst;
class dequantize_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class align {
private:
struct cst_info {
unsigned num_cst;
unsigned value;
};
// helpers
std::vector<unsigned> get_shapes(ir::value *v);
// populate is_constant
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps
void populate(ir::value *v);
public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
std::vector<cst_info> get_cst_info(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
};
}
}
}
#endif

View File

@ -1,47 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#include <map>
#include <set>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
namespace triton{
namespace ir{
class value;
class function;
class module;
}
namespace codegen{
namespace analysis{
class tiles;
class liveness;
class cts;
class allocation {
public:
allocation(liveness *live)
: liveness_(live) { }
// accessors
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
void run(ir::module& mod);
private:
std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_;
// dependences
liveness *liveness_;
};
}
}
}
#endif

View File

@ -1,53 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
#include "triton/tools/graph.h"
#include <map>
#include <vector>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
private:
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_dequantize(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async=false);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);
public:
axes();
void run(ir::module &mod);
// accessors
int get(ir::value *value, unsigned dim);
std::vector<int> get(ir::value *value);
private:
tools::graph<node_t> graph_;
std::map<node_t, size_t> axes_;
};
}
}
}
#endif

View File

@ -1,370 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
#include "triton/codegen/target.h"
namespace triton{
namespace ir{
class value;
class type;
class module;
class instruction;
class phi_node;
}
namespace codegen{
namespace analysis{
class axes;
class align;
class layout_visitor;
class data_layout;
class mma_layout;
class scanline_layout;
class shared_layout;
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_mma(mma_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
MMA,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
std::vector<int> get_axes() const { return axes_; }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
int find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
virtual int contig_per_thread(size_t k) = 0;
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public:
enum TensorCoreType : uint8_t {
// floating-point tensor core instr
FP32_FP16_FP16_FP32 = 0, // default
FP32_BF16_BF16_FP32,
FP32_TF32_TF32_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
//
NOT_APPLICABLE,
};
// Used on nvidia GPUs with sm >= 80
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
{FP32_FP16_FP16_FP32, {16, 8, 16}},
{FP32_BF16_BF16_FP32, {16, 8, 16}},
{FP32_TF32_TF32_FP32, {16, 8, 8}},
{INT32_INT1_INT1_INT32, {16, 8, 256}},
{INT32_INT4_INT4_INT32, {16, 8, 64}},
{INT32_INT8_INT8_INT32, {16, 8, 32}},
};
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
{FP32_FP16_FP16_FP32, {8, 8, 8}},
{FP32_BF16_BF16_FP32, {8, 8, 8}},
{FP32_TF32_TF32_FP32, {8, 8, 4}},
{INT32_INT1_INT1_INT32, {8, 8, 64}},
{INT32_INT4_INT4_INT32, {8, 8, 32}},
{INT32_INT8_INT8_INT32, {8, 8, 16}},
};
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
{FP32_FP16_FP16_FP32, 8},
{FP32_BF16_BF16_FP32, 8},
{FP32_TF32_TF32_FP32, 4},
{INT32_INT1_INT1_INT32, 128},
{INT32_INT4_INT4_INT32, 32},
{INT32_INT8_INT8_INT32, 16},
};
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b,
ir::value *dot);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); }
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
// helpers for generator.cc
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
// setter
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
private:
// fragment per warp
std::vector<int> fpw_;
// shape per warp
std::vector<int> spw_;
// warp per tile
std::vector<int> wpt_;
// shape per tile
std::vector<int> spt_;
// repetitions
std::vector<int> rep_;
// contiguous per thread
std::vector<int> contig_per_thread_;
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
};
class scanline_layout: public distributed_layout {
public:
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align,
target* tgt);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
int contig_per_thread(size_t k) { return nts_.at(k); }
int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
private:
// micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_;
// nano tile size. The size of a tile held by a thread.
std::vector<int> nts_;
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
struct N_buffer_info_t {
std::vector<ir::value*> firsts; // not necessarily ordered as input order
ir::value* latch;
ir::phi_node* phi;
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and corresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
public:
shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align, target *tgt,
bool is_tmp = false);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
int get_num_stages() const;
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
size_t get_per_stage_elements() const;
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
bool is_tmp() const { return is_tmp_; }
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
std::shared_ptr<N_buffer_info_t> N_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
bool is_tmp_;
};
class layouts {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values);
void create_tmp_layout(size_t id, data_layout* arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
ir::instruction* i,
bool is_index = false);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// layout checkers
bool is_scanline(ir::instruction* i);
bool is_coalesced_scanline(ir::instruction* i);
bool is_mma(ir::instruction* i);
bool is_a100_mma(ir::instruction* i);
// execution
void run(ir::module &mod);
private:
analysis::axes* axes_;
analysis::align* align_;
size_t num_warps_;
target* tgt_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
std::map<ir::value*, size_t> tmp_index_;
};
}
}
}
#endif

View File

@ -1,69 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
#include "llvm/ADT/MapVector.h"
#include <set>
#include <vector>
namespace triton{
namespace ir{
class value;
class phi_node;
class function;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
typedef unsigned slot_index;
class tiles;
class layouts;
class data_layout;
struct segment {
slot_index start;
slot_index end;
bool contains(slot_index idx) const {
return start <= idx && idx < end;
}
bool intersect(const segment &Other){
return contains(Other.start) || Other.contains(start);
}
};
class liveness {
private:
typedef llvm::MapVector<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.lookup(v); }
// run
void run(ir::module &mod);
private:
// analysis
layouts *layouts_;
intervals_map_t intervals_;
};
}
}
}
#endif

View File

@ -1,43 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#include <map>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class target;
namespace analysis{
class layouts;
class data_layout;
class swizzle {
public:
// constructor
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
// accessors
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
int get_vec (data_layout* layout) { return vec_.at(layout); }
// run
void run(ir::module &mod);
private:
layouts* layouts_;
target* tgt_;
std::map<data_layout*, int> per_phase_;
std::map<data_layout*, int> max_phase_;
std::map<data_layout*, int> vec_;
};
}
}
}
#endif

View File

@ -1,90 +0,0 @@
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
#include <memory>
#include <string>
#include <map>
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
namespace triton {
namespace codegen {
///
/// \brief ExternLib is a class that represents a library of external functions.
///
class ExternLib {
public:
ExternLib(const std::string &name, const std::string &path)
: name_(name), path_(path) {}
virtual ~ExternLib() = default;
virtual const std::string &name() const { return name_; }
virtual const std::string &path() const { return path_; }
///
/// \brief Load the library and return the module.
///
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
///
/// \brief Link the module into the given module.
///
void link(std::unique_ptr<llvm::Module> &llvm,
std::unique_ptr<llvm::Module> &mod);
///
/// \brief Run load, link, and opt on the module.
///
virtual void install(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) {
auto mod = load(ctx);
link(llvm, mod);
opt(ctx, llvm);
}
///
/// \brief Run opt on the module.
///
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) = 0;
private:
std::string name_;
std::string path_;
};
///
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
///
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
///
/// \brief Concrete class for NVIDIA's libdevice library.
///
class LibDevice final : public ExternLib {
public:
LibDevice(const std::string &name, const std::string &path)
: ExternLib(name, path) {}
virtual ~LibDevice() = default;
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) override;
};
///
/// \brief Create an ExternLib instance based on the name and path.
///
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
const std::string &lib_path);
} // namespace codegen
} // namespace triton
#endif

View File

@ -1,41 +0,0 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <memory>
#include "extern_lib.h"
namespace llvm{
class Module;
class LLVMContext;
}
namespace triton{
namespace codegen {
class target;
}
namespace ir{
class module;
}
namespace driver{
class device;
class module;
class kernel;
}
}
namespace triton{
namespace codegen{
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
int num_warps, int num_stages, int &shared_static,
const ExternLibMap &extern_libs);
}
}
#endif

View File

@ -1,300 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_GENERATOR_H_
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/extern_lib.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class PHINode;
class BasicBlock;
class Attribute;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
class StructType;
}
namespace triton{
namespace ir{
class attribute;
class load_inst;
class store_inst;
}
namespace codegen{
// forward
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
class swizzle;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Attribute Attribute;
typedef llvm::BasicBlock BasicBlock;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
class target;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class adder{
public:
adder(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class multiplier{
public:
multiplier(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class geper{
public:
geper(Builder** builder): builder_(builder) { }
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
private:
Builder** builder_;
};
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void init_idx(ir::value *x);
Instruction* add_barrier();
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
void packed_type(ir::value* i);
void forward_declare(ir::function* fn);
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
private:
typedef std::function<void(
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn, bool is_first)>
acc_fn_t;
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps);
void visit_value(ir::value* v);
void visit_call_inst(ir::call_inst*);
void visit_launch_inst(ir::launch_inst *);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
void visit_dequantize_inst(ir::dequantize_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_load_inst(ir::load_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_cat_inst(ir::cat_inst*);
void visit_extract_value_inst(ir::extract_value_inst *);
void visit_insert_value_inst(ir::insert_value_inst *);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_cos_inst(ir::cos_inst*);
void visit_umulhi_inst(ir::umulhi_inst* x);
void visit_sin_inst(ir::sin_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
// Add a new external library based on given name and path if it doesn't exist
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
// Get all external libraries
const ExternLibMap &get_extern_lib_map() {
return extern_lib_map_;
}
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
target *tgt_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *shmem_;
std::set<ir::value*> seen_;
unsigned num_warps_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
/// layout -> base ptr
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
/// offset for double-buffered layout
std::map<analysis::data_layout*, Value*> shared_off_;
/// Base shmem pointer of ir value
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
std::map<ir::value*, Function*> fns_;
// helper for creating llvm values
adder add;
multiplier mul;
geper gep;
/// PHI nodes
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
// Eviction policies
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
};
}
}
#endif

View File

@ -1,105 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
class nvidia_cu_target;
class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
virtual ~target() {}
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
nvidia_cu_target* as_nvidia();
bool is_gpu() const;
private:
bool is_gpu_;
};
class amd_cl_target: public target {
public:
amd_cl_target(): target(true){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class nvidia_cu_target: public target {
public:
nvidia_cu_target(int sm): target(true), sm_(sm){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
int sm() { return sm_; }
unsigned guaranteed_alignment() { return 16; }
private:
int sm_;
};
class cpu_target: public target {
public:
cpu_target(): target(false){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 1; }
};
}
}
#endif

View File

@ -1,49 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace ir {
class module;
class value;
class io_inst;
class instruction;
class builder;
}
namespace codegen{
namespace analysis{
class align;
class layouts;
class cts;
}
namespace transform{
class coalesce {
private:
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
bool has_sm80_;
analysis::align* align_;
analysis::layouts* layout_;
};
}
}
}
#endif

View File

@ -1,44 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#include <set>
#include <map>
namespace triton {
namespace ir {
class module;
class value;
class phi_node;
class instruction;
class builder;
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class cts {
private:
bool is_shmem_op(ir::instruction* i, int op);
bool is_shmem_res(ir::value* i);
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
public:
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
void run(ir::module &mod);
private:
bool has_sm80_;
analysis::layouts* layouts_;
};
}
}
}
#endif

View File

@ -1,24 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class dce {
public:
dce() {}
void run(ir::module &mod);
};
}
}
}
#endif

View File

@ -1,22 +0,0 @@
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class disassociate {
public:
void run(ir::module &mod);
};
}
}
}
#endif

View File

@ -1,31 +0,0 @@
#pragma once
#include <list>
namespace triton {
namespace ir {
class module;
class function;
class call_inst;
class builder;
}
namespace codegen{
namespace transform{
struct fncmp {
bool operator()(ir::function* x, ir::function* y) const;
};
class inliner {
public:
inliner() {}
void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list<ir::call_inst*>& callsites);
void run(ir::module &mod);
};
}
}
}

View File

@ -1,72 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
namespace codegen{
namespace analysis{
class allocation;
class liveness;
class layouts;
class cts;
class shared_layout;
}
namespace transform{
class prefetch;
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
bool intersect(const val_set_t &X, const val_set_t &Y);
bool check_safe_war(ir::instruction* i);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
transform::prefetch *prefetch, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
transform::prefetch *prefetch_;
target* tgt_;
};
}
}
}
#endif

View File

@ -1,56 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class value;
class instruction;
class trans_inst;
class builder;
class constant_int;
class dot_inst;
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
private:
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};
}
}
}
#endif

View File

@ -1,30 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async, int num_stages)
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
void run(ir::module &module);
private:
bool has_copy_async_;
int num_stages_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@ -1,27 +0,0 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#include <set>
// forward dclaration
namespace triton::ir{
class module;
class value;
}
namespace triton::codegen {
class target;
}
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
std::set<ir::value*> prefetched_vals_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
};
}
#endif

View File

@ -1,26 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
namespace triton {
// forward declaration
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class reorder {
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@ -1,318 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_DISPATCH_H_
#define _TRITON_DRIVER_DISPATCH_H_
#include <type_traits>
#include <dlfcn.h>
//CUDA Backend
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions
#include <iostream>
#include <stdexcept>
namespace llvm {
class PassRegistry;
class Module;
}
namespace triton
{
namespace driver
{
class cu_context;
template<class T> void check(T){}
void check(CUresult err);
void check(hipError_t err);
class dispatch
{
protected:
template <class F>
struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
initializer();
if(cache == nullptr){
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
check(res);
return res;
}
public:
static void release();
// Nvidia
static bool nvmlinit();
static bool cuinit();
// AMD
static bool hipinit();
/* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
/* ------------------- *
* HIP
* ------------------- */
// context management
static hipError_t hipInit(unsigned int Flags);
static hipError_t hipCtxDestroy(hipCtx_t ctx);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
static hipError_t hipCtxGetDevice(hipDevice_t* result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
static hipError_t hipDriverGetVersion(int *driverVersion);
// device management
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
static hipError_t hipGetDeviceCount(int *count);
// module management
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
static hipError_t hipModuleUnload(hipModule_t hmod);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
// stream management
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
static hipError_t hipStreamSynchronize(hipStream_t hStream);
static hipError_t hipStreamDestroy(hipStream_t hStream);
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
// function management
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
// memory management
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
static hipError_t hipFree(hipDeviceptr_t dptr);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
// event management
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
private:
// Libraries
static void* cuda_;
static void* nvml_;
static void* hip_;
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuCtxCreate_v2_;
static void* cuCtxGetDevice_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuCtxEnablePeerAccess_;
static void* cuDriverGetVersion_;
static void* cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
// link management
static void* cuLinkAddFile_v2_;
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;
static void* cuLinkComplete_;
// module management
static void* cuModuleGetGlobal_v2_;
static void* cuModuleLoad_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuModuleLoadData_;
static void* cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuLaunchKernel_;
// function management
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
// memory management
static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
/* ------------------- *
* HIP
* ------------------- */
// context management
static void* hipInit_;
static void* hipCtxDestroy_;
static void* hipCtxCreate_;
static void* hipCtxPushCurrent_;
static void* hipCtxPopCurrent_;
static void* hipCtxGetDevice_;
static void* hipCtxEnablePeerAccess_;
static void* hipDriverGetVersion_;
// device management
static void* hipGetDevice_;
static void* hipDeviceGetName_;
static void* hipDeviceGetPCIBusId_;
static void* hipDeviceGetAttribute_;
static void* hipGetDeviceCount_;
// module management
static void* hipModuleGetGlobal_;
static void* hipModuleLoad_;
static void* hipModuleLoadData_;
static void* hipModuleUnload_;
static void* hipModuleLoadDataEx_;
static void* hipModuleGetFunction_;
// stream management
static void* hipStreamCreate_;
static void* hipStreamSynchronize_;
static void* hipStreamDestroy_;
static void* hipModuleLaunchKernel_;;
// function management
static void* hipFuncGetAttributes_;
static void* hipFuncSetAttribute_;
static void* hipFuncSetCacheConfig_;
// memory management
static void* hipMalloc_;
static void* hipPointerGetAttribute_;
static void* hipMemsetD8Async_;
static void* hipMemcpyDtoH_;
static void* hipFree_;
static void* hipMemcpyDtoHAsync_;
static void* hipMemcpyHtoDAsync_;
static void* hipMemcpyHtoD_;
// event management
static void* hipEventCreate_;
static void* hipEventElapsedTime_;
static void* hipEventRecord_;
static void* hipEventDestroy_;
};
}
}
#endif

View File

@ -1,220 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_ERROR_H_
#define _TRITON_DRIVER_ERROR_H_
#include <exception>
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
namespace exception
{
namespace nvrtc
{
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
#undef TRITON_CREATE_NVRTC_EXCEPTION
}
namespace cuda
{
class base: public std::exception{};
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
namespace cublas
{
class base: public std::exception{};
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUBLAS_EXCEPTION
}
namespace cudnn
{
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
}
namespace hip
{
class base: public std::exception{};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "HIP: Error- " msg; } }
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
}
}
}
#endif

View File

@ -1,20 +0,0 @@
#include <string>
#include "triton/driver/dispatch.h"
namespace llvm{
class Module;
}
namespace triton{
namespace driver{
void init_llvm();
std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,288 +0,0 @@
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
// we can make that yield a warning
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
#include <cstddef>
typedef enum hipError_t {
hipSuccess = 0, ///< Successful completion.
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
///< or not in an acceptable range.
hipErrorOutOfMemory = 2,
// Deprecated
hipErrorMemoryAllocation = 2, ///< Memory allocation error.
hipErrorNotInitialized = 3,
// Deprecated
hipErrorInitializationError = 3,
hipErrorDeinitialized = 4,
hipErrorProfilerDisabled = 5,
hipErrorProfilerNotInitialized = 6,
hipErrorProfilerAlreadyStarted = 7,
hipErrorProfilerAlreadyStopped = 8,
hipErrorInvalidConfiguration = 9,
hipErrorInvalidPitchValue = 12,
hipErrorInvalidSymbol = 13,
hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer
hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction
hipErrorInsufficientDriver = 35,
hipErrorMissingConfiguration = 52,
hipErrorPriorLaunchFailure = 53,
hipErrorInvalidDeviceFunction = 98,
hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices
hipErrorInvalidDevice = 101, ///< DeviceID must be in range 0...#compute-devices.
hipErrorInvalidImage = 200,
hipErrorInvalidContext = 201, ///< Produced when input context is invalid.
hipErrorContextAlreadyCurrent = 202,
hipErrorMapFailed = 205,
// Deprecated
hipErrorMapBufferObjectFailed = 205, ///< Produced when the IPC memory attach failed from ROCr.
hipErrorUnmapFailed = 206,
hipErrorArrayIsMapped = 207,
hipErrorAlreadyMapped = 208,
hipErrorNoBinaryForGpu = 209,
hipErrorAlreadyAcquired = 210,
hipErrorNotMapped = 211,
hipErrorNotMappedAsArray = 212,
hipErrorNotMappedAsPointer = 213,
hipErrorECCNotCorrectable = 214,
hipErrorUnsupportedLimit = 215,
hipErrorContextAlreadyInUse = 216,
hipErrorPeerAccessUnsupported = 217,
hipErrorInvalidKernelFile = 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX
hipErrorInvalidGraphicsContext = 219,
hipErrorInvalidSource = 300,
hipErrorFileNotFound = 301,
hipErrorSharedObjectSymbolNotFound = 302,
hipErrorSharedObjectInitFailed = 303,
hipErrorOperatingSystem = 304,
hipErrorInvalidHandle = 400,
// Deprecated
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
hipErrorNotFound = 500,
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
///< ready. This is not actually an error, but is used to distinguish
///< from hipSuccess (which indicates completion). APIs that return
///< this error include hipEventQuery and hipStreamQuery.
hipErrorIllegalAddress = 700,
hipErrorLaunchOutOfResources = 701, ///< Out of resources error.
hipErrorLaunchTimeOut = 702,
hipErrorPeerAccessAlreadyEnabled =
704, ///< Peer access was already enabled from the current device.
hipErrorPeerAccessNotEnabled =
705, ///< Peer access was never enabled from the current device.
hipErrorSetOnActiveProcess = 708,
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
hipErrorHostMemoryAlreadyRegistered =
712, ///< Produced when trying to lock a page-locked memory.
hipErrorHostMemoryNotRegistered =
713, ///< Produced when trying to unlock a non-page-locked memory.
hipErrorLaunchFailure =
719, ///< An exception occurred on the device while executing a kernel.
hipErrorCooperativeLaunchTooLarge =
720, ///< This error indicates that the number of blocks launched per grid for a kernel
///< that was launched via cooperative launch APIs exceeds the maximum number of
///< allowed blocks for the current device
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
hipErrorUnknown = 999, //< Unknown error.
// HSA Runtime Error Codes start here.
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
///< in production systems.
hipErrorRuntimeOther = 1053, ///< HSA runtime call other than memory returned error. Typically
///< not seen in production systems.
hipErrorTbd ///< Marker that more error codes are needed.
} hipError_t;
typedef struct ihipCtx_t* hipCtx_t;
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
typedef int hipDevice_t;
typedef enum hipDeviceP2PAttr {
hipDevP2PAttrPerformanceRank = 0,
hipDevP2PAttrAccessSupported,
hipDevP2PAttrNativeAtomicSupported,
hipDevP2PAttrHipArrayAccessSupported
} hipDeviceP2PAttr;
typedef struct ihipStream_t* hipStream_t;
#define hipIpcMemLazyEnablePeerAccess 0
#define HIP_IPC_HANDLE_SIZE 64
typedef struct hipIpcMemHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcMemHandle_t;
typedef struct hipIpcEventHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcEventHandle_t;
typedef struct ihipModule_t* hipModule_t;
typedef struct ihipModuleSymbol_t* hipFunction_t;
typedef struct hipFuncAttributes {
int binaryVersion;
int cacheModeCA;
size_t constSizeBytes;
size_t localSizeBytes;
int maxDynamicSharedSizeBytes;
int maxThreadsPerBlock;
int numRegs;
int preferredShmemCarveout;
int ptxVersion;
size_t sharedSizeBytes;
} hipFuncAttributes;
typedef struct ihipEvent_t* hipEvent_t;
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
///< bytes.
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
///< thread block. This number is shared by all thread
///< blocks simultaneously resident on a
///< multiprocessor.
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
///< cache.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
///< multiprocessor.
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
///< concurrently.
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
///< Multiprocessor.
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeIntegrated, ///< iGPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///devices with unmatched shared memories
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
/// the device without migration
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
/// concurrently with the CPU
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
/// without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
/// the host's page tables
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64() , '0' otherwise.
} hipDeviceAttribute_t;
typedef void* hipDeviceptr_t;
/*
* @brief hipJitOption
* @enum
* @ingroup Enumerations
*/
typedef enum hipJitOption {
hipJitOptionMaxRegisters = 0,
hipJitOptionThreadsPerBlock,
hipJitOptionWallTime,
hipJitOptionInfoLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionOptimizationLevel,
hipJitOptionTargetFromContext,
hipJitOptionTarget,
hipJitOptionFallbackStrategy,
hipJitOptionGenerateDebugInfo,
hipJitOptionLogVerbose,
hipJitOptionGenerateLineInfo,
hipJitOptionCacheMode,
hipJitOptionSm3xOpt,
hipJitOptionFastCompile,
hipJitOptionNumOptions
} hipJitOption;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncAttribute {
hipFuncAttributeMaxDynamicSharedMemorySize = 8,
hipFuncAttributePreferredSharedMemoryCarveout = 9,
hipFuncAttributeMax
} hipFuncAttribute;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncCache_t {
hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default)
hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 cache
hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
} hipFuncCache_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)

View File

@ -1,92 +0,0 @@
#pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_
#include <string>
#include <list>
#include "value.h"
#include "visitor.h"
namespace triton{
namespace ir{
class context;
class function;
class instruction;
/* Basic Block */
class basic_block: public value{
public:
// instruction iterator types
typedef std::list<instruction*> inst_list_t;
typedef inst_list_t::iterator iterator;
typedef inst_list_t::const_iterator const_iterator;
typedef inst_list_t::reverse_iterator reverse_iterator;
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
private:
// constructors
basic_block(context &ctx, const std::string &name, function *parent, basic_block *next);
public:
// accessors
function* get_parent() { return parent_; }
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
void replace_phi_uses_with(basic_block* before, basic_block* after);
iterator get_first_non_phi();
// get instruction list
inst_list_t &get_inst_list() { return inst_list_; }
const inst_list_t &get_inst_list() const { return inst_list_; }
void erase(instruction *i) { inst_list_.remove(i); }
// instruction iterator functions
inline iterator begin() { return inst_list_.begin(); }
inline const_iterator begin() const { return inst_list_.begin(); }
inline iterator end () { return inst_list_.end(); }
inline const_iterator end () const { return inst_list_.end(); }
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
inline reverse_iterator rend () { return inst_list_.rend(); }
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
inline size_t size() const { return inst_list_.size(); }
inline bool empty() const { return inst_list_.empty(); }
inline const instruction &front() const { return *inst_list_.front(); }
inline instruction &front() { return *inst_list_.front(); }
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
void append_instruction(ir::instruction* i);
// split
basic_block* split_before(ir::instruction* loc, const std::string& name);
// predecessors
std::vector<basic_block*> get_predecessors() const;
std::vector<basic_block*> get_successors() const;
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr);
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_basic_block(this); }
private:
context &ctx_;
std::string name_;
function *parent_;
std::vector<basic_block*> preds_;
std::vector<basic_block*> succs_;
inst_list_t inst_list_;
};
}
}
#endif

View File

@ -1,212 +0,0 @@
#pragma once
#ifndef _TRITON_IR_BUILDER_H_
#define _TRITON_IR_BUILDER_H_
#include <vector>
#include <string>
#include "instructions.h"
#include "basic_block.h"
#include "type.h"
namespace triton{
namespace ir{
class basic_block;
class value;
class type;
class constant_int;
class instruction;
class context;
class phi_node;
/* Builder */
class builder{
public:
typedef basic_block::iterator iterator;
public:
// Constructor
builder(context &ctx);
// Getters
// const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
void set_insert_point_after(instruction* i);
void set_insert_point(basic_block* block);
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(uint32_t val);
value *get_int64(uint64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();
type *get_int1_ty();
type *get_int8_ty();
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_fp8_ty();
type *get_half_ty();
type *get_bf16_ty();
type *get_float_ty();
type *get_double_ty();
// Insert
template<typename InstTy>
InstTy* insert(InstTy *inst){
assert(block_);
block_->get_inst_list().insert(insert_point_, inst);
inst->set_parent(block_);
// for(ir::value* op: inst->ops())
// op->add_use(inst);
return inst;
}
// terminator instructions
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
value* create_ret(value *ret);
// Dequantize instructions
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
// Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_int_to_ptr(value *src, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
value* create_fp_to_si(value *src, type *dst_ty);
value* create_fp_to_ui(value *src, type *dst_ty);
value* create_fp_ext(value *src, type *dst_ty);
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Call instruction
value* create_call(function* fn, const std::vector<value*>& args);
value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs);
value *create_fdiv(value *lhs, value *rhs);
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
// Comparison (int)
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_icmpSLE(value *lhs, value *rhs);
value *create_icmpSLT(value *lhs, value *rhs);
value *create_icmpSGE(value *lhs, value *rhs);
value *create_icmpSGT(value *lhs, value *rhs);
value *create_icmpULE(value *lhs, value *rhs);
value *create_icmpULT(value *lhs, value *rhs);
value *create_icmpUGE(value *lhs, value *rhs);
value *create_icmpUGT(value *lhs, value *rhs);
value *create_icmpEQ(value *lhs, value *rhs);
value *create_icmpNE(value *lhs, value *rhs);
// Comparison (float)
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_fcmpOLT(value *lhs, value *rhs);
value *create_fcmpOGT(value *lhs, value *rhs);
value *create_fcmpOLE(value *lhs, value *rhs);
value *create_fcmpOGE(value *lhs, value *rhs);
value *create_fcmpOEQ(value *lhs, value *rhs);
value *create_fcmpONE(value *lhs, value *rhs);
value *create_fcmpULT(value *lhs, value *rhs);
value *create_fcmpUGT(value *lhs, value *rhs);
value *create_fcmpULE(value *lhs, value *rhs);
value *create_fcmpUGE(value *lhs, value *rhs);
value *create_fcmpUEQ(value *lhs, value *rhs);
value *create_fcmpUNE(value *lhs, value *rhs);
// Logical
value *create_and(value *lhs, value *rhs);
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
// Struct instructions
value *create_insert_value(value* val, value *elt, size_t idx);
value *create_extract_value(value* val, size_t idx);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Atomic instruction
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_atomic_max(value *ptr, value *val, value *msk);
value *create_atomic_umax(value *ptr, value *val, value *msk);
value *create_atomic_min(value *ptr, value *val, value *msk);
value *create_atomic_umin(value *ptr, value *val, value *msk);
value *create_atomic_fadd(value *ptr, value *val, value *msk);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_and(value *ptr, value *val, value *msk);
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Utilities
value *create_clock();
value *create_globaltimer();
// Extern instruction
value *create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty);
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_exp(value* arg);
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
// These have no place in the IR, and hopefully they can be removed at some point
value *create_umulhi(value* lhs, value* rhs);
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
value *create_prefetch_s(value *arg, int inc);
private:
context &ctx_;
basic_block *block_;
iterator insert_point_;
};
}
}
#endif

View File

@ -1,113 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONSTANT_H_
#define _TRITON_IR_CONSTANT_H_
#include "enums.h"
#include "value.h"
#include <cassert>
#include "visitor.h"
namespace triton{
namespace ir{
class type;
class context;
/* Constant */
class constant: public user{
protected:
using user::user;
public:
static constant* get_all_ones_value(type *ty);
static constant* get_null_value(type *ty);
virtual std::string repr() const = 0;
};
/* Undef value */
class undef_value: public constant{
private:
undef_value(type *ty);
public:
static undef_value* get(type* ty);
std::string repr() const { return "undef"; }
void accept(visitor* vst) { vst->visit_undef_value(this); }
};
/* Constant int */
class constant_int: public constant{
protected:
constant_int(type *ty, uint64_t value);
public:
virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_int(this); }
protected:
uint64_t value_;
};
/* Constant fp */
class constant_fp: public constant{
constant_fp(type *ty, double value);
public:
double get_value() { return value_; }
static constant* get_negative_zero(type *ty);
static constant* get_zero_value_for_negation(type *ty);
static constant* get(context &ctx, double v);
static constant* get(type *ty, double v);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_fp(this); }
private:
double value_;
};
/* Global Value */
class global_value: public constant {
public:
enum linkage_types_t {
external
};
public:
global_value(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space);
std::string repr() const { return get_name(); }
private:
linkage_types_t linkage_;
};
/* global object */
class global_object: public global_value {
public:
global_object(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space = 0);
std::string repr() const { return get_name(); }
};
/* global variable */
class alloc_const: public global_object {
public:
alloc_const(type *ty, constant_int *size,
const std::string &name = "");
std::string repr() const { return get_name(); }
void accept(visitor* vst) { vst->visit_alloc_const(this); }
};
}
}
#endif

View File

@ -1,29 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_H_
#define _TRITON_IR_CONTEXT_H_
#include <memory>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class type;
class context_impl;
/* Context */
class context {
public:
context();
context(const context&) = delete;
context& operator=(const context&) = delete;
public:
std::shared_ptr<context_impl> p_impl;
};
}
}
#endif

View File

@ -1,47 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_
#include "triton/ir/type.h"
#include "triton/ir/constant.h"
#include <map>
#include <memory>
namespace triton{
namespace ir{
class context;
/* Context impl */
class context_impl {
public:
// constructors
context_impl(context &ctx);
public:
// non-numeric types
type void_ty, label_ty;
// floating point types
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
// Struct types
std::map<type::contained_tys_vec_t, struct_type*> struct_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
// Float constants
std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
// undef values
std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
};
}
}
#endif

View File

@ -1,187 +0,0 @@
#pragma once
#ifndef _TRITON_IR_ENUMS_H_
#define _TRITON_IR_ENUMS_H_
namespace triton{
namespace ir{
enum binary_op_t: unsigned int{
Add,
FAdd,
Sub,
FSub,
Mul,
FMul,
UDiv,
SDiv,
FDiv,
URem,
SRem,
FRem,
Shl,
LShr,
AShr,
And,
Or,
Xor
};
enum class atomic_rmw_op_t: unsigned int{
And,
Or,
Xor,
Add,
Max,
Min,
UMax,
UMin,
FAdd,
Xchg,
};
enum cast_op_t: unsigned int {
Trunc,
ZExt,
SExt,
FPTrunc,
FPExt,
UIToFP,
SIToFP,
FPToUI,
FPToSI,
PtrToInt,
IntToPtr,
BitCast,
AddrSpaceCast
};
enum cmp_pred_t: unsigned int {
FIRST_FCMP_PREDICATE,
FCMP_FALSE,
FCMP_OEQ,
FCMP_OGT,
FCMP_OGE,
FCMP_OLT,
FCMP_OLE,
FCMP_ONE,
FCMP_ORD,
FCMP_UNO,
FCMP_UEQ,
FCMP_UGT,
FCMP_UGE,
FCMP_ULT,
FCMP_ULE,
FCMP_UNE,
FCMP_TRUE,
LAST_FCMP_PREDICATE,
FIRST_ICMP_PREDICATE,
ICMP_EQ,
ICMP_NE,
ICMP_UGT,
ICMP_UGE,
ICMP_ULT,
ICMP_ULE,
ICMP_SGT,
ICMP_SGE,
ICMP_SLT,
ICMP_SLE,
LAST_ICMP_PREDICATE
};
enum value_id_t: unsigned {
/* ------------ *
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// call
INST_CALL,
INST_LAUNCH,
// phi
INST_PHI,
// arithmetic
INST_BINOP,
INST_GETELEMENTPTR,
INST_SELECT,
INST_SQRT,
// cmp
INST_ICMP,
INST_FCMP,
// dequantize
INST_DEQUANTIZE,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
INST_CAST_SEXT,
INST_CAST_FP_TRUNC,
INST_CAST_FP_EXT,
INST_CAST_UI_TO_FP,
INST_CAST_SI_TO_FP,
INST_CAST_FP_TO_UI,
INST_CAST_FP_TO_SI,
INST_CAST_PTR_TO_INT,
INST_CAST_INT_TO_PTR,
INST_CAST_BIT_CAST,
INST_CAST_ADDR_SPACE_CAST,
// terminators
INST_RETURN,
INST_COND_BRANCH,
INST_UNCOND_BRANCH,
// io
INST_UNMASKED_LOAD,
INST_MASKED_LOAD,
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// struct
INST_EXTRACT_VALUE,
INST_INSERT_VALUE,
// retile
INST_RESHAPE,
INST_SPLAT,
INST_CAT,
INST_BROADCAST,
INST_DOWNCAST,
// builtin
INST_GET_PROGRAM_ID,
INST_GET_NUM_PROGRAMS,
// atomics
INST_ATOMIC_CAS,
INST_ATOMIC_EXCH,
INST_ATOMIC_RMW,
// math
INST_UMULHI,
INST_EXP,
INST_COS,
INST_SIN,
INST_LOG,
// extern
INST_EXTERN_ELEMENTWISE,
// array arithmetic
INST_TRANS,
INST_REDUCE,
INST_DOT,
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_CVT_LAYOUT,
INST_CVT_SCANLINE,
INST_DECOALESCE,
INST_RECOALESCE,
INST_BARRIER,
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
};
}
}
#endif

View File

@ -1,145 +0,0 @@
#pragma once
#ifndef _TRITON_IR_FUNCTION_H_
#define _TRITON_IR_FUNCTION_H_
#include <string>
#include <map>
#include "value.h"
#include "constant.h"
namespace triton{
namespace ir{
class function;
class function_type;
class module;
class basic_block;
/* Argument */
class argument: public value{
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
void accept(visitor *v);
private:
function *parent_;
unsigned arg_no_;
};
/* Attribute */
enum attribute_kind_t {
readonly = 0,
writeonly,
noalias,
aligned,
multiple_of,
retune,
not_implemented
};
class attribute {
public:
attribute(attribute_kind_t kind, unsigned value = 0):
kind_(kind), value_(value){}
bool operator<(const attribute& other) const {
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
}
attribute_kind_t get_kind() const {
return kind_;
}
unsigned get_value() const {
return value_;
}
bool is_llvm_attr() const {
return kind_ != multiple_of;
}
std::string repr() const {
switch(kind_){
case readonly: return ".readonly";
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
case retune: return ".retunr";
default: break;
}
assert(false);
return "";
}
private:
attribute_kind_t kind_;
unsigned value_;
};
/* Function */
class function: public global_object{
typedef std::vector<argument*> args_t;
typedef args_t::iterator arg_iterator;
typedef args_t::const_iterator const_arg_iterator;
typedef std::vector<basic_block*> blocks_t;
typedef blocks_t::iterator block_iterator;
typedef blocks_t::const_iterator const_block_iterator;
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
private:
function(function_type *ty, linkage_types_t linkage,
const std::string &name = "", module *parent = nullptr);
public:
// accessors
const args_t &args() const { return args_; }
function_type* get_fn_type() { return fn_ty_; }
const function_type* get_fn_type() const { return fn_ty_; }
module *get_parent() { return parent_; }
const module *get_parent() const { return parent_; }
// factory methods
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
bool get_is_kernel() { return is_kernel_; }
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_function(this); }
private:
module *parent_;
bool init_;
function_type *fn_ty_;
args_t args_;
blocks_t blocks_;
attr_map_t attrs_;
bool is_kernel_;
};
}
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -1,34 +0,0 @@
#pragma once
#ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_
#include <vector>
namespace triton{
namespace ir{
/* Metadata */
class metadata{
public:
enum kind_t{
multiple_of,
max_contiguous
};
private:
metadata(kind_t kind, std::vector<unsigned> value);
public:
static metadata* get(kind_t kind, std::vector<unsigned> value);
private:
kind_t kind_;
std::vector<unsigned> value_;
};
}
}
#endif

View File

@ -1,129 +0,0 @@
#pragma once
#ifndef _TRITON_IR_MODULE_H_
#define _TRITON_IR_MODULE_H_
#include <map>
#include <set>
#include <stack>
#include <string>
#include <functional>
#include "triton/ir/builder.h"
#include "triton/ir/metadata.h"
#include "triton/ir/context.h"
namespace triton{
namespace lang{
class iteration_statement;
class compound_statement;
}
namespace ir{
class basic_block;
class phi_node;
class value;
class context;
class function;
class attribute;
class function_type;
class constant;
class global_value;
class alloc_const;
class value_constructor {
typedef std::pair<std::string, basic_block*> val_key_t;
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
public:
value_constructor(builder &builder);
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Metadata
private:
ir::builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
std::map<value*, value**> current_phi_;
};
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, std::vector<unsigned>> md_pair_t;
friend class function;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
builder &get_builder() { return builder_; };
const std::string& get_name() { return name_; };
// Functions
const functions_list_t &get_function_list() const { return functions_; }
function *get_function(const std::string& name) {
if(symbols_.find(name) == symbols_.end())
throw std::runtime_error("function " + name + " is not declared");
return (function*)symbols_.at(name);
}
function *get_or_insert_function(const std::string &name, function_type *ty);
bool has_function(const std::string& name){
return symbols_.find(name) != symbols_.end();
}
void remove_function(ir::function* fn){
functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end());
}
void reset_ret_ty(const std::string& name, type* ty);
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
// Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void print(std::ostream &os);
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
private:
std::string name_;
builder &builder_;
functions_list_t functions_;
symbols_map_t symbols_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
};
}
}
#endif

View File

@ -1,22 +0,0 @@
#ifndef _TRITON_IR_PRINT_H_
#define _TRITON_IR_PRINT_H_
#include "builder.h"
namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
void print(module &mod, std::ostream& os);
void print(function &func, std::ostream& os);
void print(basic_block &bb, std::ostream& os);
void print(instruction &instr, std::ostream& os);
}
}
#endif

View File

@ -1,252 +0,0 @@
#pragma once
#ifndef _TRITON_IR_TYPE_H_
#define _TRITON_IR_TYPE_H_
#include <cassert>
#include <vector>
#include <string>
#include <stdexcept>
namespace triton{
namespace ir{
class context;
class value;
class integer_type;
class constant_int;
/* Type */
class type {
public:
typedef std::vector<unsigned> block_shapes_t;
typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator;
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
public:
enum id_t {
// primitive types
VoidTyID = 0, ///< type with no size
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
FP32TyID, ///< 32-bit floating point type
FP64TyID, ///< 64-bit floating point type
LabelTyID, ///< Labels
MetadataTyID, ///< Metadata
TokenTyID, ///< Token
// derived types
IntegerTyID, ///< Arbitrary bit width integers
FunctionTyID, ///< Functions
PointerTyID, ///< Pointers
StructTyID, ///< Struct
BlockTyID, ///< Block
};
public:
//constructors
type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
//destructor
virtual ~type(){}
// accessors
context &get_context() const { return ctx_; }
id_t get_type_id() const { return id_; }
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const;
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
block_shapes_t get_block_shapes() const;
const size_t get_tile_rank() const;
const size_t get_tile_ranks1() const;
unsigned get_tile_num_elements() const;
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;
type *get_pointer_element_ty() const;
unsigned get_struct_numel() const { return contained_tys_.size(); }
type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_fp16_ty() const { return id_ == FP16TyID; }
bool is_bf16_ty() const { return id_ == BF16TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
bool is_struct_ty() const { return id_ == StructTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
bool is_integer_ty(unsigned width) const;
bool is_floating_point_ty() const;
bool is_sized() const ;
// Factory methods
// primitive types
static type *get_void_ty(context &ctx);
static type *get_label_ty(context &ctx);
// half
static type *get_fp8_ty(context &ctx);
static type *get_fp16_ty(context &ctx);
static type *get_bf16_ty(context &ctx);
static type *get_fp32_ty(context &ctx);
static type *get_fp64_ty(context &ctx);
// integer types
static integer_type *get_int1_ty(context &ctx);
static integer_type *get_int8_ty(context &ctx);
static integer_type *get_int16_ty(context &ctx);
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
// repr
std::string tile_repr() const {
std::string res = get_tile_element_ty()->repr();
auto shapes = get_block_shapes();
res += "<";
for(size_t i = 0; i < shapes.size(); i++){
if(i > 0)
res += ", ";
res += std::to_string(shapes[i]);
}
res+= ">";
return res;
}
std::string repr() const {
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case BF16TyID: return "bf16";
case FP16TyID: return "f16";
case FP32TyID: return "f32";
case FP64TyID: return "f64";
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
case BlockTyID: return tile_repr();
default: break;
}
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
};
private:
context &ctx_;
id_t id_;
protected:
contained_tys_vec_t contained_tys_;
};
class integer_type: public type {
friend class context_impl;
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
};
class composite_type: public type{
protected:
using type::type;
public:
bool index_valid(value *idx) const;
type* get_type_at_index(value *idx) const;
};
class struct_type: public composite_type {
public:
struct_type(const contained_tys_vec_t& tys, bool is_packed);
unsigned get_num_types() const { return contained_tys_.size(); }
static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
private:
bool is_packed_;
};
class block_type: public composite_type {
private:
block_type(type *ty, const block_shapes_t &shapes);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
const block_shapes_t& get_shapes() const { return shapes_; }
unsigned get_num_elements() const;
unsigned get_bitwidth() const;
// factory methods
static block_type* get(type *ty, const block_shapes_t &shapes);
static block_type* get_same_shapes(type *ty, type *ref);
private:
block_shapes_t shapes_;
};
class pointer_type: public type {
private:
pointer_type(type *ty, unsigned address_space);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
unsigned get_address_space() const { return address_space_; }
type *get_element_ty() const { return contained_tys_[0]; }
// factory methods
static pointer_type* get(type *ty, unsigned address_space);
private:
unsigned address_space_;
};
class function_type: public type {
private:
function_type(type *ret_ty, const std::vector<type *> &param_tys);
public:
// accessors
unsigned get_num_params() const { return contained_tys_.size() - 1; }
const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
const_ty_iterator params_end() const { return contained_tys_.end(); }
ty_iterator params_begin() { return contained_tys_.begin() + 1; }
ty_iterator params_end() { return contained_tys_.end(); }
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
type* get_return_ty() const { return contained_tys_.at(0); }
void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
// factory methods
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
};
}
}
#endif

View File

@ -1,31 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CFG_H_
#define _TRITON_IR_CFG_H_
#include <vector>
#include <functional>
namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
class value;
class cfg {
public:
static std::vector<basic_block *> post_order(function* fn);
static std::vector<basic_block *> reverse_post_order(function* fn);
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}
}
#endif

Some files were not shown because too many files have changed in this diff Show More