Merge branch 'master' into rcom52_fixes

This commit is contained in:
Michael Melesse
2022-10-17 17:53:48 +00:00
151 changed files with 20150 additions and 19097 deletions

View File

@@ -8,7 +8,7 @@ jobs:
Build-Documentation: Build-Documentation:
runs-on: self-hosted runs-on: [self-hosted, V100]
steps: steps:
@@ -18,24 +18,37 @@ jobs:
with: with:
ref: 'gh-pages' ref: 'gh-pages'
- name: Clear docs
run: |
rm -r /tmp/triton-docs
continue-on-error: true
- name: Checkout branch - name: Checkout branch
uses: actions/checkout@v1 uses: actions/checkout@v1
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e .
- name: Build docs - name: Build docs
run: | run: |
git fetch origin master:master
cd docs cd docs
make html sphinx-multiversion . _build/html/
- name: Publish docs - name: Publish docs
run: | run: |
git branch
# update docs
mkdir /tmp/triton-docs;
mv docs/_build/html/* /tmp/triton-docs/
git checkout gh-pages git checkout gh-pages
sh ./update-website.sh cp -r CNAME /tmp/triton-docs/
cp -r index.html /tmp/triton-docs/
cp -r .nojekyll /tmp/triton-docs/
rm -r *
cp -r /tmp/triton-docs/* .
# ln -s master/index.html .
# mv master docs
git add .
git commit -am "[GH-PAGES] Updated website"
# publish docs
eval `ssh-agent -s` eval `ssh-agent -s`
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }} DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
git remote set-url origin git@github.com:openai/triton.git git remote set-url origin git@github.com:openai/triton.git

View File

@@ -5,14 +5,13 @@ on:
pull_request: pull_request:
branches: branches:
- master - master
- v2.0
jobs: jobs:
Integration-Tests: Integration-Tests:
runs-on: self-hosted runs-on: [self-hosted, V100]
steps: steps:
@@ -21,14 +20,23 @@ jobs:
- name: Clear cache - name: Clear cache
run: | run: |
rm -r /tmp/triton/ rm -r ~/.triton/
continue-on-error: true continue-on-error: true
- name: Install Triton - name: Install Triton
run: | run: |
alias python='python3' alias python='python3'
cd python cd python
pip3 install -e . pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Flake8
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"
- name: Unit tests - name: Unit tests
run: | run: |
@@ -44,4 +52,3 @@ jobs:
pytest -vs . pytest -vs .
sudo nvidia-smi -i 0 -rgc sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc sudo nvidia-smi -i 0 -rmc

View File

@@ -8,7 +8,7 @@ jobs:
Build-Wheels: Build-Wheels:
runs-on: self-hosted runs-on: [self-hosted, V100]
steps: steps:
@@ -18,7 +18,7 @@ jobs:
- name: Patch setup.py - name: Patch setup.py
run: | run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py #sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g') export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg echo "[build_ext]" >> python/setup.cfg

3
.gitignore vendored
View File

@@ -7,3 +7,6 @@ python/build/
python/triton.egg-info/ python/triton.egg-info/
python/triton/_C/libtriton.pyd python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so python/triton/_C/libtriton.so
.vscode
.vs

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -33,6 +33,9 @@ endif()
# Compiler flags # Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32) if(WIN32)
SET(BUILD_SHARED_LIBS OFF) SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
@@ -175,7 +178,7 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
if(WIN32) if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32 target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
else() else()
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY}) target_link_libraries(triton ${LLVM_LIBRARIES} z)
endif() endif()

View File

@@ -1,6 +1,6 @@
/* /*
* Copyright 2018-2020 Philippe Tillet * Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2021 OpenAI * Copyright 2020-2022 OpenAI
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files * a copy of this software and associated documentation files

View File

@@ -12,12 +12,27 @@
# Triton # Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment for expressing tensor math workloads that offers high flexibility, developer productivity and end to end performance.
The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton! The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton!
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials. The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
# Quick Installation
You can install the latest stable release of Triton from pip:
```bash
pip install triton
```
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
And the latest nightly release:
```bash
pip install -U --pre triton
```
# Changelog # Changelog
Version 1.1 is out! New features include: Version 1.1 is out! New features include:

View File

@@ -25,7 +25,7 @@
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn). # LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0). # LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
# #
# Note: The variable names were chosen in conformance with the offical CMake # Note: The variable names were chosen in conformance with the official CMake
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt. # guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
# Try suffixed versions to pick up the newest LLVM install available on Debian # Try suffixed versions to pick up the newest LLVM install available on Debian

27
docs/_templates/versions.html vendored Normal file
View File

@@ -0,0 +1,27 @@
{%- if current_version %}
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: {{ current_version.name }}
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
{%- if versions.tags %}
<dl>
<dt>Tags</dt>
{%- for item in versions.tags %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
{%- if versions.branches %}
<dl>
<dt>Branches</dt>
{%- for item in versions.branches %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
</div>
</div>
{%- endif %}

View File

@@ -34,15 +34,18 @@ def process_sig(app, what, name, obj, options, signature, return_annotation):
def setup(app): def setup(app):
"""Customize function args retrieving to get args under decorator.""" """Customize function args retrieving to get args under decorator."""
import sphinx import sphinx
import triton import os
app.connect("autodoc-process-signature", process_sig) app.connect("autodoc-process-signature", process_sig)
os.system("pip install -e ../python")
def forward_jit_fn(func): def forward_jit_fn(func):
old = func old = func
def wrapped(obj, **kwargs): def wrapped(obj, **kwargs):
if isinstance(obj, triton.code_gen.JITFunction): import triton
if isinstance(obj, triton.runtime.JITFunction):
obj = obj.fn obj = obj.fn
return old(obj) return old(obj)
@@ -52,7 +55,8 @@ def setup(app):
old_documenter = sphinx.ext.autosummary.get_documenter old_documenter = sphinx.ext.autosummary.get_documenter
def documenter(app, obj, parent): def documenter(app, obj, parent):
if isinstance(obj, triton.code_gen.JITFunction): import triton
if isinstance(obj, triton.runtime.JITFunction):
obj = obj.fn obj = obj.fn
return old_documenter(app, obj, parent) return old_documenter(app, obj, parent)
@@ -66,9 +70,17 @@ def setup(app):
import sys import sys
import os import os
sys.path.insert(0, os.path.abspath('../python/')) sys.path.insert(0, os.path.abspath('../python/'))
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion']
autosummary_generate = True autosummary_generate = True
# versioning config
smv_tag_whitelist = r'^(v1.1.2)$'
smv_branch_whitelist = r'^master$'
smv_remote_whitelist = None
smv_released_pattern = r'^tags/.*$'
smv_outputdir_format = '{ref.name}'
smv_prefer_remote_refs = False
# Sphinx gallery # Sphinx gallery
extensions += ['sphinx_gallery.gen_gallery'] extensions += ['sphinx_gallery.gen_gallery']
from sphinx_gallery.sorting import FileNameSortKey from sphinx_gallery.sorting import FileNameSortKey
@@ -85,6 +97,11 @@ sphinx_gallery_conf = {
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
html_sidebars = {
'**': [
'_templates/versions.html',
],
}
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:

View File

@@ -34,17 +34,19 @@ You can install the Python package from source by running the following commands
.. code-block:: bash .. code-block:: bash
git clone https://github.com/openai/triton.git; git clone https://github.com/openai/triton.git;
cd triton/python; cd triton;
git submodule update --init --recursive;
cd python;
pip install cmake; # build time dependency pip install cmake; # build time dependency
pip install -e . pip install -e .
Note that, if llvm-11 is not present on your system, the setup.py script will download the official LLVM11 static libraries link against that. Note that, if llvm-11 is not present on your system and you are on linux, the setup.py script will download the official LLVM11 static libraries link against that. For windows users, LLVM must be installed and configured in PATH.
You can then test your installation by running the unit tests: You can then test your installation by running the unit tests:
.. code-block:: bash .. code-block:: bash
pip install -r requirements-test.txt pip install -e '.[tests]'
pytest -vs test/unit/ pytest -vs test/unit/
and the benchmarks and the benchmarks

View File

@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
Program Representation Program Representation
+++++++++++++++++++++++ +++++++++++++++++++++++
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming. Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
.. table:: .. table::
:widths: 50 50 :widths: 50 50

View File

@@ -106,9 +106,13 @@ Atomic Ops
:nosignatures: :nosignatures:
atomic_cas atomic_cas
atomic_xchg
atomic_add atomic_add
atomic_max atomic_max
atomic_min atomic_min
atomic_and
atomic_or
atomic_xor
Comparison ops Comparison ops

View File

@@ -12,7 +12,9 @@ namespace ir {
class phi_node; class phi_node;
class splat_inst; class splat_inst;
class cast_inst; class cast_inst;
class cmp_inst;
class reshape_inst; class reshape_inst;
class dequantize_inst;
class broadcast_inst; class broadcast_inst;
class binary_operator; class binary_operator;
class getelementptr_inst; class getelementptr_inst;
@@ -33,8 +35,10 @@ private:
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x); std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x); std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x); std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x); std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x); std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x); std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v); std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v); std::vector<cst_info> populate_is_constant(ir::value *v);
@@ -42,6 +46,7 @@ private:
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x); std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x); std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x); std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x); std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x); std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x); std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
@@ -52,6 +57,7 @@ private:
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x); std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x); std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x); std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x); std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x); std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x); std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
@@ -65,6 +71,7 @@ public:
void run(ir::module &mod); void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const; unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const; std::vector<unsigned> contiguous(ir::value* v) const;
std::vector<cst_info> get_cst_info(ir::value* v) const;
private: private:
std::map<ir::value*, std::vector<cst_info>> is_constant_; std::map<ir::value*, std::vector<cst_info>> is_constant_;

View File

@@ -25,6 +25,7 @@ private:
void update_graph_reduce(ir::instruction *i); void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i); void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i); void update_graph_trans(ir::instruction *i);
void update_graph_dequantize(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i); void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i); void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i, void update_graph_elementwise(ir::instruction *i,

View File

@@ -103,12 +103,70 @@ public:
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); } int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; } int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
virtual int contig_per_thread(size_t k) = 0;
protected: protected:
std::vector<int> shape_per_cta_; std::vector<int> shape_per_cta_;
}; };
class mma_layout: public distributed_layout { class mma_layout: public distributed_layout {
public:
enum TensorCoreType : uint8_t {
// floating-point tensor core instr
FP32_FP16_FP16_FP32 = 0, // default
FP32_BF16_BF16_FP32,
FP32_TF32_TF32_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
//
NOT_APPLICABLE,
};
// Used on nvidia GPUs with sm >= 80
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
{FP32_FP16_FP16_FP32, {16, 8, 16}},
{FP32_BF16_BF16_FP32, {16, 8, 16}},
{FP32_TF32_TF32_FP32, {16, 8, 8}},
{INT32_INT1_INT1_INT32, {16, 8, 256}},
{INT32_INT4_INT4_INT32, {16, 8, 64}},
{INT32_INT8_INT8_INT32, {16, 8, 32}},
};
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
{FP32_FP16_FP16_FP32, {8, 8, 8}},
{FP32_BF16_BF16_FP32, {8, 8, 8}},
{FP32_TF32_TF32_FP32, {8, 8, 4}},
{INT32_INT1_INT1_INT32, {8, 8, 64}},
{INT32_INT4_INT4_INT32, {8, 8, 32}},
{INT32_INT8_INT8_INT32, {8, 8, 16}},
};
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
{FP32_FP16_FP16_FP32, 8},
{FP32_BF16_BF16_FP32, 8},
{FP32_TF32_TF32_FP32, 4},
{INT32_INT1_INT1_INT32, 128},
{INT32_INT4_INT4_INT32, 32},
{INT32_INT8_INT8_INT32, 16},
};
public: public:
mma_layout(size_t num_warps, mma_layout(size_t num_warps,
const std::vector<int>& axes, const std::vector<int>& axes,
@@ -116,13 +174,25 @@ public:
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align, target *tgt, analysis::align* align, target *tgt,
shared_layout* layout_a, shared_layout* layout_a,
shared_layout* layout_b); shared_layout* layout_b,
ir::value *dot);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor // accessor
int fpw(size_t k) { return fpw_.at(k); } int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); } int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); } int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); } int rep(size_t k) { return rep_.at(k); }
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
// helpers for generator.cc
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
// setter
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
private: private:
// fragment per warp // fragment per warp
@@ -135,6 +205,10 @@ private:
std::vector<int> spt_; std::vector<int> spt_;
// repetitions // repetitions
std::vector<int> rep_; std::vector<int> rep_;
// contiguous per thread
std::vector<int> contig_per_thread_;
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
}; };
class scanline_layout: public distributed_layout { class scanline_layout: public distributed_layout {
@@ -149,7 +223,9 @@ public:
// accessor // accessor
int mts(size_t k) { return mts_.at(k); } int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); } int nts(size_t k) { return nts_.at(k); }
int contig_per_thread(size_t k) { return nts_.at(k); }
int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
private: private:
// micro tile size. The size of a tile held by a thread block. // micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_; std::vector<int> mts_;
@@ -170,7 +246,7 @@ struct N_buffer_info_t {
std::map<ir::value*, int> firsts_idx; std::map<ir::value*, int> firsts_idx;
}; };
// abstract for dot and coresponding smem values // abstract for dot and corresponding smem values
class shared_layout: public data_layout { class shared_layout: public data_layout {
private: private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
@@ -183,7 +259,8 @@ public:
const std::vector<unsigned>& shapes, const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_, const std::vector<ir::value *> &values_,
ir::type *ty, ir::type *ty,
analysis::align* align); analysis::align* align, target *tgt,
bool is_tmp = false);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors // accessors
size_t get_size() { return size_; } size_t get_size() { return size_; }
@@ -198,7 +275,10 @@ public:
ir::value* hmma_dot_b() { return hmma_dot_b_; } ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;} int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; } data_layout* get_arg_layout() { return arg_layout_; }
bool is_tmp() const { return is_tmp_; }
private: private:
size_t size_; size_t size_;
@@ -210,6 +290,10 @@ private:
ir::value* hmma_dot_b_; ir::value* hmma_dot_b_;
data_layout* arg_layout_; data_layout* arg_layout_;
int mma_vec_; int mma_vec_;
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
bool is_tmp_;
}; };
@@ -228,13 +312,20 @@ private:
void create(size_t id, const std::vector<ir::value*>& values); void create(size_t id, const std::vector<ir::value*>& values);
public: void create_tmp_layout(size_t id, data_layout* arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
ir::instruction* i,
bool is_index = false);
public:
// constructor // constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt); layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors // accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); } unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); } bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); } const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();} size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); } data_layout* get(size_t id) { return layouts_.at(id); }
@@ -242,7 +333,19 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; } std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); } bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);} int tmp(ir::value* i) { return tmp_.at(i);}
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; } void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// layout checkers
bool is_scanline(ir::instruction* i);
bool is_coalesced_scanline(ir::instruction* i);
bool is_mma(ir::instruction* i);
bool is_a100_mma(ir::instruction* i);
// execution // execution
void run(ir::module &mod); void run(ir::module &mod);
@@ -256,6 +359,7 @@ private:
std::map<size_t, std::vector<ir::value*>> values_; std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_; std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_; std::map<ir::value*, size_t> tmp_;
std::map<ir::value*, size_t> tmp_index_;
}; };
} }

View File

@@ -1,12 +1,14 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H #ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H #define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h" #include "triton/tools/graph.h"
#include "llvm/ADT/MapVector.h"
#include <set>
#include <vector>
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -42,14 +44,14 @@ struct segment {
class liveness { class liveness {
private: private:
typedef std::map<shared_layout*, segment> intervals_map_t; typedef llvm::MapVector<shared_layout*, segment> intervals_map_t;
public: public:
// constructor // constructor
liveness(layouts *l): layouts_(l){ } liveness(layouts *l): layouts_(l){ }
// accessors // accessors
const intervals_map_t& get() const { return intervals_; } const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); } segment get(shared_layout* v) const { return intervals_.lookup(v); }
// run // run
void run(ir::module &mod); void run(ir::module &mod);

View File

@@ -0,0 +1,90 @@
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
#include <memory>
#include <string>
#include <map>
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
namespace triton {
namespace codegen {
///
/// \brief ExternLib is a class that represents a library of external functions.
///
class ExternLib {
public:
ExternLib(const std::string &name, const std::string &path)
: name_(name), path_(path) {}
virtual ~ExternLib() = default;
virtual const std::string &name() const { return name_; }
virtual const std::string &path() const { return path_; }
///
/// \brief Load the library and return the module.
///
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
///
/// \brief Link the module into the given module.
///
void link(std::unique_ptr<llvm::Module> &llvm,
std::unique_ptr<llvm::Module> &mod);
///
/// \brief Run load, link, and opt on the module.
///
virtual void install(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) {
auto mod = load(ctx);
link(llvm, mod);
opt(ctx, llvm);
}
///
/// \brief Run opt on the module.
///
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) = 0;
private:
std::string name_;
std::string path_;
};
///
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
///
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
///
/// \brief Concrete class for NVIDIA's libdevice library.
///
class LibDevice final : public ExternLib {
public:
LibDevice(const std::string &name, const std::string &path)
: ExternLib(name, path) {}
virtual ~LibDevice() = default;
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) override;
};
///
/// \brief Create an ExternLib instance based on the name and path.
///
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
const std::string &lib_path);
} // namespace codegen
} // namespace triton
#endif

View File

@@ -3,6 +3,7 @@
#include <memory> #include <memory>
#include "extern_lib.h"
namespace llvm{ namespace llvm{
class Module; class Module;
@@ -30,12 +31,10 @@ namespace codegen{
// TODO: // TODO:
// There should be a proper pass manager there! // There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
codegen::target* target, ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
int sm, int num_warps, int num_warps, int num_stages, int &shared_static,
int num_stages, int &shared_static); const ExternLibMap &extern_libs);
} }
} }

View File

@@ -4,7 +4,9 @@
#define _TRITON_SELECTION_GENERATOR_H_ #define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h" #include "triton/ir/visitor.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/layout.h"
#include "triton/codegen/extern_lib.h"
#include <functional> #include <functional>
// forward // forward
@@ -24,6 +26,7 @@ namespace llvm{
class IRBuilder; class IRBuilder;
class ArrayType; class ArrayType;
class Function; class Function;
class StructType;
} }
namespace triton{ namespace triton{
@@ -114,8 +117,17 @@ private:
private: private:
Type *cvt(ir::type *ty); Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr); llvm::Attribute cvt(ir::attribute attr);
void packed_type(ir::value* i);
void forward_declare(ir::function* fn);
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
public: private:
typedef std::function<void(
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn, bool is_first)>
acc_fn_t;
public:
generator(analysis::axes *a_axes, generator(analysis::axes *a_axes,
analysis::layouts *layouts, analysis::layouts *layouts,
analysis::align *alignment, analysis::align *alignment,
@@ -125,6 +137,8 @@ public:
unsigned num_warps); unsigned num_warps);
void visit_value(ir::value* v); void visit_value(ir::value* v);
void visit_call_inst(ir::call_inst*);
void visit_launch_inst(ir::launch_inst *);
void visit_phi_node(ir::phi_node*); void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*); void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*); void visit_getelementptr_inst(ir::getelementptr_inst*);
@@ -134,9 +148,19 @@ public:
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3); std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0); Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0); Value* fp32_to_bf16(Value *in0);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
void visit_dequantize_inst(ir::dequantize_inst*);
void visit_cast_inst(ir::cast_inst*); void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*); void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*); void visit_cond_branch_inst(ir::cond_branch_inst*);
@@ -148,6 +172,8 @@ public:
void visit_unmasked_store_inst(ir::unmasked_store_inst*); void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*); void visit_masked_store_inst(ir::masked_store_inst*);
void visit_cat_inst(ir::cat_inst*); void visit_cat_inst(ir::cat_inst*);
void visit_extract_value_inst(ir::extract_value_inst *);
void visit_insert_value_inst(ir::insert_value_inst *);
void visit_reshape_inst(ir::reshape_inst*); void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*); void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*); void visit_broadcast_inst(ir::broadcast_inst*);
@@ -168,8 +194,8 @@ public:
void visit_trans_inst(ir::trans_inst*); void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*); void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i); Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*); void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reduce_inst(ir::reduce_inst*); void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*); void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in); void visit_layout_convert(ir::value *out, ir::value *in);
@@ -182,6 +208,9 @@ public:
void visit_async_wait_inst(ir::async_wait_inst*); void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*); // void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*); void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
// void visit_make_range_sta(ir::make_range_sta*); // void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*); void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*); void visit_constant_int(ir::constant_int*);
@@ -197,12 +226,21 @@ public:
void visit_layout_scanline(analysis::scanline_layout*); void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*); void visit_layout_shared(analysis::shared_layout*);
// Add a new external library based on given name and path if it doesn't exist
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
private: // Get all external libraries
const ExternLibMap &get_extern_lib_map() {
return extern_lib_map_;
}
private:
LLVMContext *ctx_; LLVMContext *ctx_;
Builder* builder_; Builder* builder_;
Module *mod_; Module *mod_;
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
analysis::axes *a_axes_; analysis::axes *a_axes_;
analysis::swizzle *swizzle_; analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_; std::map<unsigned, distributed_axis> axes_;
@@ -239,6 +277,7 @@ private:
/// triton bb -> llvm bb /// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_; std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_; std::map<ir::value*, std::vector<int>> ords_;
std::map<ir::value*, Function*> fns_;
// helper for creating llvm values // helper for creating llvm values
adder add; adder add;
@@ -250,6 +289,9 @@ private:
/// Record prefetch instrs that needs to be moved /// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_; std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
// Eviction policies
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
}; };
} }

View File

@@ -32,11 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen); ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public: public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder); triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod); void run(ir::module &mod);
private: private:
bool has_sm80_;
analysis::align* align_; analysis::align* align_;
analysis::layouts* layout_; analysis::layouts* layout_;
}; };

View File

@@ -15,18 +15,26 @@ namespace ir {
} }
namespace codegen{ namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{ namespace transform{
class cts { class cts {
private: private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared); bool is_shmem_op(ir::instruction* i, int op);
bool is_shmem_res(ir::value* i);
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
public: public:
cts(bool use_async = false): use_async_(use_async) {} cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
void run(ir::module &mod); void run(ir::module &mod);
private: private:
bool use_async_; bool has_sm80_;
analysis::layouts* layouts_;
}; };
} }

View File

@@ -0,0 +1,31 @@
#pragma once
#include <list>
namespace triton {
namespace ir {
class module;
class function;
class call_inst;
class builder;
}
namespace codegen{
namespace transform{
struct fncmp {
bool operator()(ir::function* x, ir::function* y) const;
};
class inliner {
public:
inliner() {}
void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list<ir::call_inst*>& callsites);
void run(ir::module &mod);
};
}
}
}

View File

@@ -30,6 +30,9 @@ private:
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder); bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder); bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder); bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);

View File

@@ -88,6 +88,7 @@ public:
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count); static CUresult cuDeviceGetCount(int *count);
// link management // link management
static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
@@ -214,6 +215,7 @@ private:
static void* cuDeviceGetAttribute_; static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_; static void* cuDeviceGetCount_;
// link management // link management
static void* cuLinkAddFile_v2_;
static void* cuLinkAddData_v2_; static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_; static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_; static void* cuLinkDestroy_;

View File

@@ -9,8 +9,9 @@ namespace triton{
namespace driver{ namespace driver{
void init_llvm(); void init_llvm();
std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version); std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, int cc); std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc); CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc); std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path); hipModule_t amdgpu_to_hipmodule(const std::string& path);

5607
include/triton/external/CUDA/cuda.h vendored Executable file → Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
#pragma once #pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_ #ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_ #define _TRITON_IR_BASIC_BLOCK_H_
@@ -27,7 +27,7 @@ public:
private: private:
// constructors // constructors
basic_block(context &ctx, const std::string &name, function *parent); basic_block(context &ctx, const std::string &name, function *parent, basic_block *next);
public: public:
// accessors // accessors
@@ -35,6 +35,7 @@ public:
context& get_context() { return ctx_; } context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi // get iterator to first instruction that is not a phi
void replace_phi_uses_with(basic_block* before, basic_block* after);
iterator get_first_non_phi(); iterator get_first_non_phi();
// get instruction list // get instruction list
@@ -60,13 +61,16 @@ public:
inline const instruction &back() const { return *inst_list_.back(); } inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); } inline instruction &back() { return *inst_list_.back(); }
void append_instruction(ir::instruction* i);
// split
basic_block* split_before(ir::instruction* loc, const std::string& name);
// predecessors // predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; } std::vector<basic_block*> get_predecessors() const;
const std::vector<basic_block*>& get_successors() const { return succs_; } std::vector<basic_block*> get_successors() const;
void add_predecessor(basic_block* pred);
// factory functions // factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent); static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr);
void print(std::ostream &os); void print(std::ostream &os);

View File

@@ -22,13 +22,16 @@ class phi_node;
/* Builder */ /* Builder */
class builder{ class builder{
public:
typedef basic_block::iterator iterator; typedef basic_block::iterator iterator;
public: public:
// Constructor // Constructor
builder(context &ctx); builder(context &ctx);
// Getters // Getters
const context& get_context() { return ctx_; } // const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }
// Setters // Setters
void set_insert_point(iterator instr); void set_insert_point(iterator instr);
void set_insert_point(instruction* i); void set_insert_point(instruction* i);
@@ -38,8 +41,8 @@ public:
iterator get_insert_point() { return insert_point_;} iterator get_insert_point() { return insert_point_;}
// Constants // Constants
value *get_int1(bool val); value *get_int1(bool val);
value *get_int32(int32_t val); value *get_int32(uint32_t val);
value *get_int64(int64_t val); value *get_int64(uint64_t val);
value *get_float16(float val); value *get_float16(float val);
value *get_float32(float val); value *get_float32(float val);
value *get_float64(float val); value *get_float64(float val);
@@ -51,7 +54,9 @@ public:
type *get_int16_ty(); type *get_int16_ty();
type *get_int32_ty(); type *get_int32_ty();
type *get_int64_ty(); type *get_int64_ty();
type *get_fp8_ty();
type *get_half_ty(); type *get_half_ty();
type *get_bf16_ty();
type *get_float_ty(); type *get_float_ty();
type *get_double_ty(); type *get_double_ty();
// Insert // Insert
@@ -68,8 +73,13 @@ public:
value* create_br(basic_block *dest); value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void(); value* create_ret_void();
value* create_ret(value *ret);
// Dequantize instructions
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
// Cast instructions // Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty); value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_int_to_ptr(value *src, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty); value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty); value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty); value* create_ui_to_fp(value *src, type *dst_ty);
@@ -79,6 +89,9 @@ public:
value* create_fp_trunc(value *src, type *dst_ty); value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed); value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg); value *create_downcast(value *arg);
// Call instruction
value* create_call(function* fn, const std::vector<value*>& args);
value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
// Phi instruction // Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved); phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions // Binary instructions
@@ -88,11 +101,11 @@ public:
value *create_frem(value *lhs, value *rhs); value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs); value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs); value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs); value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs); value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs); value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs); value *create_urem(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false); value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
@@ -131,25 +144,48 @@ public:
value *create_xor(value *lhs, value *rhs); value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs); value *create_or(value *lhs, value *rhs);
// Input/Output // Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache); value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val); value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask); value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
// Struct instructions
value *create_insert_value(value* val, value *elt, size_t idx);
value *create_extract_value(value* val, size_t idx);
// Block instruction // Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes); value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes); value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs); value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes); value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Atomic instruction
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_atomic_max(value *ptr, value *val, value *msk);
value *create_atomic_umax(value *ptr, value *val, value *msk);
value *create_atomic_min(value *ptr, value *val, value *msk);
value *create_atomic_umin(value *ptr, value *val, value *msk);
value *create_atomic_fadd(value *ptr, value *val, value *msk);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_and(value *ptr, value *val, value *msk);
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Utilities
value *create_clock();
value *create_globaltimer();
// Extern instruction
value *create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty);
// Built-in instruction // Built-in instruction
value *create_get_program_id(unsigned axis); value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis); value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg); value *create_exp(value* arg);
value *create_cos(value* arg); value *create_cos(value* arg);
value *create_sin(value* arg); value *create_sin(value* arg);
value *create_log(value* arg); value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C); value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {}); value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A); value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
@@ -158,7 +194,7 @@ public:
// These have no place in the IR, and hopefully they can be removed at some point // These have no place in the IR, and hopefully they can be removed at some point
value *create_umulhi(value* lhs, value* rhs); value *create_umulhi(value* lhs, value* rhs);
value *create_copy_to_shared(value *arg); value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
value *create_copy_from_shared(value *arg); value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = ""); value *create_barrier(const std::string &name = "");
value *create_async_wait(int N); value *create_async_wait(int N);

View File

@@ -9,7 +9,6 @@
namespace triton{ namespace triton{
namespace ir{ namespace ir{
class builder;
class type; class type;
class context_impl; class context_impl;
@@ -21,7 +20,6 @@ public:
context& operator=(const context&) = delete; context& operator=(const context&) = delete;
public: public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl; std::shared_ptr<context_impl> p_impl;
}; };

View File

@@ -3,17 +3,15 @@
#ifndef _TRITON_IR_CONTEXT_IMPL_H_ #ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_ #define _TRITON_IR_CONTEXT_IMPL_H_
#include <map>
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/ir/constant.h"
#include <map>
#include <memory>
namespace triton{ namespace triton{
namespace ir{ namespace ir{
class context; class context;
class constant;
class constant_int;
class constant_fp;
class undef_value;
/* Context impl */ /* Context impl */
class context_impl { class context_impl {
@@ -29,16 +27,17 @@ public:
// integer types // integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty; integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types // Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys; std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types // Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys; std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
// Struct types
std::map<type::contained_tys_vec_t, struct_type*> struct_tys;
// Int constants // Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_; std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
// Float constants // Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_; std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
// undef values // undef values
std::map<type*, undef_value*> uv_constants_; std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
}; };

View File

@@ -1,110 +0,0 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
// math
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -95,6 +95,9 @@ enum value_id_t: unsigned {
INSTRUCTIONS INSTRUCTIONS
* ------------ */ * ------------ */
INST_BEGIN, INST_BEGIN,
// call
INST_CALL,
INST_LAUNCH,
// phi // phi
INST_PHI, INST_PHI,
// arithmetic // arithmetic
@@ -105,6 +108,8 @@ enum value_id_t: unsigned {
// cmp // cmp
INST_ICMP, INST_ICMP,
INST_FCMP, INST_FCMP,
// dequantize
INST_DEQUANTIZE,
// cast // cast
INST_CAST_TRUNC, INST_CAST_TRUNC,
INST_CAST_ZEXT, INST_CAST_ZEXT,
@@ -129,6 +134,9 @@ enum value_id_t: unsigned {
INST_MASKED_LOAD_ASYNC, INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE, INST_UNMASKED_STORE,
INST_MASKED_STORE, INST_MASKED_STORE,
// struct
INST_EXTRACT_VALUE,
INST_INSERT_VALUE,
// retile // retile
INST_RESHAPE, INST_RESHAPE,
INST_SPLAT, INST_SPLAT,
@@ -148,6 +156,8 @@ enum value_id_t: unsigned {
INST_COS, INST_COS,
INST_SIN, INST_SIN,
INST_LOG, INST_LOG,
// extern
INST_EXTERN_ELEMENTWISE,
// array arithmetic // array arithmetic
INST_TRANS, INST_TRANS,
INST_REDUCE, INST_REDUCE,
@@ -165,6 +175,8 @@ enum value_id_t: unsigned {
INST_MAKE_RANGE_STA, INST_MAKE_RANGE_STA,
INST_MAKE_RANGE, INST_MAKE_RANGE,
INST_PREFETCH_S, INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
}; };

View File

@@ -112,7 +112,7 @@ public:
static function *create(function_type *ty, linkage_types_t linkage, static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod); const std::string &name, module *mod);
// blocks // blocks
const blocks_t &blocks() { return blocks_; } blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; } const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr); void insert_block(basic_block* block, basic_block *next = nullptr);
@@ -121,6 +121,8 @@ public:
const attr_map_t &attrs() { return attrs_; } const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); } bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; } std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
bool get_is_kernel() { return is_kernel_; }
void print(std::ostream &os); void print(std::ostream &os);
@@ -134,6 +136,7 @@ private:
args_t args_; args_t args_;
blocks_t blocks_; blocks_t blocks_;
attr_map_t attrs_; attr_map_t attrs_;
bool is_kernel_;
}; };
} }

View File

@@ -59,8 +59,8 @@ public:
std::string repr() const { return repr_impl(); } std::string repr() const { return repr_impl(); }
// metadata // metadata
void set_metadata(ir::metadata::kind_t kind, void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;} std::vector<unsigned> value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning // cloning
ir::instruction* clone() { ir::instruction* clone() {
ir::instruction* res = clone_impl(); ir::instruction* res = clone_impl();
@@ -77,10 +77,55 @@ public:
private: private:
basic_block *parent_; basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_; std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_;
value_id_t id_; value_id_t id_;
}; };
//===----------------------------------------------------------------------===//
// call_inst classes
//===----------------------------------------------------------------------===//
class call_inst: public instruction {
private:
std::string repr_impl() const;
call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next);
public:
static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "", instruction *next = nullptr);
ir::function* get_fn() { return fn_; }
_TRITON_DEFINE_CLONE(call_inst)
_TRITON_DEFINE_ACCEPT(call_inst)
private:
ir::function* fn_;
};
class launch_inst: public instruction {
private:
std::string repr_impl() const { return "launch"; }
launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
const std::string &name = "", instruction *next = nullptr);
public:
static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
const std::string& name = "", instruction* next = nullptr);
ir::function* get_fn();
std::vector<ir::value*> get_values();
std::vector<ir::value*> get_grid();
ir::value* get_num_warps();
_TRITON_DEFINE_CLONE(launch_inst)
_TRITON_DEFINE_ACCEPT(launch_inst)
private:
unsigned val_begin;
unsigned val_end;
unsigned grid_begin;
unsigned grid_end;
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// phi_node classes // phi_node classes
@@ -117,6 +162,7 @@ private:
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// binary_operator classes // binary_operator classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class binary_operator: public instruction { class binary_operator: public instruction {
public: public:
typedef binary_op_t op_t; typedef binary_op_t op_t;
@@ -145,6 +191,10 @@ public:
bool is_shl() const; bool is_shl() const;
bool is_shr() const; bool is_shr() const;
// Approx
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
// Wraps // Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; } void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
@@ -163,6 +213,8 @@ public:
binary_op_t op_; binary_op_t op_;
bool has_no_unsigned_wrap_; bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_; bool has_no_signed_wrap_;
bool fdiv_ieee_rnd_;
}; };
@@ -222,6 +274,24 @@ protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next); unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
}; };
//===----------------------------------------------------------------------===//
// dequantize_inst classes
//===----------------------------------------------------------------------===//
class dequantize_inst: public instruction{
private:
std::string repr_impl() const override { return "dequantize"; }
protected:
dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
public:
static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dequantize_inst)
_TRITON_DEFINE_ACCEPT(dequantize_inst)
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// cast_inst classes // cast_inst classes
@@ -383,13 +453,31 @@ private:
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class io_inst: public instruction { class io_inst: public instruction {
public:
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
protected: protected:
io_inst(type *ty, value_id_t id, unsigned num_ops, io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
public: public:
// accessors // accessors
value *get_pointer_operand() { return get_operand(0); } value *get_pointer_operand() { return get_operand(0); }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
protected:
EVICTION_POLICY eviction_;
}; };
// load // load
@@ -401,9 +489,13 @@ public:
CG, CG,
}; };
CACHE_MODIFIER get_cache_modifier() const { return cache_; } CACHE_MODIFIER get_cache_modifier() const { return cache_; }
bool get_is_volatile() const { return is_volatile_; }
protected: protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
std::string get_cache_modifier_repr() const { std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca"; if (cache_ == CA) return ".ca";
@@ -412,20 +504,25 @@ protected:
} }
CACHE_MODIFIER cache_; CACHE_MODIFIER cache_;
std::string get_volatile_repr() {
return is_volatile_ ? ".volatile" : "";
}
bool is_volatile_;
private: private:
static type *get_pointee_type(type *ty); static type *get_pointee_type(type *ty);
}; };
// unmasked load // unmasked load
class unmasked_load_inst: public load_inst { class unmasked_load_inst: public load_inst {
private: private:
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next); unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
public: public:
static unmasked_load_inst* create(value *ptr, static unmasked_load_inst* create(value *ptr,
CACHE_MODIFIER cache, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst) _TRITON_DEFINE_CLONE(unmasked_load_inst)
@@ -436,7 +533,7 @@ public:
class masked_load_inst: public load_inst { class masked_load_inst: public load_inst {
private: private:
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
const std::string &name, instruction *next); const std::string &name, instruction *next);
public: public:
@@ -445,7 +542,8 @@ public:
value *get_false_value_operand() { return get_operand(2); } value *get_false_value_operand() { return get_operand(2); }
// factory method // factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value, static masked_load_inst* create(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst) _TRITON_DEFINE_CLONE(masked_load_inst)
@@ -455,8 +553,9 @@ public:
// masked load async // masked load async
class masked_load_async_inst: public load_inst { class masked_load_async_inst: public load_inst {
private: private:
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); } std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, masked_load_async_inst(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
const std::string &name, instruction *next); const std::string &name, instruction *next);
public: public:
@@ -466,6 +565,7 @@ public:
// factory method // factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache, load_inst::CACHE_MODIFIER cache,
EVICTION_POLICY eviction,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst) _TRITON_DEFINE_CLONE(masked_load_async_inst)
@@ -477,7 +577,7 @@ public:
// store // store
class store_inst: public io_inst { class store_inst: public io_inst {
protected: protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops, store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
public: public:
@@ -488,11 +588,11 @@ public:
class unmasked_store_inst: public store_inst{ class unmasked_store_inst: public store_inst{
private: private:
std::string repr_impl() const { return "unmasked_store"; } std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next); unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
public: public:
// factory method // factory method
static unmasked_store_inst* create(value* ptr, value *v, static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst) _TRITON_DEFINE_CLONE(unmasked_store_inst)
@@ -502,20 +602,58 @@ public:
class masked_store_inst: public store_inst{ class masked_store_inst: public store_inst{
private: private:
std::string repr_impl() const { return "masked_store"; } std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask, masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next); const std::string &name, instruction *next);
public: public:
// accessors // accessors
value *get_mask_operand() { return get_operand(2); } value *get_mask_operand() { return get_operand(2); }
// factory method // factory method
static masked_store_inst* create(value *ptr, value *v, value *mask, static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst) _TRITON_DEFINE_CLONE(masked_store_inst)
_TRITON_DEFINE_ACCEPT(masked_store_inst) _TRITON_DEFINE_ACCEPT(masked_store_inst)
}; };
//===----------------------------------------------------------------------===//
// struct classes
//===----------------------------------------------------------------------===//
// insert_value
class insert_value_inst: public instruction {
private:
std::string repr_impl() const { return "insertvalue"; }
insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next);
public:
static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr);
size_t get_idx() { return idx_; }
_TRITON_DEFINE_CLONE(insert_value_inst)
_TRITON_DEFINE_ACCEPT(insert_value_inst)
private:
size_t idx_;
};
// extract_value
class extract_value_inst: public instruction {
private:
std::string repr_impl() const { return "extractvalue"; }
extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next);
public:
static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr);
size_t get_idx() { return idx_; }
_TRITON_DEFINE_CLONE(extract_value_inst)
_TRITON_DEFINE_ACCEPT(extract_value_inst)
private:
size_t idx_;
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// retile_inst classes // retile_inst classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -641,6 +779,8 @@ private:
class atomic_inst: public io_inst { class atomic_inst: public io_inst {
public: public:
using io_inst::io_inst; using io_inst::io_inst;
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
io_inst(ty, id, num_ops, NORMAL, name, next) {}
}; };
class atomic_rmw_inst: public atomic_inst { class atomic_rmw_inst: public atomic_inst {
@@ -728,24 +868,40 @@ public:
class dot_inst: public builtin_inst { class dot_inst: public builtin_inst {
public: public:
enum TransT { NoTrans, Trans }; enum TransT { NoTrans, Trans };
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
UNKNOWN,
};
private: private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; } std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
public: public:
bool is_prefetched() const { return is_prefetched_; } bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
bool is_trans_a() const { return AT_ == Trans; }
bool is_trans_b() const { return BT_ == Trans; }
public: public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr); static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst) _TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst) _TRITON_DEFINE_ACCEPT(dot_inst)
private:
bool is_prefetched_ = false;
bool allow_tf32_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
TransT AT_;
TransT BT_;
}; };
//class outer_inst: public builtin_inst { //class outer_inst: public builtin_inst {
@@ -787,8 +943,11 @@ public:
class reduce_inst: public builtin_inst { class reduce_inst: public builtin_inst {
public: public:
enum op_t{ enum op_t{
ADD, SUB, MAX, MIN, ADD, SUB, MAX, MIN, UMAX, UMIN,
FADD, FSUB, FMAX, FMIN ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
FADD, FSUB, FMAX, FMIN,
ARGFMAX, ARGFMIN,
XOR
}; };
private: private:
@@ -805,12 +964,19 @@ public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; } unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; } op_t get_op() const { return op_; }
bool with_index() const {
return with_index_ops_.find(op_) != with_index_ops_.end();
}
private: private:
const static inline std::set<op_t> with_index_ops_ = {
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
unsigned axis_; unsigned axis_;
op_t op_; op_t op_;
}; };
class select_inst: public builtin_inst { class select_inst: public builtin_inst {
private: private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
@@ -928,7 +1094,53 @@ private:
constant_int* last_; constant_int* last_;
}; };
/* timing utilities */
class clock_inst: public instruction{
clock_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "clock"; }
_TRITON_DEFINE_CLONE(clock_inst)
_TRITON_DEFINE_ACCEPT(clock_inst)
public:
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class globaltimer_inst: public instruction{
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "globaltimer"; }
_TRITON_DEFINE_CLONE(globaltimer_inst)
_TRITON_DEFINE_ACCEPT(globaltimer_inst)
public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class extern_elementwise_inst : public instruction {
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
type *dst_ty, const std::string &lib_name,
const std::string &extern_lib_path,
const std::string &symbol_name,
const std::string &name, instruction *next);
std::string repr_impl() const { return "extern_elementwise"; }
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
public:
static extern_elementwise_inst *create(
context &ctx, const std::vector<value *> &args, type *dst_ty,
const std::string &lib_name = "", const std::string &lib_path = "",
const std::string &symbol_name = "", const std::string &name = "",
instruction *next = nullptr);
const std::string &get_lib_name() const { return lib_name_; }
const std::string &get_lib_path() const { return lib_path_; }
const std::string &get_symbol_name() const { return symbol_name_; }
private:
std::string lib_name_;
std::string lib_path_;
std::string symbol_name_;
};
} }
} }

View File

@@ -3,6 +3,8 @@
#ifndef _TRITON_IR_METADATA_H_ #ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_ #define _TRITON_IR_METADATA_H_
#include <vector>
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -16,14 +18,14 @@ public:
}; };
private: private:
metadata(kind_t kind, unsigned value); metadata(kind_t kind, std::vector<unsigned> value);
public: public:
static metadata* get(kind_t kind, unsigned value); static metadata* get(kind_t kind, std::vector<unsigned> value);
private: private:
kind_t kind_; kind_t kind_;
unsigned value_; std::vector<unsigned> value_;
}; };
} }

View File

@@ -34,50 +34,74 @@ class constant;
class global_value; class global_value;
class alloc_const; class alloc_const;
/* Module */ class value_constructor {
class module {
typedef std::pair<std::string, basic_block*> val_key_t; typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
lang::iteration_statement *statement;
basic_block *block;
};
private: private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block); phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi); value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi); value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block); value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }
public: public:
module(const std::string &name, builder& builder); value_constructor(builder &builder);
builder& get_builder();
// Setters
void set_value(const std::string& name, basic_block* block, value *x); void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x); void set_value(const std::string& name, value* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; } const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; } void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name); value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added // Seal block -- no more predecessors will be added
void seal_block(basic_block *block); void seal_block(basic_block *block);
// Metadata
private:
ir::builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
std::map<value*, value**> current_phi_;
};
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, std::vector<unsigned>> md_pair_t;
friend class function;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
builder &get_builder() { return builder_; };
const std::string& get_name() { return name_; };
// Functions // Functions
const functions_list_t &get_function_list() const { return functions_; } const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; } function *get_function(const std::string& name) {
if(symbols_.find(name) == symbols_.end())
throw std::runtime_error("function " + name + " is not declared");
return (function*)symbols_.at(name);
}
function *get_or_insert_function(const std::string &name, function_type *ty); function *get_or_insert_function(const std::string &name, function_type *ty);
bool has_function(const std::string& name){
return symbols_.find(name) != symbols_.end();
}
void remove_function(ir::function* fn){
functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end());
}
void reset_ret_ty(const std::string& name, type* ty);
// Const allocation // Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; } const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
@@ -85,22 +109,15 @@ public:
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; } void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; } const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata // Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
void print(std::ostream &os); void print(std::ostream &os);
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
private: private:
std::string name_; std::string name_;
builder& builder_; builder &builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_; functions_list_t functions_;
symbols_map_t symbols_; symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::vector<ir::alloc_const*> allocs_; std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_; std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_; std::map<std::string, md_pair_t> metadatas_;

View File

@@ -1,4 +1,4 @@
#pragma once #pragma once
#ifndef _TRITON_IR_TYPE_H_ #ifndef _TRITON_IR_TYPE_H_
#define _TRITON_IR_TYPE_H_ #define _TRITON_IR_TYPE_H_
@@ -6,6 +6,7 @@
#include <cassert> #include <cassert>
#include <vector> #include <vector>
#include <string> #include <string>
#include <stdexcept>
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -20,7 +21,6 @@ class type {
public: public:
typedef std::vector<unsigned> block_shapes_t; typedef std::vector<unsigned> block_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t; typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator; typedef contained_tys_vec_t::iterator ty_iterator;
typedef contained_tys_vec_t::const_iterator const_ty_iterator; typedef contained_tys_vec_t::const_iterator const_ty_iterator;
@@ -68,6 +68,8 @@ public:
type *get_tile_element_ty() const; type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const; unsigned get_pointer_address_space() const;
type *get_pointer_element_ty() const; type *get_pointer_element_ty() const;
unsigned get_struct_numel() const { return contained_tys_.size(); }
type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
// primitive predicates // primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; } bool is_void_ty() const { return id_ == VoidTyID; }
@@ -80,11 +82,10 @@ public:
bool is_metadata_ty() const { return id_ == MetadataTyID; } bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; } bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; } bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_bool_ty() const { return is_integer_ty(1); } bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; } bool is_block_ty() const { return id_ == BlockTyID; }
bool is_struct_ty() const { return id_ == StructTyID; }
// Composite predicates // Composite predicates
bool is_int_or_tileint_ty(); bool is_int_or_tileint_ty();
@@ -128,6 +129,7 @@ public:
switch(id_) { switch(id_) {
case VoidTyID: return "void"; case VoidTyID: return "void";
case FP8TyID: return "fp8"; case FP8TyID: return "fp8";
case BF16TyID: return "bf16";
case FP16TyID: return "f16"; case FP16TyID: return "f16";
case BF16TyID: return "bf16"; case BF16TyID: return "bf16";
case FP32TyID: return "f32"; case FP32TyID: return "f32";
@@ -135,15 +137,14 @@ public:
case LabelTyID: return "label"; case LabelTyID: return "label";
case MetadataTyID: return "md"; case MetadataTyID: return "md";
case TokenTyID: return "tok"; case TokenTyID: return "tok";
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth()); case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn"; case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*"; case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct"; case StructTyID: return "struct";
case BlockTyID: return tile_repr(); case BlockTyID: return tile_repr();
default: break; default: break;
} }
assert(false); throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
return "";
}; };
private: private:
@@ -160,7 +161,7 @@ class integer_type: public type {
private: private:
// constructors // constructors
integer_type(context &ctx, unsigned bitwidth) integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ } : type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
public: public:
// accessors // accessors
@@ -182,6 +183,16 @@ public:
type* get_type_at_index(value *idx) const; type* get_type_at_index(value *idx) const;
}; };
class struct_type: public composite_type {
public:
struct_type(const contained_tys_vec_t& tys, bool is_packed);
unsigned get_num_types() const { return contained_tys_.size(); }
static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
private:
bool is_packed_;
};
class block_type: public composite_type { class block_type: public composite_type {
private: private:
block_type(type *ty, const block_shapes_t &shapes); block_type(type *ty, const block_shapes_t &shapes);
@@ -230,6 +241,7 @@ public:
ty_iterator params_end() { return contained_tys_.end(); } ty_iterator params_end() { return contained_tys_.end(); }
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); } type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
type* get_return_ty() const { return contained_tys_.at(0); } type* get_return_ty() const { return contained_tys_.at(0); }
void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
// factory methods // factory methods
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys); static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
}; };

View File

@@ -22,6 +22,7 @@ public:
}; };
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn); void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn); void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
} }

View File

@@ -21,7 +21,7 @@ class visitor;
class value { class value {
public: public:
typedef std::set<user*> users_t; typedef std::vector<user*> users_t;
public: public:
// constructor // constructor
@@ -30,7 +30,7 @@ public:
// uses // uses
void add_use(user* arg); void add_use(user* arg);
users_t::iterator erase_use(user* arg); users_t::iterator erase_use(user* arg);
const std::set<user*> &get_users() { return users_; } const std::vector<user*> &get_users() { return users_; }
void replace_all_uses_with(value *target); void replace_all_uses_with(value *target);
// name // name
void set_name(const std::string &name); void set_name(const std::string &name);

View File

@@ -11,12 +11,16 @@ class value;
class instruction; class instruction;
class call_inst;
class launch_inst;
class phi_node; class phi_node;
class binary_operator; class binary_operator;
class getelementptr_inst; class getelementptr_inst;
class icmp_inst; class icmp_inst;
class fcmp_inst; class fcmp_inst;
class dequantize_inst;
class cast_inst; class cast_inst;
class trunc_inst; class trunc_inst;
class z_ext_inst; class z_ext_inst;
@@ -42,6 +46,9 @@ class masked_load_inst;
class unmasked_store_inst; class unmasked_store_inst;
class masked_store_inst; class masked_store_inst;
class extract_value_inst;
class insert_value_inst;
class retile_inst; class retile_inst;
class reshape_inst; class reshape_inst;
class splat_inst; class splat_inst;
@@ -75,6 +82,10 @@ class async_wait_inst;
class make_range_dyn; class make_range_dyn;
class make_range; class make_range;
class prefetch_s_inst; class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;
class extern_elementwise_inst;
class make_range_sta; class make_range_sta;
class undef_value; class undef_value;
@@ -103,6 +114,8 @@ public:
virtual ~visitor() {} virtual ~visitor() {}
virtual void visit_value(ir::value*); virtual void visit_value(ir::value*);
virtual void visit_call_inst(ir::call_inst*) = 0;
virtual void visit_launch_inst(ir::launch_inst*) = 0;
virtual void visit_basic_block(basic_block*) = 0; virtual void visit_basic_block(basic_block*) = 0;
virtual void visit_argument(argument*) = 0; virtual void visit_argument(argument*) = 0;
@@ -112,6 +125,7 @@ public:
virtual void visit_icmp_inst(icmp_inst*) = 0; virtual void visit_icmp_inst(icmp_inst*) = 0;
virtual void visit_fcmp_inst(fcmp_inst*) = 0; virtual void visit_fcmp_inst(fcmp_inst*) = 0;
virtual void visit_dequantize_inst(dequantize_inst*) = 0;
virtual void visit_cast_inst(cast_inst*) = 0; virtual void visit_cast_inst(cast_inst*) = 0;
virtual void visit_return_inst(return_inst*) = 0; virtual void visit_return_inst(return_inst*) = 0;
@@ -130,6 +144,9 @@ public:
virtual void visit_sin_inst(sin_inst*) = 0; virtual void visit_sin_inst(sin_inst*) = 0;
virtual void visit_log_inst(log_inst*) = 0; virtual void visit_log_inst(log_inst*) = 0;
virtual void visit_extract_value_inst(extract_value_inst*) = 0;
virtual void visit_insert_value_inst(insert_value_inst*) = 0;
virtual void visit_reshape_inst(reshape_inst*) = 0; virtual void visit_reshape_inst(reshape_inst*) = 0;
virtual void visit_splat_inst(splat_inst*) = 0; virtual void visit_splat_inst(splat_inst*) = 0;
virtual void visit_cat_inst(cat_inst*) = 0; virtual void visit_cat_inst(cat_inst*) = 0;
@@ -157,11 +174,15 @@ public:
virtual void visit_make_range(make_range*) = 0; virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0; virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0; virtual void visit_function(function*) = 0;
virtual void visit_clock_inst(clock_inst*) = 0;
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;
virtual void visit_undef_value(undef_value*) = 0; virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0; virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0; virtual void visit_constant_fp(constant_fp*) = 0;
virtual void visit_alloc_const(alloc_const*) = 0; virtual void visit_alloc_const(alloc_const*) = 0;
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
}; };
} }

View File

@@ -3,8 +3,9 @@
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_ #ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
#define _TRITON_TOOLS_THREAD_GRAPH_H_ #define _TRITON_TOOLS_THREAD_GRAPH_H_
#include "llvm/ADT/SetVector.h"
#include <map> #include <map>
#include <set>
#include <vector> #include <vector>
#include <iostream> #include <iostream>
@@ -13,21 +14,21 @@ namespace tools{
template<class node_t> template<class node_t>
class graph { class graph {
typedef std::map<node_t, std::set<node_t>> edges_t; typedef std::map<node_t, llvm::SetVector<node_t>> edges_t;
public: public:
typedef std::map<size_t, std::vector<node_t>> cmap_t; typedef std::map<size_t, std::vector<node_t>> cmap_t;
typedef std::map<node_t, size_t> nmap_t; typedef std::map<node_t, size_t> nmap_t;
private: private:
void connected_components_impl(node_t x, std::set<node_t> &nodes, void connected_components_impl(node_t x, llvm::SetVector<node_t> &nodes,
nmap_t* nmap, cmap_t* cmap, int id) const { nmap_t* nmap, cmap_t* cmap, int id) const {
if(nmap) if(nmap)
(*nmap)[x] = id; (*nmap)[x] = id;
if(cmap) if(cmap)
(*cmap)[id].push_back(x); (*cmap)[id].push_back(x);
if(nodes.find(x) != nodes.end()) { if (nodes.count(x)) {
nodes.erase(x); nodes.remove(x);
for(const node_t &y: edges_.at(x)) for(const node_t &y: edges_.at(x))
connected_components_impl(y, nodes, nmap, cmap, id); connected_components_impl(y, nodes, nmap, cmap, id);
} }
@@ -39,7 +40,7 @@ public:
cmap->clear(); cmap->clear();
if(nmap) if(nmap)
nmap->clear(); nmap->clear();
std::set<node_t> nodes = nodes_; llvm::SetVector<node_t> nodes = nodes_;
unsigned id = 0; unsigned id = 0;
while(!nodes.empty()){ while(!nodes.empty()){
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++); connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
@@ -59,7 +60,7 @@ public:
} }
private: private:
std::set<node_t> nodes_; llvm::SetVector<node_t> nodes_;
edges_t edges_; edges_t edges_;
}; };

View File

@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
return add_to_cache(x, result, is_constant_); return add_to_cache(x, result, is_constant_);
} }
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++) {
result.push_back(op_cst[d]);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) { std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x); auto x_shapes = get_shapes(x);
std::vector<cst_info> result; std::vector<cst_info> result;
@@ -129,6 +141,36 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
return add_to_cache(x, result, is_constant_); return add_to_cache(x, result, is_constant_);
} }
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax = {1, 0};
// Examples:
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
//
// if LHS is a range of N continuous (or equal) elements that starts at M,
// and RHS is a set of N constants that start at K
// then the result in constant in groups of gcd(M, K)
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
rhs[d].num_cst % lhs[d].num_cst == 0)
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) { std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x); auto x_shapes = get_shapes(x);
std::vector<cst_info> result; std::vector<cst_info> result;
@@ -136,12 +178,14 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
ir::value* rhs_op = x->get_operand(1); ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op); auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op); auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op); auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) { for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax; cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0}; ax = {num_constants, 0};
} }
else else
@@ -180,10 +224,14 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
return populate_is_constant_splat(x); return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x); return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_is_constant_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)) if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x); return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v)) if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x); return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
return populate_is_constant_cmp(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)) if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x); return populate_is_constant_gep(x);
return populate_is_constant_default(v); return populate_is_constant_default(v);
@@ -245,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
return add_to_cache(x, result, max_contiguous_); return add_to_cache(x, result, max_contiguous_);
} }
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++) {
unsigned factor = 1;
if (d == shapes.size() - 1) {
factor = ret_last_dim / op_last_dim;
}
result.push_back(factor * op_mc[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) { std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x); auto shapes = get_shapes(x);
std::vector<unsigned> result; std::vector<unsigned> result;
@@ -285,8 +350,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
} }
if(x->is_int_add_sub()){ if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1; unsigned lvalue = 1, rvalue = 1;
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]); lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]); rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue); value = std::max(lvalue, rvalue);
} }
result.push_back(value); result.push_back(value);
@@ -332,9 +397,9 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end()) if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v); return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){ if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous); std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
if(max_contiguous > 0) if(!max_contiguous.empty())
return add_to_cache(x, {max_contiguous}, max_contiguous_); return add_to_cache(x, max_contiguous, max_contiguous_);
} }
if(auto *x = dynamic_cast<ir::cast_inst*>(v)) if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x); return populate_max_contiguous_cast(x);
@@ -342,6 +407,8 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
return populate_max_contiguous_splat(x); return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x); return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_max_contiguous_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)) if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x); return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v)) if(auto *x = dynamic_cast<ir::binary_operator*>(v))
@@ -386,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
return add_to_cache(x, result, starting_multiple_); return add_to_cache(x, result, starting_multiple_);
} }
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto op_multiple = populate_starting_multiple(op);
for(size_t d = 0; d < shapes.size(); d++) {
unsigned factor = 1;
if (d == shapes.size() - 1) {
factor = ret_last_dim / op_last_dim;
}
result.push_back(factor * op_multiple[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){ std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0)); auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_); return add_to_cache(x, result, starting_multiple_);
@@ -401,7 +485,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
if(x->is_int_add_sub()) if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]); result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div()) if(x->is_int_div())
result[d] = 1; result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
if(x->is_int_rem() && rhs[d] > 1){ if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]); result[d] = gcd(lhs[d], rhs[d]);
} }
@@ -471,28 +555,42 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
return add_to_cache(v, {1}, starting_multiple_); return add_to_cache(v, {1}, starting_multiple_);
} }
unsigned get_max_multiple(int val){
if(val == 0) return 1 << 31;
if(val % 128 == 0) return 128;
if(val % 64 == 0) return 64;
if(val % 32 == 0) return 32;
if(val % 16 == 0) return 16;
if(val % 8 == 0) return 8;
if(val % 4 == 0) return 4;
if(val % 2 == 0) return 2;
return 1;
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end()) if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v); return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){ if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0) if(!multiple_of.empty())
return add_to_cache(x, {multiple_of}, starting_multiple_); return add_to_cache(x, multiple_of, starting_multiple_);
} }
if(auto *x = dynamic_cast<ir::cast_inst*>(v)) if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x); return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v)) if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x); return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v)) if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_); return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v)) if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)) if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x); return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v)) if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x); return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x); return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_starting_multiple_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)) if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x); return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v)) if(auto *x = dynamic_cast<ir::phi_node*>(v))
@@ -511,12 +609,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v); return max_contiguous_.at(v);
} }
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
return is_constant_.at(v);
}
void align::populate(ir::value *v) { void align::populate(ir::value *v) {
populate_is_constant(v); populate_is_constant(v);
populate_starting_multiple(v); populate_starting_multiple(v);
populate_max_contiguous(v); populate_max_contiguous(v);
} }
void align::run(ir::module &mod) { void align::run(ir::module &mod) {

View File

@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
} }
// Save maximum size of induced memory space // Save maximum size of induced memory space
allocated_size_ = 0; allocated_size_ = 0;
for(shared_layout* x: V) for(shared_layout* x: V){
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size()); allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
}
} }
} }

View File

@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
graph_.add_edge({i, perm[d]}, {op, d}); graph_.add_edge({i, perm[d]}, {op, d});
} }
void axes::update_graph_dequantize(ir::instruction *i) {
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
auto shapes = dequantize->get_type()->get_block_shapes();
ir::value *op = dequantize->get_operand(0);
// add edge except the last axis
for(unsigned d = 0; d < shapes.size() - 1; d ++){
graph_.add_edge({i, d}, {op, d});
}
}
void axes::update_graph_broadcast(ir::instruction *i) { void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i); auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_block_shapes(); auto shapes = broadcast->get_type()->get_block_shapes();
@@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_SPLAT: return update_graph_no_edge(i); case ir::INST_SPLAT: return update_graph_no_edge(i);
case ir::INST_CAT: return update_graph_elementwise(i, true); case ir::INST_CAT: return update_graph_elementwise(i, true);
case ir::INST_TRANS: return update_graph_trans(i); case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i); case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);

View File

@@ -23,19 +23,67 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
return std::min(std::max(x, lo), hi); return std::min(std::max(x, lo), hi);
} }
inline bool is_hmma_c(ir::value *v){ inline bool is_hmma_c(ir::value *v, int sm){
bool result = false; bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){ if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0); ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type(); ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1); ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type(); ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_fp16_ty() && result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
b_ty->get_scalar_ty()->is_fp16_ty(); (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80) ||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
sm >= 80);
} }
return result; return result;
} }
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
mma_layout::TensorCoreType mma_type;
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
ir::type* a_ty = a->get_type();
ir::type* b_ty = b->get_type();
ir::type* c_ty = v->get_type();
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
// floating point tensor cores
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
mma_type = mma_layout::FP32_FP16_FP16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
mma_type = mma_layout::FP32_BF16_BF16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
&& dot->allow_tf32()) {
mma_type = mma_layout::FP32_TF32_TF32_FP32;
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
// throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
mma_type = mma_layout::INT32_INT8_INT8_INT32;
return mma_type;
}
}
}
return mma_layout::NOT_APPLICABLE;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) { inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u); auto i = dynamic_cast<ir::io_inst*>(u);
@@ -52,12 +100,13 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
} }
} }
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
for(ir::user* u: v->get_users()){ for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u); auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v) if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
result = i; result = i;
} }
}
} }
@@ -142,7 +191,9 @@ mma_layout::mma_layout(size_t num_warps,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
analysis::align* align, target* tgt, analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) { shared_layout *layout_a, shared_layout *layout_b,
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */ /* fragments per warp */
// try to make things as square as possible to maximize data re-use // try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){ if(tgt->as_nvidia() && tgt->as_nvidia()->sm() < 80){
@@ -157,17 +208,19 @@ mma_layout::mma_layout(size_t num_warps,
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1}; rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
contig_per_thread_ = {1, 1};
order_ = {0, 1};
} }
else{ else{
fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
spw_ = {16, 8, 1}; contig_per_thread_ = {1, 2};
rep_ = {2, 2, 1}; order_ = {1, 0};
} }
order_ = {0, 1};
/* warps per tile */ /* warps per tile */
// try to make things as square as possible to maximize data re-use
wpt_ = {1, 1, 1}; wpt_ = {1, 1, 1};
// try to make warp-level tiles as square as possible to maximize data re-use
if (tgt->as_nvidia()->sm() < 80) {
std::vector<int> wpt_nm1; std::vector<int> wpt_nm1;
do{ do{
wpt_nm1 = wpt_; wpt_nm1 = wpt_;
@@ -176,6 +229,46 @@ mma_layout::mma_layout(size_t num_warps,
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_); }while(wpt_nm1 != wpt_);
} else {
bool changed = false;
// try to have a warp own entire rows of the output
// this makes it easier to fuse multiple mmas by fusing
// registers
bool one_warp_per_row = false;
for(ir::value* v: values)
for(ir::user* u: v->get_users()){
auto* dot = dynamic_cast<ir::dot_inst*>(u);
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
}
// std::cout << one_warp_per_row << std::endl;
if(one_warp_per_row){
wpt_[1] = 1;
wpt_[0] = num_warps;
}
else{
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} while(changed);
}
}
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
/* shape per block */ /* shape per block */
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
@@ -198,8 +291,6 @@ scanline_layout::scanline_layout(size_t num_warps,
bool is_dot = std::any_of(values.begin(), values.end(), bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); }); [&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
std::vector<ir::value*> ptrs; std::vector<ir::value*> ptrs;
for(ir::value *v: values) for(ir::value *v: values)
for(ir::user *usr: v->get_users()) for(ir::user *usr: v->get_users())
@@ -215,7 +306,6 @@ scanline_layout::scanline_layout(size_t num_warps,
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits)); contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
} }
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i]; size /= shape_[i];
@@ -277,11 +367,15 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
res.reset(new double_buffer_info_t{value_1, value_0, phi}); res.reset(new double_buffer_info_t{value_1, value_0, phi});
} }
static bool is_smem(ir::value* v) { static bool is_smem_in(ir::value* v, const ir::basic_block* bb) {
if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) {
if (instr->get_parent() != bb)
return false;
if (dynamic_cast<ir::copy_to_shared_inst*>(v) || if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v)) dynamic_cast<ir::masked_load_async_inst*>(v)) {
return true; return true;
else }
}
return false; return false;
} }
@@ -297,14 +391,14 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::
ir::basic_block *cbb0 = cphi->get_incoming_block(0); ir::basic_block *cbb0 = cphi->get_incoming_block(0);
ir::basic_block *cbb1 = cphi->get_incoming_block(1); ir::basic_block *cbb1 = cphi->get_incoming_block(1);
if (is_smem(c0)) { if (is_smem_in(c0, cbb0)) {
assert(cbb0 == bb0); assert(cbb0 == bb0);
values_0.push_back(c0); values_0.push_back(c0);
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) { if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
next = phi1; next = phi1;
continue; continue;
} else { } else {
if (is_smem(c1)) { if (is_smem_in(c1, cbb1)) {
value_1 = c1; value_1 = c1;
assert(cbb1 == bb1); assert(cbb1 == bb1);
return true; return true;
@@ -359,7 +453,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape, const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values, const std::vector<ir::value *> &values,
ir::type *ty, ir::type *ty,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { analysis::align* align, target *tgt, bool is_tmp)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
size_ = 0; size_ = 0;
arg_layout_ = arg; arg_layout_ = arg;
@@ -385,12 +480,35 @@ shared_layout::shared_layout(data_layout *arg,
for(ir::value* v: values){ for(ir::value* v: values){
extract_dot_use(v, dot_a, 0); extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1); extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0); extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
extract_hmma_dot_use(v, hmma_dot_b, 1); extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
} }
hmma_dot_a_ = hmma_dot_a; hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b; hmma_dot_b_ = hmma_dot_b;
// Update mma_vec
if (hmma_dot_a_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 0) // need transpose
allow_swizzle_ = false;
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 1) // need transpose
allow_swizzle_ = false;
}
// size // size
size_ = ty_->get_primitive_size_in_bits() / 8; size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_) for(auto s: shape_)
@@ -454,7 +572,8 @@ void layouts::make_graph(ir::instruction *i) {
void layouts::create(size_t id, const std::vector<ir::value*>& values) { void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end()) // if(layouts_.find(id) != layouts_.end())
// return; // return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto it_hmma_c = std::find_if(values.begin(), values.end(),
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
auto cmp = [](ir::value* x, ir::value *y) { auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
@@ -476,19 +595,61 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
ir::value *b = dot->get_operand(1); ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a))); create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b))); create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b))); layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
(shared_layout*)layouts_.at(groups_.at(a)),
(shared_layout*)layouts_.at(groups_.at(b)),
dot);
} }
else if(it_cts != values.end()){ else if(it_cts != values.end()){
ir::instruction *cts = (ir::instruction*)*it_cts; ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0); ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg))); create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
} }
else{ else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
} }
} }
// layout checkers
bool layouts::is_scanline(ir::instruction *i) {
return this->get(i->get_operand(0))->to_scanline() != nullptr;
}
bool layouts::is_coalesced_scanline(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
auto *scanline = this->get(i->get_operand(0))->to_scanline();
return scanline && scanline->get_order()[0] == red->get_axis();
}
return false;
}
bool layouts::is_mma(ir::instruction *i) {
return this->get(i->get_operand(0))->to_mma() != nullptr;
}
bool layouts::is_a100_mma(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
(red->get_axis() == 1);
}
return false;
}
void layouts::create_tmp_layout(size_t id, data_layout *arg,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
if (is_index) {
tmp_index_[i] = id;
} else {
tmp_[i] = id;
}
}
void layouts::run(ir::module &mod) { void layouts::run(ir::module &mod) {
// make graph // make graph
graph_.clear(); graph_.clear();
@@ -510,35 +671,47 @@ void layouts::run(ir::module &mod) {
// create temporaries // create temporaries
size_t id = values_.size(); size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
// std::cout << "layout: " << std::endl;
// i->print(std::cout);
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) { if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0); ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis(); distributed_layout *layout =
dynamic_cast<analysis::distributed_layout *>(get(arg));
// shape // shape
auto shapes = arg->get_type()->get_block_shapes(); auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline(); unsigned axis = red->get_axis();
shapes[axis] = layout->mts(axis); shapes[axis] =
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
// create layout // create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); id++;
tmp_[red] = id; create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
if (red->with_index()) {
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
}
} }
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){ if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val)); distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0))); distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
id++;
size_t dim = val->get_type()->get_tile_rank(); size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim); ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){ for(size_t k = 0; k < dim; k++){
shape[k] = std::max(in_layout->shape_per_cta(k), shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k)); out_layout->shape_per_cta(k));
} }
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_); auto in_ord = in_layout->get_order();
tmp_[val] = id; auto out_ord = out_layout->get_order();
int in_vec = in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
id++;
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
} }
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){ if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++; id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_); create_tmp_layout(id, nullptr, {}, {1}, atom);
tmp_[atom] = id;
} }
}); });

View File

@@ -14,40 +14,105 @@ namespace analysis{
void liveness::run(ir::module &mod) { void liveness::run(ir::module &mod) {
intervals_.clear(); intervals_.clear();
// Assigns index to each instruction std::map<ir::value*, std::set<shared_layout*>> layouts_map;
std::map<ir::value*, slot_index> indices; for(auto &x: layouts_->get_all()){
for(ir::function *fn: mod.get_function_list()){ shared_layout* layout = x.second->to_shared();
slot_index index = 0; if(!layout || layout->is_tmp())
for(ir::basic_block *block: fn->blocks()) continue;
for(ir::instruction *instr: block->get_inst_list()){ for(ir::value* v:layout->get_values()){
index += 1; layouts_map[v].insert(layout);
indices.insert({instr, index});
} }
} }
// create live intervals
std::map<ir::user*, std::set<shared_layout*>> live_in;
while(true){
bool changed = false;
ir::instruction* last_inst = nullptr;
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
// gen
std::set<shared_layout*> gen;
for(ir::value* v: i->ops())
for(shared_layout* layout: layouts_map[v])
gen.insert(layout);
// kill
std::set<shared_layout*> kill;
for(shared_layout* layout: layouts_map[i])
kill.insert(layout);
// temporaries are handled separately
if(layouts_->has_tmp(i)){
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
}
if(layouts_->has_tmp_index(i)){
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
}
// live-out
std::set<shared_layout*> live_out;
std::vector<ir::instruction*> succs = {last_inst};
if(i == i->get_parent()->get_inst_list().back())
for(ir::basic_block* succ: i->get_parent()->get_successors())
succs.push_back(succ->get_inst_list().front());
for(ir::instruction* succ: succs)
for(shared_layout* layout: live_in[succ])
if(!layout->is_tmp())
live_out.insert(layout);
// new sets
std::set<shared_layout*> live_out_minus_kill;
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
std::set<shared_layout*> new_live_in;
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
std::inserter(new_live_in, new_live_in.end()));
changed = changed || (new_live_in != live_in[i]);
live_in[i] = new_live_in;
last_inst = i;
});
if(!changed)
break;
}
// ir::for_each_instruction(mod, [&](ir::instruction* i){
// i->print(std::cout);
// std::cout << " live_in: " << live_in[i].size() << std::endl;
// });
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
slot_index index = 0;
ir::for_each_instruction(mod, [&](ir::instruction* instr){
index += 1;
indices.insert({instr, index});
});
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(layout)
intervals_[layout] = segment{INT32_MAX, 0};
}
for(auto& x: live_in)
for(shared_layout* layout: x.second)
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
for(auto& x: live_in)
for(shared_layout* layout: x.second){
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
}
for(auto &x: layouts_->get_all()) { for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared(); shared_layout* layout = x.second->to_shared();
if(!layout) if(!layout)
continue; continue;
// users // std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
} }

View File

@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
continue; continue;
ir::value* mma_dot_a = layout->hmma_dot_a(); ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b(); ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){ if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1; per_phase_[layout] = 1;
max_phase_[layout] = 1; max_phase_[layout] = 1;
@@ -27,22 +28,31 @@ void swizzle::run(ir::module &) {
} }
auto ord = layout->get_order(); auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout()); scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout) int per_phase = 1;
continue;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(in_layout)
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
else
per_phase = 1;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1; int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); per_phase_[layout] = per_phase;
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a) if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
} }
else{ else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); if (!layout->allow_swizzle()) {
max_phase_[layout] = 8 / per_phase_[layout]; per_phase_[layout] = 1;
vec_[layout] = 8; max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = per_phase;
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
} }
} }
} }

63
lib/codegen/extern_lib.cc Normal file
View File

@@ -0,0 +1,63 @@
#include "triton/codegen/extern_lib.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Type.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "triton/codegen/pass.h"
namespace triton {
namespace codegen {
std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
llvm::SMDiagnostic err;
auto mod = llvm::parseIRFile(this->path_, err, ctx);
if (!mod) {
throw std::runtime_error("Failed to load extern lib " + this->name_ +
" at " + this->path_);
}
return mod;
}
void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
std::unique_ptr<llvm::Module>& mod) {
// Set triple and data layout to match the target module
mod->setTargetTriple(llvm->getTargetTriple());
mod->setDataLayout(llvm->getDataLayout());
if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
throw std::runtime_error("Failed to link extern lib " + this->name_ +
" at " + this->path_);
}
}
void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
// Add nvvm reflect flags to llvm module
// https://llvm.org/docs/LangRef.html#module-flags-metadata
// i32 4: Override the other module.
// i32 1: Emit an error
// If both modules specify Override, but the values differ, an error
// will be emitted.
llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata* md_four =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata* md_one =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
llvm->addModuleFlag(reflect);
}
std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
const std::string& lib_path) {
if (lib_name == "libdevice") {
return std::make_unique<LibDevice>(lib_name, lib_path);
} else {
throw std::runtime_error("Unknown external library: " + lib_name);
}
}
} // namespace codegen
} // namespace triton

View File

@@ -1,4 +1,14 @@
#include "triton/codegen/pass.h" #include "triton/codegen/pass.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/axes.h"
@@ -9,6 +19,7 @@
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h" #include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/inline.h"
#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h" #include "triton/codegen/transform/pipeline.h"
@@ -16,44 +27,90 @@
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/print.h" #include "triton/ir/print.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace triton { namespace triton {
namespace codegen { namespace codegen {
static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
const ExternLibMap& target_extern_lib_map,
ir::module& ir, llvm::LLVMContext& ctx,
std::unique_ptr<llvm::Module>& llvm) {
for (const auto& iter : target_extern_lib_map) {
auto &lib_name = iter.first;
if (user_extern_lib_map.count(lib_name) != 0 &&
user_extern_lib_map.at(lib_name)->path() != "") {
// If the user specified a path for this library, use it.
user_extern_lib_map.at(lib_name)->install(ctx, llvm);
} else {
// Otherwise, use the default path.
iter.second->install(ctx, llvm);
}
}
std::set<llvm::StringRef> function_names;
for (auto& func : ir.get_function_list()) {
function_names.insert(func->get_name());
}
llvm::legacy::PassManager pass;
pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
if (function_names.count(v.getName()) != 0) {
// Preserve global functions
return true;
}
// Internalize all device functions
return false;
}));
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*llvm);
llvm::PassManagerBuilder builder;
builder.OptLevel = 3;
builder.SizeLevel = 0;
builder.populateModulePassManager(pass);
pass.run(*llvm);
}
// TODO: // TODO:
// There should be a proper pass manager there! // There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
int cc, int num_warps, int num_stages, int& shared_static) { ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
int num_warps, int num_stages, int& shared_static,
const ExternLibMap& extern_lib_map) {
// generate llvm code // generate llvm code
std::string name = ir.get_function_list()[0]->get_name(); std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx)); std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations // optimizations
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80; bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes // create passes
codegen::analysis::align align; codegen::analysis::align align;
codegen::transform::inliner inliner;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async); codegen::transform::pipeline pipeline(has_sm80, num_stages);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::disassociate disassociate; codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target); codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::transform::cts cts(&layouts, has_sm80);
codegen::analysis::liveness liveness(&layouts); codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target); codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness); codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole(target, &layouts); codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
codegen::transform::prefetch prefetch_s(target); codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); codegen::transform::membar barriers(&liveness, &layouts, &allocation,
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
target, num_warps);
// run passes // run passes
inliner.run(ir);
dce.run(ir); dce.run(ir);
peephole.run(ir); peephole.run(ir);
dce.run(ir); dce.run(ir);
pipeline.run(ir); pipeline.run(ir);
dce.run(ir); dce.run(ir);
// ir.print(std::cout);
disassociate.run(ir); disassociate.run(ir);
dce.run(ir); dce.run(ir);
align.run(ir); align.run(ir);
@@ -61,8 +118,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
layouts.run(ir); layouts.run(ir);
peephole.run(ir); peephole.run(ir);
dce.run(ir); dce.run(ir);
if (target->is_gpu()) if (target->is_gpu()) cts.run(ir);
cts.run(ir);
align.run(ir); align.run(ir);
axes.run(ir); axes.run(ir);
layouts.run(ir); layouts.run(ir);
@@ -70,8 +126,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
dce.run(ir); dce.run(ir);
align.run(ir); align.run(ir);
dce.run(ir); dce.run(ir);
if (target->is_gpu()) if (target->is_gpu()) cts.run(ir);
cts.run(ir);
dce.run(ir); dce.run(ir);
align.run(ir); align.run(ir);
axes.run(ir); axes.run(ir);
@@ -82,12 +137,25 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
axes.run(ir); axes.run(ir);
layouts.run(ir); layouts.run(ir);
swizzle.run(ir); swizzle.run(ir);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
liveness.run(ir); liveness.run(ir);
allocation.run(ir); allocation.run(ir);
prefetch_s.run(ir); prefetch_s.run(ir);
barriers.run(ir); barriers.run(ir);
// exit(1);
// ir.print(std::cout);
isel.visit(ir, *llvm); isel.visit(ir, *llvm);
shared_static = allocation.allocated_size(); shared_static = allocation.allocated_size();
if (isel.get_extern_lib_map().size() > 0) {
// If there's any extern lib calls,
// we need to link them in.
link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
}
return llvm; return llvm;
} }

File diff suppressed because it is too large Load Diff

View File

@@ -12,46 +12,11 @@ namespace triton {
namespace codegen{ namespace codegen{
namespace transform{ namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
: align_(align), layout_(layouts) { } : align_(align), layout_(layouts), has_sm80_(has_sm80) { }
// simplify layout conversions using the following simple rules:
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
// ir::value* _op = inst->get_operand(0);
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
// std::cout << 1 << std::endl;
// // i must be layout conversion instruction
// if(!mma_in && !mma_out)
// return inst;
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
// if((mma_in || mma_out) && is_op_cvt &&
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
// return op->get_operand(0);
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
// return inst;
// std::cout << 1 << std::endl;
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = inst->clone();
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
// }
// std::cout << 2 << std::endl;
// return op;
//}
void coalesce::run(ir::module &mod) { void coalesce::run(ir::module &mod) {
std::set<analysis::data_layout*> invalidated;
ir::builder& builder = mod.get_builder(); ir::builder& builder = mod.get_builder();
// add layout conversion instructions // add layout conversion instructions
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
@@ -61,23 +26,43 @@ void coalesce::run(ir::module &mod) {
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i)) if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
if(ir::value* op = i->get_operand(1)) if(ir::value* op = i->get_operand(1))
if(op->get_type()->is_block_ty()) if(op->get_type()->is_block_ty())
if(layout_->get(op)->to_mma()){ if(op->get_type()->get_tile_ranks1() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma())
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
ir::instruction* new_op = ir::cvt_layout_inst::create(op); ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i); builder.set_insert_point(i);
builder.insert(new_op); builder.insert(new_op);
i->replace_uses_of_with(op, new_op); i->replace_uses_of_with(op, new_op);
} }
// coalesce before copy_to_shared
// only necessary for sm < 80 as Ampere+ can handle reduction
// on MMA layout
if(!has_sm80_)
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
if(ir::value* op = i->get_operand(0))
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_ranks1() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma()){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
op->replace_all_uses_with(new_op);
new_op->replace_uses_of_with(new_op, op);
invalidated.insert(layout_->get(op));
}
// uncoalesce after load // uncoalesce after load
if(auto x = dynamic_cast<ir::load_inst*>(i)) if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty()) if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_rank()==2) if(x->get_type()->get_tile_ranks1()==2)
if(layout_->get(x)->to_mma()){ if(layout_->get(x)->to_mma())
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
builder.set_insert_point_after(x); builder.set_insert_point_after(x);
ir::instruction* new_x = ir::cvt_layout_inst::create(x); ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x); builder.insert(new_x);
x->replace_all_uses_with(new_x); x->replace_all_uses_with(new_x);
new_x->replace_uses_of_with(new_x, x); new_x->replace_uses_of_with(new_x, x);
// new_x->replace_uses_of_with(new_x, new_x);
} }
} }
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
@@ -90,9 +75,11 @@ void coalesce::run(ir::module &mod) {
auto out_contig = align_->contiguous(ptr); auto out_contig = align_->contiguous(ptr);
auto val_inst = dynamic_cast<ir::instruction*>(val); auto val_inst = dynamic_cast<ir::instruction*>(val);
if(!val_inst) if(!val_inst)
break; continue;
if(dynamic_cast<ir::cvt_layout_inst*>(val)) if(dynamic_cast<ir::cvt_layout_inst*>(val))
break; continue;
if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
continue;
std::vector<unsigned> in_contig; std::vector<unsigned> in_contig;
std::vector<ir::instruction*> queue = {val_inst}; std::vector<ir::instruction*> queue = {val_inst};
std::set<ir::instruction*> seen; std::set<ir::instruction*> seen;
@@ -101,6 +88,8 @@ void coalesce::run(ir::module &mod) {
ir::instruction* curr = queue.back(); ir::instruction* curr = queue.back();
seen.insert(curr); seen.insert(curr);
queue.pop_back(); queue.pop_back();
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
break;
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){ if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
in_contig = align_->contiguous(io_inst->get_pointer_operand()); in_contig = align_->contiguous(io_inst->get_pointer_operand());
break; break;

View File

@@ -1,8 +1,10 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h" #include "triton/ir/function.h"
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include <iostream> #include <iostream>
namespace triton { namespace triton {
@@ -10,9 +12,9 @@ namespace codegen{
namespace transform{ namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) { bool cts::is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT) if(i->get_id() == ir::INST_DOT)
return op==0 || op==1; return op == 0 || op == 1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED) if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0; return op==0;
if(i->get_id() == ir::INST_TRANS) if(i->get_id() == ir::INST_TRANS)
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
return false; return false;
} }
inline bool is_shmem_res(ir::value* v){ bool cts::is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v); ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i) if(!i)
return false; return false;
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
// run pass on module // run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
auto *i = dynamic_cast<ir::instruction*>(x); auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction // not an instruction
if(!i) { if(!i) {
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
// phi node // phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) { if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i) for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared); add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
return; return;
} }
// already in shared memory // already in shared memory
@@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
} }
else else
copy = builder.create_copy_from_shared(x); copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy); copies.insert({x, copy});
parent->replace_uses_of_with(x, copies.at(x));
} }
void cts::run(ir::module &mod) { void cts::run(ir::module &mod) {
// Precompute where copies should be added
std::set<ir::value*> shmem_ops;
std::set<ir::value*> shmem_res;
ir::for_each_instruction(mod, [&](ir::instruction* i) {
if(i->get_id() == ir::INST_DOT){
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
ir::value* lhs = i->get_operand(0);
ir::type* ty = lhs->get_type()->get_scalar_ty();
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
// TODO: V100
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
if(is_lhs_shmem)
shmem_ops.insert(lhs);
shmem_ops.insert(i->get_operand(1));
}
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS ||
i->get_id() == ir::INST_COPY_TO_SHARED ||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
shmem_res.insert(i);
});
// Add shared copies // Add shared copies
std::map<ir::value*, ir::value*> copies;
ir::builder &builder = mod.get_builder(); ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){ ir::for_each_instruction(mod, [&](ir::instruction* i) {
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands(); size_t num_op = i->get_num_operands();
for(size_t k = 0; k < num_op; k++){
ir::value* op = i->get_operand(k);
// copy to shared operands // copy to shared operands
for(size_t k = 0; k < num_op; k++) bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
if(is_shmem_op(i, k)){ if(is_shmem_op)
add_copy(i, i->get_operand(k), builder, true); add_copy(i, op, builder, true, copies);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
}
} }
});
} }

View File

@@ -3,6 +3,7 @@
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/utils.h" #include "triton/ir/utils.h"
#include <iostream>
namespace triton { namespace triton {
namespace codegen{ namespace codegen{
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
case ir::INST_ATOMIC_CAS: case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_RMW: case ir::INST_ATOMIC_RMW:
case ir::INST_ATOMIC_EXCH: case ir::INST_ATOMIC_EXCH:
case ir::INST_CALL:
case ir::INST_LAUNCH:
case ir::INST_BARRIER: { case ir::INST_BARRIER: {
work_list.push_back(i); work_list.push_back(i);
marked.insert(i); marked.insert(i);
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
} }
} }
// delete // delete
for(ir::instruction* i: to_delete) for(ir::instruction* i: to_delete)
i->erase_from_parent(); i->erase_from_parent();

View File

@@ -0,0 +1,147 @@
#include <iostream>
#include "triton/codegen/transform/inline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace transform{
bool fncmp::operator()(ir::function* x, ir::function* y) const {
auto fn_list = x->get_parent()->get_function_list();
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
};
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
std::list<ir::call_inst*>& callsites){
ir::basic_block* parent_block = callsite->get_parent();
ir::function* parent_fn = parent_block->get_parent();
// the parent block is split into block A and block B:
// - block A (`new_blocks[0]`) is the entry block of the inlined function
// - block B (`exit`) resumes execution of the parent function
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
ir::basic_block* exit = entry->get_successors()[0];
std::vector<ir::basic_block*> new_blocks = {entry};
for(size_t i = 1; i < fn->blocks().size(); i++){
ir::basic_block* block = fn->blocks()[i];
ir::context& ctx = block->get_context();
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
}
// a phi node holds the return values of the inlined function
if(exit->get_inst_list().empty())
builder.set_insert_point(exit);
else
builder.set_insert_point(exit->get_first_non_phi());
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
callsite->replace_all_uses_with(exit_val);
callsite->erase_from_parent();
// get arguments `fn` is called with
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
// Actually generate the instructions:
// - Remove the branch created by basic_block::split_before
// - Clone all instructions
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
terminator->erase_from_parent();
std::map<ir::instruction*, ir::instruction*> inst_map;
std::map<ir::argument*, ir::value*> arg_map;
for(size_t k = 0; k < fn->args().size(); k++)
arg_map[fn->args()[k]] = callsite->ops()[k];
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// clone instructions
for(size_t i = 0; i < new_blocks.size(); i++){
ir::basic_block* old_block = fn->blocks()[i];
ir::basic_block* new_block = new_blocks[i];
builder.set_insert_point(new_block);
for(ir::instruction* old_inst: old_block->get_inst_list()){
ir::instruction* new_inst = old_inst->clone();
inst_map[old_inst] = new_inst;
builder.insert(new_inst);
}
}
// update basic blocks
for(size_t i = 0; i < new_blocks.size(); i++) {
for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
// replace basic use cases
for(size_t k = 0; k < new_blocks.size(); k++)
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
// additionally replace basic blocks of phi-nodes since
// replace_uses_of_with() does not replace them.
for(unsigned in = 0; in < phi->get_num_incoming(); in++)
for(size_t k = 0; k < new_blocks.size(); k++)
if (phi->get_incoming_block(in) == fn->blocks()[k])
phi->set_incoming_block(in, new_blocks[k]);
}
}
}
// replace operands of instructions after constructing inst_map
for (auto& it: inst_map) {
ir::instruction* new_inst = it.second;
for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
ir::value* op = new_inst->get_operand(k);
if(auto arg_op = dynamic_cast<ir::argument*>(op))
new_inst->set_operand(k, arg_map.at(arg_op));
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// handles a ret instruciton.
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
if(ir::value* ret_val = ret->get_return_value())
exit_val->add_incoming(ret_val, new_inst->get_parent());
// replace ret with branch
ir::instruction* new_br_inst = ir::branch_inst::create(exit);
builder.set_insert_point(new_inst->get_parent());
builder.insert(new_br_inst);
new_inst->erase_from_parent();
}
}
if(exit_val->get_num_incoming() == 1)
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
// done -- make sure insert point is properly set to exit block
builder.set_insert_point(exit);
}
void inliner::run(ir::module &mod) {
// gather all call sites
while(true){
std::map<ir::function*, size_t> counts;
for(ir::function* fn: mod.get_function_list())
counts[fn] = 0;
std::list<ir::call_inst*> callsites;
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* instr: block->get_inst_list())
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
callsites.push_back(call);
counts[call->get_fn()] += 1;
}
}
for(auto& count: counts){
if(!count.first->get_is_kernel() && count.second == 0)
count.first->get_parent()->remove_function(count.first);
}
if(callsites.empty())
break;
for(ir::call_inst* call: callsites)
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
}
}
}
}
}

View File

@@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
else{ else{
if(layouts_->has_tmp(v)) if(layouts_->has_tmp(v))
return async_write.size() - 1; return async_write.size() - 1;
// // Ignore copy_to_shared. It won't modify async behavior.
// if(dynamic_cast<ir::copy_to_shared_inst*>(v))
// return 0;
auto it = std::find(async_write.begin(), async_write.end(), v); auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it); return std::distance(async_write.begin(), it);
} }
@@ -60,15 +63,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
continue; continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
for(ir::value* b: bs){ for(ir::value* b: bs){
if(!b->get_type()->is_block_ty()) if(!b->get_type()->is_block_ty())
continue; continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) || if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) || intersect_with(a_layout, b_tmp) ||
intersect_with(a_layout, b_tmp_index) ||
intersect_with(a_tmp, b_layout) || intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp)) intersect_with(a_tmp, b_tmp) ||
intersect_with(a_tmp, b_tmp_index) ||
intersect_with(a_tmp_index, b_layout) ||
intersect_with(a_tmp_index, b_tmp) ||
intersect_with(a_tmp_index, b_tmp_index))
ret.insert(b); ret.insert(b);
} }
} }

View File

@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, c) + d -> dot(a, b, c + d) // dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d) // d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value); auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && add->get_op() == ir::binary_op_t::FAdd) { if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
ir::value *lhs = add->get_operand(0); ir::value *lhs = add->get_operand(0);
ir::value *rhs = add->get_operand(1); ir::value *rhs = add->get_operand(1);
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs); ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
@@ -72,15 +73,21 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *other = (dot == lhs) ? rhs : lhs; ir::value *other = (dot == lhs) ? rhs : lhs;
ir::value *acc = dot->get_operand(2); ir::value *acc = dot->get_operand(2);
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc); ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
ir::constant_fp *_0 = nullptr; ir::constant *_0 = nullptr;
if(splat) if(splat)
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0)); _0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
if(!(_0 && _0->get_value() == 0.0)) if(!_0)
return false;
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
if (fp_0->get_value() != 0.0)
return false;
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
if (int_0->get_value() != 0)
return false; return false;
ir::value *a = dot->get_operand(0); ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1); ir::value *b = dot->get_operand(1);
builder.set_insert_point(add); builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name())); ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot); add->replace_all_uses_with(new_dot);
return true; return true;
} }
@@ -116,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
int nts = layout->nts(layout->get_order()[0]); int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){ if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier()); ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
copy_to_shared->replace_all_uses_with(new_load); copy_to_shared->replace_all_uses_with(new_load);
return true; return true;
} }
@@ -171,6 +178,27 @@ bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
return false; return false;
} }
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
if(!extracted)
return false;
size_t extract_idx = extracted->get_idx();
ir::value* agg = extracted->get_operand(0);
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
while(insert){
agg = insert->get_operand(0);
ir::value* inserted = insert->get_operand(1);
size_t insert_idx = insert->get_idx();
insert = dynamic_cast<ir::insert_value_inst*>(agg);
if(extract_idx == insert_idx){
extracted->replace_all_uses_with(inserted);
return true;
}
insert = dynamic_cast<ir::insert_value_inst*>(agg);
}
return false;
}
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) { bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
auto x = dynamic_cast<ir::getelementptr_inst*>(value); auto x = dynamic_cast<ir::getelementptr_inst*>(value);
@@ -207,7 +235,9 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(), if_value->get_mask_operand(),
select->get_else_value_op(), select->get_else_value_op(),
if_value->get_cache_modifier()); if_value->get_cache_modifier(),
if_value->get_eviction_policy(),
if_value->get_is_volatile());
select->replace_all_uses_with(new_load); select->replace_all_uses_with(new_load);
return true; return true;
} }
@@ -219,22 +249,22 @@ bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0)); ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
if(!op) if(!op)
return false; return false;
// convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) // // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
if(op->get_id() == ir::INST_BINOP){ // if(op->get_id() == ir::INST_BINOP){
for(size_t i = 0; i < op->get_num_operands(); i++){ // for(size_t i = 0; i < op->get_num_operands(); i++){
ir::value* arg_i = op->get_operand(i); // ir::value* arg_i = op->get_operand(i);
builder.set_insert_point(op); // builder.set_insert_point(op);
// create new layout transform // // create new layout transform
ir::instruction* new_arg_i = cvt->clone(); // ir::instruction* new_arg_i = cvt->clone();
layouts_->copy(new_arg_i, op); // layouts_->copy(new_arg_i, op);
builder.insert(new_arg_i); // builder.insert(new_arg_i);
// set the right args // // set the right args
new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); // new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
op->replace_uses_of_with(arg_i, new_arg_i); // op->replace_uses_of_with(arg_i, new_arg_i);
} // }
cvt->replace_all_uses_with(op); // cvt->replace_all_uses_with(op);
return true; // return true;
} // }
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op); auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
if(!cvt_op) if(!cvt_op)
return false; return false;
@@ -282,9 +312,11 @@ void peephole::run(ir::module &mod) {
was_modified = was_modified || rewrite_mult(i, builder); was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder); // was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_insert_extract(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
was_modified = was_modified || rewrite_select_masked_load(i, builder); // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
was_modified = was_modified || rewrite_cvt_layout(i, builder); was_modified = was_modified || rewrite_cvt_layout(i, builder);
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder); was_modified = was_modified || rewrite_load_to_shared(i, builder);

View File

@@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) {
ir::builder &builder = mod.get_builder(); ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_; const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){ for(auto info: to_pipeline){
ir::load_inst* load = info.load; ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr; ir::phi_node* ptr = info.ptr;
@@ -178,7 +179,7 @@ void pipeline::run(ir::module &mod) {
false_value = remat_false_value; false_value = remat_false_value;
} else } else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier()); first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
for (int stage = 1; stage < num_stages-1; ++stage) { for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration // mask is the loop condition of the previous iteration
@@ -193,7 +194,7 @@ void pipeline::run(ir::module &mod) {
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value; false_value = remat_false_value;
} }
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier()); first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
} }
// create new phis for induction variables // create new phis for induction variables
@@ -222,7 +223,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask); next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value; false_value = remat_false_value;
} }
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node // phi node
@@ -257,7 +258,7 @@ void pipeline::run(ir::module &mod) {
} }
else else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier()); ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// pre-fetch next iteration // pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back()); builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block); ir::value* next_ptr = ptr->get_value_for_block(block);
@@ -268,7 +269,7 @@ void pipeline::run(ir::module &mod) {
next_mask = builder.create_and(next_mask, remat_mask); next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value; false_value = remat_false_value;
} }
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier()); ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node // phi node
builder.set_insert_point(block->get_first_non_phi()); builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2); ir::phi_node* new_load = builder.create_phi(ty, 2);

View File

@@ -29,8 +29,16 @@ void prefetch::run(ir::module &mod) {
std::vector<ir::dot_inst*> to_prefetch; std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) { ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) { if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16 // Now only do prefetching when dot is using tensor cores
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID) if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)
return; return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0)); auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1)); auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));

View File

@@ -91,9 +91,13 @@ void* dispatch::fname ## _;
bool dispatch::cuinit(){ bool dispatch::cuinit(){
if(cuda_==nullptr){ if(cuda_==nullptr){
#ifdef _WIN32
cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
#else
cuda_ = dlopen("libcuda.so", RTLD_LAZY); cuda_ = dlopen("libcuda.so", RTLD_LAZY);
if(!cuda_) if(!cuda_)
cuda_ = dlopen("libcuda.so.1", RTLD_LAZY); cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
#endif
if(!cuda_) if(!cuda_)
throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH."); throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
} }
@@ -134,6 +138,7 @@ CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*) CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
// link management // link management
CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **);
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*); CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState); CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
@@ -176,8 +181,13 @@ CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
* NVML * NVML
* ------------------- */ * ------------------- */
bool dispatch::nvmlinit(){ bool dispatch::nvmlinit(){
#ifdef _WIN32
if(nvml_==nullptr)
nvml_ = dlopen("nvml.dll", RTLD_LAZY);
#else
if(nvml_==nullptr) if(nvml_==nullptr)
nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY); nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
#endif
nvmlReturn_t (*fptr)(); nvmlReturn_t (*fptr)();
nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2"); nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
*reinterpret_cast<void **>(&fptr) = nvmlInit_v2_; *reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;

View File

@@ -90,7 +90,7 @@ void check(CUresult err)
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted(); case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported(); case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
case CUDA_ERROR_UNKNOWN : throw unknown(); case CUDA_ERROR_UNKNOWN : throw unknown();
default : throw unknown(); default : throw std::runtime_error("unimplemented code: " + std::to_string(err));
} }
} }

View File

@@ -1,27 +1,27 @@
/* Copyright 2015-2017 Philippe Tillet /* Copyright 2015-2017 Philippe Tillet
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files * a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction, * (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge, * including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software, * publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so, * and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions: * subject to the following conditions:
* *
* The above copyright notice and this permission notice shall be * The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software. * included in all copies or substantial portions of the Software.
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
#include <fstream> #include <fstream>
#if __has_include(<unistd.h>) #if __has_include(<unistd.h>)
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <memory> #include <memory>
#include <regex> #include <regex>
@@ -49,6 +49,7 @@
#include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Scalar.h"
// begin AMD stuff // begin AMD stuff
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
@@ -61,12 +62,21 @@
#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Intrinsics.h"
// end AMD stuff // end AMD stuff
namespace triton{ extern "C"
namespace driver{ {
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
void init_llvm() { namespace triton
static bool init = false; {
if(!init){ namespace driver
{
void init_llvm()
{
LLVMInitializeNVPTXTargetInfo(); LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget(); LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXTargetMC();
@@ -75,40 +85,92 @@ void init_llvm() {
LLVMInitializeAMDGPUTarget(); LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC(); LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter(); LLVMInitializeAMDGPUAsmPrinter();
init = true;
} }
}
/* ------------------------ */ /* ------------------------ */
// CUDA // // CUDA //
/* ------------------------ */ /* ------------------------ */
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){ static bool find_and_replace(std::string &str, const std::string &begin, const std::string &end, const std::string &target)
{
size_t start_replace = str.find(begin); size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace); size_t end_replace = str.find(end, start_replace);
if(start_replace == std::string::npos) if (start_replace == std::string::npos)
return false; return false;
str.replace(start_replace, end_replace + 1 - start_replace, target); str.replace(start_replace, end_replace + 1 - start_replace, target);
return true; return true;
} }
int vptx(int version){ std::string path_to_ptxas(int &version)
if(version >= 11030) return 73; {
if(version >= 11020) return 72; std::vector<std::string> rets;
if(version >= 11010) return 71; std::string ret;
if(version >= 11000) return 70; // search paths for ptxas
if(version >= 10020) return 65; std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
if(version >= 10010) return 64; std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(version >= 10000) return 63; if (!triton_ptxas.empty())
throw std::runtime_error("Triton requires CUDA 10+"); ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
} // see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for (std::string prefix : ptxas_prefixes)
{
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if (works)
{
working_ptxas.push_back(ptxas);
rets.push_back(ret);
}
}
// error if no working ptxas was found
if (working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
bool found = false;
// currently choosing the first ptxas. Other logics can be implemented in future
for (std::string ret : rets)
{
if (std::regex_search(ret, match, version_regex))
{
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major * 1000 + minor * 10;
found = true;
break;
}
}
if (not found)
{
throw std::runtime_error("Error in parsing version");
}
return ptxas;
}
std::string llir_to_ptx(llvm::Module* module, int cc, int version){ int vptx(int version)
{
if (version >= 11040)
return 74;
// if(version >= 11030) return 73;
// if(version >= 11020) return 72;
// if(version >= 11010) return 71;
// if(version >= 11000) return 70;
// if(version >= 10020) return 65;
// if(version >= 10010) return 64;
// if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 11.4+");
}
std::string llir_to_ptx(llvm::Module *module, int cc, int version)
{
// LLVM version in use may not officially support target hardware // LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75; int max_nvvm_cc = 75;
int max_nvvm_ptx = 64; int max_nvvm_ptx = 74;
// options // options
auto options = llvm::cl::getRegisteredOptions(); auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]); auto *short_ptr = static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr); assert(short_ptr);
short_ptr->setValue(true); short_ptr->setValue(true);
// compute capability // compute capability
@@ -122,25 +184,30 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string triple = "nvptx64-nvidia-cuda"; std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc)); std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = ""; std::string layout = "";
std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)); std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
init_llvm(); init_llvm();
// verify and store llvm // verify and store llvm
llvm::legacy::PassManager pm; llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
pm.add(llvm::createVerifierPass()); pm.add(llvm::createVerifierPass());
pm.run(*module); pm.run(*module);
// module->print(llvm::outs(), nullptr);
// create machine // create machine
module->setTargetTriple(triple); module->setTargetTriple(triple);
std::string error; std::string error;
llvm::TargetMachine *machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt; llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false; opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false; opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true; opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt, machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout // set data layout
if(layout.empty()) if (layout.empty())
module->setDataLayout(machine->createDataLayout()); module->setDataLayout(machine->createDataLayout());
else else
module->setDataLayout(layout); module->setDataLayout(layout);
@@ -157,18 +224,15 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
std::string result(buffer.begin(), buffer.end()); std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", "")); while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
while(find_and_replace(result, "\t// end inline asm", "\n", "")); ;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result; return result;
} }
std::string ptx_to_cubin(const std::string& ptx, int cc) {
std::string ptxas = "ptxas";
std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
if(!use_system_ptxas)
return "";
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas, int cc)
{
// compile ptx with ptxas // compile ptx with ptxas
char _fsrc[L_tmpnam]; char _fsrc[L_tmpnam];
char _flog[L_tmpnam]; char _flog[L_tmpnam];
@@ -177,91 +241,41 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) {
std::string fsrc = _fsrc; std::string fsrc = _fsrc;
std::string flog = _flog; std::string flog = _flog;
std::string fbin = fsrc + ".o"; std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str(); const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc); std::ofstream ofs(fsrc);
ofs << ptx; ofs << ptx << std::endl;
ofs.close(); ofs.close();
std::string cmd; std::string cmd;
int err; int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str()); err = system(cmd.c_str());
CUmodule ret; if (err != 0)
std::ifstream _cubin(_fbin, std::ios::binary ); {
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {}); std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close(); _cubin.close();
dispatch::cuModuleLoadData(&ret, cubin.c_str());
unlink(_fsrc); unlink(_fsrc);
unlink(_flog); unlink(_flog);
unlink(_fbin); unlink(_fbin);
return cubin; return cubin;
}
CUmodule ptx_to_cumodule(const std::string& ptx, int cc) {
// JIT compile source-code
try{
// use ptxas if present in PATH. Otherwise, use JIT from the driver
std::string ptxas = "ptxas";
std::string version;
int use_system_ptxas = tools::exec(ptxas + " --version 2>&1", version) == 0;
// Use PTXAS via system call
if(use_system_ptxas){
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
dispatch::cuModuleLoadData(&ret, cubin.c_str());
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
return ret;
} }
// Use PTXAS included in driver /* ------------------------ */
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, // HIP //
CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, /* ------------------------ */
CU_JIT_LOG_VERBOSE};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1};
CUmodule ret;
dispatch::cuModuleLoadDataEx(&ret, ptx.data(), 5, opt, optval);
return ret;
}
catch(exception::cuda::invalid_ptx const &){
std::cout << ptx << std::endl;
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
throw;
}
}
/* ------------------------ */ std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
// HIP // {
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
init_llvm(); init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo)); // proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo)); // features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create // create
llvm::SmallVector<char, 0> buffer; llvm::SmallVector<char, 0> buffer;
@@ -350,15 +364,18 @@ std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
return hsaco_path; return hsaco_path;
} }
return hsaco_path;
}
hipModule_t amdgpu_to_hipmodule(const std::string& path) { hipModule_t amdgpu_to_hipmodule(const std::string &path)
{
// Read HSACO. // Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate); std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size); std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg); hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size); hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
hsaco_file.close(); hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer, hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer, hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
@@ -367,16 +384,13 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) {
const unsigned int logbufsize = 8192; const unsigned int logbufsize = 8192;
char _err[errbufsize]; char _err[errbufsize];
char _log[logbufsize]; char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize, void *optval[] = {(void *)(uintptr_t)errbufsize,
(void*)_err, (void*)(uintptr_t)logbufsize, (void *)_err, (void *)(uintptr_t)logbufsize,
(void*)_log, (void*)1}; (void *)_log, (void *)1};
hipModule_t ret; hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval); dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret; return ret;
} }
}
}
} // namespace driver
} // namespace triton

View File

@@ -1,3 +1,5 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/basic_block.h" #include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
@@ -9,23 +11,71 @@ namespace ir {
class phi_node; class phi_node;
basic_block::basic_block(context &ctx, const std::string &name, function *parent): basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next):
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) { value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
if(parent_) if(parent_)
parent_->insert_block(this); parent_->insert_block(this, next);
} }
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){ basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){
return new basic_block(ctx, name, parent); return new basic_block(ctx, name, parent, next);
} }
void basic_block::add_predecessor(basic_block *pred) { void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) {
preds_.push_back(pred); for(ir::instruction* i: inst_list_){
if(pred) auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
pred->succs_.push_back(this); if(!curr_phi)
break;
// curr_phi->replace_uses_of_with(before, after);
for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx)
if (curr_phi->get_incoming_block(idx) == before)
curr_phi->set_incoming_block(idx, after);
}
} }
void basic_block::append_instruction(ir::instruction* i){
i->set_parent(this);
inst_list_.push_back(i);
}
basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) {
basic_block* ret = basic_block::create(ctx_, name, parent_, this);
ret->set_name(get_name());
set_name("after_" + name);
// splice instruction list
auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc);
ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it);
for(ir::instruction* i: ret->get_inst_list())
i->set_parent(ret);
// the predecessors of `this` becomes the predecessors of `ret`
for(ir::basic_block* pred: get_predecessors()){
auto* term = dynamic_cast<ir::terminator_inst*>(pred->get_inst_list().back());
assert(term);
term->replace_uses_of_with(this, ret);
replace_phi_uses_with(pred, ret);
}
ir::branch_inst* br = branch_inst::create(this);
ret->append_instruction(br);
return ret;
}
std::vector<basic_block*> basic_block::get_predecessors() const {
std::vector<basic_block*> ret;
for(ir::user* u: users_)
if(auto term = dynamic_cast<ir::terminator_inst*>(u))
ret.push_back(term->get_parent());
return ret;
}
std::vector<basic_block*> basic_block::get_successors() const {
std::vector<basic_block*> ret;
for(ir::instruction* i: inst_list_)
for(ir::value* v: i->ops())
if(auto block = dynamic_cast<ir::basic_block*>(v))
ret.push_back(block);
return ret;
}
basic_block::iterator basic_block::get_first_non_phi(){ basic_block::iterator basic_block::get_first_non_phi(){
auto it = begin(); auto it = begin();

View File

@@ -48,10 +48,10 @@ void builder::set_insert_point(basic_block *block){
value *builder::get_int1(bool val) value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); } { return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val) value *builder::get_int32(uint32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);} { return constant_int::get(type::get_int32_ty(ctx_), val);}
value *builder::get_int64(int64_t val) value *builder::get_int64(uint64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);} { return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_float16(float val) value *builder::get_float16(float val)
@@ -87,9 +87,15 @@ type *builder::get_int32_ty()
type *builder::get_int64_ty() type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); } { return type::get_int64_ty(ctx_); }
type *builder::get_fp8_ty()
{ return type::get_fp8_ty(ctx_); }
type *builder::get_half_ty() type *builder::get_half_ty()
{ return type::get_fp16_ty(ctx_); } { return type::get_fp16_ty(ctx_); }
type *builder::get_bf16_ty()
{ return type::get_bf16_ty(ctx_); }
type *builder::get_float_ty() type *builder::get_float_ty()
{ return type::get_fp32_ty(ctx_); } { return type::get_fp32_ty(ctx_); }
@@ -102,13 +108,10 @@ type *builder::get_double_ty()
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value* builder::create_br(basic_block *dest){ value* builder::create_br(basic_block *dest){
dest->add_predecessor(block_);
return insert(branch_inst::create(dest)); return insert(branch_inst::create(dest));
} }
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){ value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
if_dest->add_predecessor(block_);
else_dest->add_predecessor(block_);
return insert(branch_inst::create(cond, if_dest, else_dest)); return insert(branch_inst::create(cond, if_dest, else_dest));
} }
@@ -116,6 +119,18 @@ value *builder::create_ret_void() {
return insert(return_inst::create(ctx_)); return insert(return_inst::create(ctx_));
} }
value *builder::create_ret(value* val) {
return insert(return_inst::create(ctx_, val));
}
//===----------------------------------------------------------------------===//
// dequantize instructions
//===----------------------------------------------------------------------===//
value* builder::create_dequantize(value *src, value *scale, value *shift, type *dst_ty){
return insert(dequantize_inst::create(src, scale, shift, dst_ty));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// cast instructions // cast instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -124,6 +139,8 @@ value *builder::create_ret_void() {
return create_cast(OPCODE, src, dst_ty);\ return create_cast(OPCODE, src, dst_ty);\
} }
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
@@ -148,6 +165,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){
return insert(phi_node::create(ty, num_reserved)); return insert(phi_node::create(ty, num_reserved));
} }
//===----------------------------------------------------------------------===//
// call instructions
//===----------------------------------------------------------------------===//
value *builder::create_call(function* fn, const std::vector<value*>& args){
return insert(call_inst::create(fn, args));
}
value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){
return insert(launch_inst::create(fn, args, grid, num_warps));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// binary float instructions // binary float instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -276,22 +306,35 @@ DEFINE_FCMP_INSTR(UNE, cmp_pred_t::FCMP_UNE)
// load/store instructions // load/store instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache){ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(unmasked_load_inst::create(ptr, cache)); return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
} }
value *builder::create_store(value *ptr, value *val){ value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
return insert(unmasked_store_inst::create(ptr, val)); return insert(unmasked_store_inst::create(ptr, val, eviction));
} }
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache){ value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache)); return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
} }
value *builder::create_masked_store(value *ptr, value *val, value *mask){ value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
return insert(masked_store_inst::create(ptr, val, mask)); return insert(masked_store_inst::create(ptr, val, mask, eviction));
} }
//===----------------------------------------------------------------------===//
// struct instructions
//===----------------------------------------------------------------------===//
// Struct instructions
value *builder::create_insert_value(value* val, value *elt, size_t idx){
return insert(insert_value_inst::create(val, elt, idx));
}
value *builder::create_extract_value(value* val, size_t idx) {
return insert(extract_value_inst::create(val, idx));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// block instructions // block instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -316,6 +359,50 @@ value *builder::create_downcast(value *arg) {
return insert(downcast_inst::create(arg)); return insert(downcast_inst::create(arg));
} }
//
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
return create_atomic_rmw(OPCODE, ptr, val, mask);\
}
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
// Utilities
value *builder::create_clock() {
return insert(clock_inst::create(ctx_));
}
value *builder::create_globaltimer() {
return insert(globaltimer_inst::create(ctx_));
}
//===----------------------------------------------------------------------===//
// externs
//===----------------------------------------------------------------------===//
value *builder::create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty) {
return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name,
lib_path, symbol_name));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// built-in instructions // built-in instructions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -332,9 +419,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
return insert(atomic_cas_inst::create(ptr, cmp, val)); return insert(atomic_cas_inst::create(ptr, cmp, val));
} }
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
value *builder::create_exp(value *arg){ value *builder::create_exp(value *arg){
return insert(exp_inst::create(arg)); return insert(exp_inst::create(arg));
@@ -352,8 +436,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg)); return insert(log_inst::create(arg));
} }
value *builder::create_dot(value *A, value *B, value *C) { value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C)); return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
} }
value *builder::create_trans(value *A, const std::vector<int>& perm) { value *builder::create_trans(value *A, const std::vector<int>& perm) {
@@ -389,8 +473,8 @@ value *builder::create_copy_from_shared(value *arg) {
return insert(copy_from_shared_inst::create(arg)); return insert(copy_from_shared_inst::create(arg));
} }
value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache) { value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction) {
return insert(masked_load_async_inst::create(ptr, mask, false_value, cache)); return insert(masked_load_async_inst::create(ptr, mask, false_value, cache, eviction));
} }
value *builder::create_barrier(const std::string &name) { value *builder::create_barrier(const std::string &name) {

View File

@@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) {
return constant_int::get(ty, 0); return constant_int::get(ty, 0);
case type::FP16TyID: case type::FP16TyID:
return constant_fp::get(type::get_fp16_ty(ctx), 0); return constant_fp::get(type::get_fp16_ty(ctx), 0);
case type::BF16TyID:
return constant_fp::get(type::get_bf16_ty(ctx), 0);
case type::FP32TyID: case type::FP32TyID:
return constant_fp::get(type::get_fp32_ty(ctx), 0); return constant_fp::get(type::get_fp32_ty(ctx), 0);
case type::FP64TyID: case type::FP64TyID:
@@ -47,10 +49,10 @@ constant_int *constant_int::get(type *ty, uint64_t value) {
if (!ty->is_integer_ty()) if (!ty->is_integer_ty())
throw std::runtime_error("Cannot create constant_int with non integer ty"); throw std::runtime_error("Cannot create constant_int with non integer ty");
context_impl *impl = ty->get_context().p_impl.get(); context_impl *impl = ty->get_context().p_impl.get();
constant_int *& cst = impl->int_constants_[std::make_pair(ty, value)]; std::unique_ptr<constant_int> &cst = impl->int_constants_[std::make_pair(ty, value)];
if(cst == nullptr) if(!cst)
cst = new constant_int(ty, value); cst.reset(new constant_int(ty, value));
return cst; return cst.get();
} }
@@ -73,10 +75,10 @@ constant *constant_fp::get_zero_value_for_negation(type *ty) {
constant *constant_fp::get(type *ty, double v){ constant *constant_fp::get(type *ty, double v){
context_impl *impl = ty->get_context().p_impl.get(); context_impl *impl = ty->get_context().p_impl.get();
constant_fp *&result = impl->fp_constants_[std::make_pair(ty, v)]; std::unique_ptr<constant_fp> &result = impl->fp_constants_[std::make_pair(ty, v)];
if(!result) if(!result)
result = new constant_fp(ty, v); result.reset(new constant_fp(ty, v));
return result; return result.get();
} }
@@ -86,10 +88,10 @@ undef_value::undef_value(type *ty)
undef_value *undef_value::get(type *ty) { undef_value *undef_value::get(type *ty) {
context_impl *impl = ty->get_context().p_impl.get(); context_impl *impl = ty->get_context().p_impl.get();
undef_value *&result = impl->uv_constants_[ty]; std::unique_ptr<undef_value> &result = impl->uv_constants_[ty];
if(!result) if(!result)
result = new undef_value(ty); result.reset(new undef_value(ty));
return result; return result.get();
} }
/* global value */ /* global value */

View File

@@ -24,9 +24,7 @@ context_impl::context_impl(context &ctx)
int16_ty(ctx, 16), int16_ty(ctx, 16),
int32_ty(ctx, 32), int32_ty(ctx, 32),
int64_ty(ctx, 64), int64_ty(ctx, 64),
int128_ty(ctx, 128){ int128_ty(ctx, 128) {}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// context // context

View File

@@ -1,796 +0,0 @@
#include "triton/ir/dispatch.h"
#include <iostream>
namespace triton{
namespace ir{
ir::value* throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
return 0;
}
//===----------------------------------------------------------------------===//
// Programming Model
//===----------------------------------------------------------------------===//
ir::value *dispatch::program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
}
ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
}
//===----------------------------------------------------------------------===//
// Implicit Casting Utilities
//===----------------------------------------------------------------------===//
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
return a_rank > b_rank ? a_ty : b_ty;
}
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty){
context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly
// converted to double
if(a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
return type::get_fp64_ty(ctx);
// 2) if one operand is float, the other is implicitly
// converted to float
if(a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
return type::get_fp32_ty(ctx);
// 3 ) if one operand is half, the other is implicitly
// converted to half
if(a_ty->is_fp16_ty() || b_ty->is_fp16_ty())
return type::get_fp16_ty(ctx);
if(a_ty->is_bf16_ty() || b_ty->is_bf16_ty())
return type::get_bf16_ty(ctx);
if(!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("augment_types");
// 4 ) both operands are integer and undergo
// integer promotion
return integer_promote(a_ty, b_ty);
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
void throw_incompatible_types(ir::type* type_a, ir::type* type_b) {
throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr());
}
void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
if(type_a->is_pointer_ty()){
if(!allow_ptr_a)
throw_incompatible_types(type_a, type_b);
// T* + U* with T != U
if(type_b->is_pointer_ty() && (type_a != type_b))
throw_incompatible_types(type_a, type_b);
// T* + float
if(type_b->is_floating_point_ty())
throw_incompatible_types(type_a, type_b);
}
}
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
bool arithmetic_check = true){
// implicit broadcasting
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
// implicit typecasting
ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty();
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
if(arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()){
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty);
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
}
}
ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, true);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// offset + ptr
// ptr + offset
if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty())
std::swap(input, other);
if (input_scalar_ty->is_pointer_ty())
return builder->create_gep(input, {other});
// float + float
else if (input_scalar_ty->is_floating_point_ty())
return builder->create_fadd(input, other);
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
return throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, false);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// ptr - offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(input, {dispatch::minus(other, builder)});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(input, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
return throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(input, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
return throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float / int
if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty())
other = cast(other, input_scalar_ty, builder);
// int / float
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty())
input = cast(input, other_scalar_ty, builder);
// int / int (cast to float32)
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
input = cast(input, builder->get_float_ty(), builder);
other = cast(other, builder->get_float_ty(), builder);
}
// float / float (cast to highest exponent type)
else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){
if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width())
other = cast(other, input_scalar_ty, builder);
else
input = cast(input, other_scalar_ty, builder);
}
// unreachable
else
return throw_unreachable("div");
return builder->create_fdiv(input, other);
}
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
return builder->create_sdiv(input, other);
}
return throw_unreachable("floordiv");
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty())
return builder->create_srem(input, other);
return throw_unreachable("mod");
}
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder, bool force_lhs_type = false){
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
// for some reason pytorch assigns the result of binary op to have the type of the lhs...
if(force_lhs_type){
if(input_sca_ty->get_integer_bitwidth() != other_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
}
else{
if(input_sca_ty->get_integer_bitwidth() < other_sca_ty->get_integer_bitwidth())
input = dispatch::cast(input, other_sca_ty, builder);
else if(other_sca_ty->get_integer_bitwidth() < input_sca_ty->get_integer_bitwidth())
other = dispatch::cast(other, input_sca_ty, builder);
}
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true);
return builder->create_and(input, other);
}
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true);
return builder->create_or(input, other);
}
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, true);
return builder->create_xor(input, other);
}
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, false);
return builder->create_lshr(input, other);
}
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder, false);
return builder->create_shl(input, other);
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::plus(ir::value *input, ir::builder *) {
return input;
}
ir::value *dispatch::minus(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty())
throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")");
ir::value *_0 = ir::constant::get_null_value(input_sca_ty);
return dispatch::sub(_0, input, builder);
}
ir::value *dispatch::invert(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty())
throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")");
ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty);
return dispatch::xor_(input, _1, builder);
}
//===----------------------------------------------------------------------===//
// Comparison Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGT(input, other);
return throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSGE(input, other);
return throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLT(input, other);
return throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpSLE(input, other);
return throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
return throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpUNE(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
return throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
// Block Creation
//===----------------------------------------------------------------------===//
ir::value* dispatch::arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
}
ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) {
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
}
//===----------------------------------------------------------------------===//
// Shape Manipulation
//===----------------------------------------------------------------------===//
ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) {
unsigned numel = 1;
for(unsigned s: dst_shape) numel *= s;
if(input->get_type()->get_tile_num_elements() != numel)
throw semantic_error("cannot reshape block of different shape");
return builder->create_reshape(input, dst_shape);
}
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
return builder->create_cat(lhs, rhs);
}
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
if(shape == src_shape)
return input;
return builder->create_broadcast(input, shape);
}
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
return cast(input, dst_ty, builder);
// Bitcast
int src_bits = src_sca_ty->get_primitive_size_in_bits();
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
if( src_bits!= dst_bits)
throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) +
"to data-type of size " + std::to_string(dst_bits));
return builder->create_cast(ir::BitCast, input, dst_ty);
}
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth())
return builder->create_int_cast(input, dst_ty, src_sca_ty != builder->get_int1_ty());
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
return builder->create_fp_to_ui(input, dst_ty);
else
return builder->create_fp_to_si(input, dst_ty);
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if(src_sca_ty->is_bool_ty())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
}
if (src_sca_ty->is_pointer_ty() && !dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::PtrToInt, input, dst_ty);
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::IntToPtr, input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, builder->get_int64_ty(), builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
return throw_unreachable("cast");
}
//===----------------------------------------------------------------------===//
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(other){
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
}
}
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
if (!cache_modifier.empty()) {
if (cache_modifier == ".ca")
cache = load_inst::CA;
else if (cache_modifier == ".cg")
cache = load_inst::CG;
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
if (!mask && !other)
return builder->create_load(ptr, cache);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes();
if(!other){
other = ir::undef_value::get(elt_ty);
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other, cache);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty())
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cast to target data-type
#ifdef USE_ROCM
ir::type *src_ty = val->get_type();
ir::type *dst_ty = elt_ty;
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, src_ty->get_block_shapes());
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
// check if truncation is need
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp && elt_ty->is_fp16_ty())
{
std::cout << "WARNING: "<<"casting down to fp16 is broken on ROCM" << std::endl;
}
val = dispatch::cast(val, elt_ty, builder);
#else
val = dispatch::cast(val, elt_ty, builder);
#endif
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
throw semantic_error("Mask must have boolean scalar type");
return builder->create_masked_store(ptr, val, mask);
}
ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){
return builder->create_atomic_cas(ptr, cmp, val);
}
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(val){
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
}
}
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
}
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty())
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty())
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
return builder->create_atomic_rmw(op, ptr, val, mask);
}
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
}
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
}
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
}
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask);
}
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
lhs = dispatch::cast(lhs, builder->get_float_ty(), builder);
rhs = dispatch::cast(rhs, builder->get_float_ty(), builder);
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
}
//===----------------------------------------------------------------------===//
// Indexing
//===----------------------------------------------------------------------===//
ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){
condition = dispatch::cast(condition, builder->get_int1_ty(), builder);
if(condition->get_type()->is_block_ty()){
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
}
if(x->get_type()->get_scalar_ty() != y->get_type()->get_scalar_ty())
throw_incompatible_types(x->get_type()->get_scalar_ty(), y->get_type()->get_scalar_ty());
return builder->create_select(condition, x, y);
}
//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// input is extended to 32-bits if necessary
// this increases numerical accuracy and can be done pretty much for free
// on GPUs
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
return throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
}
ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
}
ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
binary_op_type_checking(x, y, builder);
return builder->insert(umulhi_inst::create(x, y));
}
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
return builder->create_exp(x);
}
ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
return builder->create_log(x);
}
ir::value *dispatch::cos(ir::value *x, ir::builder *builder) {
return builder->create_cos(x);
}
ir::value *dispatch::sin(ir::value *x, ir::builder *builder) {
return builder->create_sin(x);
}
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("multiple_of");
i->set_metadata(ir::metadata::multiple_of, value);
return i;
}
ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("max_contiguous");
i->set_metadata(ir::metadata::max_contiguous, value);
return i;
}
ir::value *dispatch::debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
}
}

View File

@@ -33,8 +33,10 @@ void argument::accept(visitor *v) {
/* function */ /* function */
function::function(function_type *ty, linkage_types_t linkage, function::function(function_type *ty, linkage_types_t linkage,
const std::string &name, module *parent) const std::string &name, module *parent)
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) { : global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) {
unsigned num_params = fn_ty_->get_num_params(); unsigned num_params = fn_ty_->get_num_params();
if(parent)
parent->push_function(this);
// skip if no parameter // skip if no parameter
if(num_params == 0) if(num_params == 0)
return; return;
@@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage,
type *param_ty = fn_ty_->get_param_ty(i); type *param_ty = fn_ty_->get_param_ty(i);
args_[i] = argument::create(param_ty, "", this, i); args_[i] = argument::create(param_ty, "", this, i);
} }
if(parent)
parent->push_function(this);
} }
/* basic block */ /* basic block */

View File

@@ -5,6 +5,7 @@
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/constant.h" #include "triton/ir/constant.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/ir/function.h"
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -68,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
// Add incoming // Add incoming
void phi_node::add_incoming(value *v, basic_block *block){ void phi_node::add_incoming(value *v, basic_block *block){
assert(v && "PHI node got a null value!!");
resize_ops(get_num_operands() + 1); resize_ops(get_num_operands() + 1);
blocks_.resize(get_num_operands() + 1); blocks_.resize(get_num_operands() + 1);
set_incoming_value(get_num_operands() - 1, v); set_incoming_value(get_num_operands() - 1, v);
@@ -79,6 +81,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
return new phi_node(ty, num_reserved, name, next); return new phi_node(ty, num_reserved, name, next);
} }
//===----------------------------------------------------------------------===//
// call_inst classes
//===----------------------------------------------------------------------===//
std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
: instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
for(size_t i = 0; i < values.size(); i++)
set_operand(i, values.at(i));
}
call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
return new call_inst(fn, values, name, next);
}
// launch
launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
: instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
int k = 0;
if(grid.size() != 3)
throw std::runtime_error("grid must have 3 elements");
set_operand(k++, fn);
val_begin = k;
for(ir::value* v: values)
set_operand(k++, v);
val_end = k;
grid_begin = k;
for(ir::value* g: grid)
set_operand(k++, g);
grid_end = k;
set_operand(k++, num_warps);
}
ir::function* launch_inst::get_fn() {
return (ir::function*)get_operand(0);
}
std::vector<ir::value*> launch_inst::get_values() {
std::vector<ir::value*> ret;
for(int i = val_begin; i < val_end; i++)
ret.push_back(get_operand(i));
return ret;
}
std::vector<ir::value*> launch_inst::get_grid() {
std::vector<ir::value*> ret;
for(int i = grid_begin; i < grid_end; i++)
ret.push_back(get_operand(i));
return ret;
}
ir::value* launch_inst::get_num_warps() {
return get_operand(grid_end);
}
launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
return new launch_inst(fn, values, grid, num_warps, name, next);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// binary_operator classes // binary_operator classes
@@ -134,7 +200,7 @@ bool binary_operator::is_int_add_sub() const {
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, INST_BINOP, 2, name, next), op_(op){ : instruction(ty, INST_BINOP, 2, name, next), op_(op), fdiv_ieee_rnd_(false){
set_operand(0, lhs); set_operand(0, lhs);
set_operand(1, rhs); set_operand(1, rhs);
} }
@@ -232,6 +298,7 @@ icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){ icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_int_predicate(pred)); assert(is_int_predicate(pred));
assert(lhs->get_type() == rhs->get_type());
type *res_ty = make_cmp_result_type(lhs->get_type()); type *res_ty = make_cmp_result_type(lhs->get_type());
return new icmp_inst(res_ty, pred, lhs, rhs, name, next); return new icmp_inst(res_ty, pred, lhs, rhs, name, next);
} }
@@ -256,6 +323,21 @@ unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &nam
set_operand(0, v); set_operand(0, v);
} }
//===----------------------------------------------------------------------===//
// dequantize_inst classes
//===----------------------------------------------------------------------===//
dequantize_inst::dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next)
: instruction(ty, INST_DEQUANTIZE, 3, name, next) {
set_operand(0, v);
set_operand(1, scale);
set_operand(2, shift);
}
dequantize_inst *dequantize_inst::create(value *arg, value *scale, value *shift, type *ty, const std::string &name, instruction *next){
return new dequantize_inst(ty, arg, scale, shift, name, next);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// cast_inst classes // cast_inst classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -323,7 +405,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
// return_inst // return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next) return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){ : terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
if(ret_val) if(ret_val)
set_operand(0, ret_val); set_operand(0, ret_val);
} }
@@ -428,13 +510,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// io_inst // io_inst
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next) : instruction(ty, id, num_ops, name, next), eviction_(eviction)
{ } { }
// load_inst // load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache) : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
{ } { }
// load // load
@@ -447,77 +529,110 @@ type *load_inst::get_pointee_type(type *ty) {
} }
// unmasked_load // unmasked_load
unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) unmasked_load_inst::unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache,load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, name, next) { : load_inst(ptr, INST_UNMASKED_LOAD, 1, cache, eviction, is_volatile, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
} }
unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next) { unmasked_load_inst* unmasked_load_inst::create(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, cache, name, next); return new unmasked_load_inst(ptr, cache, eviction, is_volatile, name, next);
} }
// masked load // masked load
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD, 3, cache, name, next) { : load_inst(ptr, INST_MASKED_LOAD, 3, cache, eviction, is_volatile, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
set_operand(1, mask); set_operand(1, mask);
set_operand(2, false_value); set_operand(2, false_value);
} }
masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new masked_load_inst(ptr, mask, false_value, cache, name, next); return new masked_load_inst(ptr, mask, false_value, cache, eviction, is_volatile, name, next);
} }
// masked load async // masked load async
masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, name, next) { : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, cache, eviction, false, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
set_operand(1, mask); set_operand(1, mask);
set_operand(2, false_value); set_operand(2, false_value);
} }
masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new masked_load_async_inst(ptr, mask, false_value, cache, name, next); return new masked_load_async_inst(ptr, mask, false_value, cache, eviction, name, next);
} }
// store // store
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next) : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
{ } { }
// unmasked_store // unmasked_store
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) { : store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
set_operand(1, val); set_operand(1, val);
} }
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new unmasked_store_inst(ptr, val, name, next); return new unmasked_store_inst(ptr, val, eviction, name, next);
} }
// masked store // masked store
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) { : store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
set_operand(0, ptr); set_operand(0, ptr);
set_operand(1, val); set_operand(1, val);
set_operand(2, mask); set_operand(2, mask);
} }
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
return new masked_store_inst(ptr, val, mask, name, next); const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, eviction, name, next);
} }
//===----------------------------------------------------------------------===//
// struct classes
//===----------------------------------------------------------------------===//
// insert value
insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
: instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
set_operand(0, val);
set_operand(1, elt);
}
insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
return new insert_value_inst(val, elt, idx, name, next);
}
// extract value
extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
: instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
set_operand(0, val);
}
extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
return new extract_value_inst(val, idx, name, next);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// retile_inst classes // retile_inst classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -572,44 +687,48 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next); return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// matmul_inst classes // matmul_inst classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next) const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) { : builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
set_operand(0, A); set_operand(0, A);
set_operand(1, B); set_operand(1, B);
set_operand(2, C); set_operand(2, C);
allow_tf32_ = allow_tf32;
} }
instruction *dot_inst::create(value *A, value *B, value *C, instruction *dot_inst::create(value *A, value *B, value *C,
bool AT, bool BT, bool AT, bool BT, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
TransT OPA = AT ? Trans : NoTrans; TransT OPA = AT ? Trans : NoTrans;
TransT OPB = BT ? Trans : NoTrans; TransT OPB = BT ? Trans : NoTrans;
return new dot_inst(A, B, C, OPA, OPB, name, next); return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
} }
instruction *dot_inst::create_nn(value *A, value *B, value *C, instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next); return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
} }
instruction *dot_inst::create_nt(value *A, value *B, value *C, instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, Trans, name, next); return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
} }
instruction *dot_inst::create_tn(value *A, value *B, value *C, instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, NoTrans, name, next); return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
} }
instruction *dot_inst::create_tt(value *A, value *B, value *C, instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) { const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, Trans, name, next); return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -857,8 +976,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
} }
// barrier // barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name, barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
instruction *next)
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { } : instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) { barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
@@ -877,27 +995,44 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
return new prefetch_s_inst(ctx, arg, inc, name, next); return new prefetch_s_inst(ctx, arg, inc, name, next);
} }
//// nv_dynamic_program_idx // global timer
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } : instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next); return new globaltimer_inst(ctx, name, next);
//} }
//// nv_static_program_idx // extern elementwise
//make_range_sta::make_range_sta(make_range *range) extern_elementwise_inst::extern_elementwise_inst(
// : constant(range->get_type(), 0), range_(range) { } context &ctx, const std::vector<value *> &args, type *ret_ty,
const std::string &lib_name, const std::string &lib_path,
const std::string &symbol_name, const std::string &name, instruction *next)
: instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), name, next),
lib_name_(lib_name),
lib_path_(lib_path),
symbol_name_(symbol_name) {
for (size_t i = 0; i < args.size(); i++) {
set_operand(i, args[i]);
}
}
//make_range* make_range_sta::get_range() const extern_elementwise_inst *extern_elementwise_inst::create(
//{ return range_; } context &ctx, const std::vector<value *> &args, type *ret_ty,
const std::string &lib_name, const std::string &lib_path,
const std::string &symbol_name, const std::string &name,
instruction *next) {
return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
symbol_name, name, next);
}
//make_range_sta* make_range_sta::get(make_range* range) { // clock
// static std::map<make_range*, make_range_sta*> cache; clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
// if(cache.find(range) == cache.end()) : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range); clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
//} return new clock_inst(ctx, name, next);
}
// make_range // make_range
@@ -920,7 +1055,5 @@ const constant_int* make_range::get_last() const {
return last_; return last_;
} }
} }
} }

View File

@@ -3,10 +3,10 @@
namespace triton{ namespace triton{
namespace ir{ namespace ir{
metadata::metadata(kind_t kind, unsigned value) metadata::metadata(kind_t kind, std::vector<unsigned> value)
: kind_(kind), value_(value) { } : kind_(kind), value_(value) { }
metadata* metadata::get(kind_t kind, unsigned value) { metadata* metadata::get(kind_t kind, std::vector<unsigned> value) {
return new metadata(kind, value); return new metadata(kind, value);
} }

View File

@@ -9,151 +9,16 @@
namespace triton{ namespace triton{
namespace ir{ namespace ir{
/* Module */ void module::reset_ret_ty(const std::string& name, type* ty) {
module::module(const std::string &name, builder &builder) get_function(name)->get_fn_type()->reset_ret_ty(ty);
: name_(name), builder_(builder) {
sealed_blocks_.insert(nullptr);
}
ir::builder& module::get_builder() {
return builder_;
}
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
if(auto *x = dynamic_cast<ir::instruction*>(value))
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
// value->set_name(name);
}
void module::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
void module::set_const(const std::string& name){
const_.insert(name);
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
builder_.set_insert_point(insert);
}
ir::phi_node *res = builder_.create_phi(ty, num_values);
if(insert != block->end())
builder_.set_insert_point(block);
return res;
}
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
[phi](ir::value* op){ return op != phi && op; });
// non-trivial
if(non_self_ref.size() != 1)
return phi;
// unique value or self-reference
ir::value *same = *non_self_ref.begin();
assert(same != nullptr);
phi->replace_all_uses_with(same);
phi->erase_from_parent();
std::set<ir::user*> users = phi->get_users();
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
try_remove_trivial_phis(uphi);
return same;
}
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
// already initialized
if(phi->get_num_operands())
return phi;
ir::basic_block *block = phi->get_parent();
for(ir::basic_block *pred: block->get_predecessors()){
ir::value *value = get_value(name, pred);
phi->add_incoming(value, pred);
}
return phi;
}
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *result;
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
ir::type *ty = types_.at(name);
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
else if(preds.size() <= 1){
bool has_pred = preds.size();
result = get_value(name, has_pred?preds.front():nullptr);
}
else{
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
result = try_remove_trivial_phis(phi);
}
set_value(name, block, result);
return result;
}
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
ir::basic_block* save_block = builder_.get_insert_block();
ir::basic_block::iterator save_pt = builder_.get_insert_point();
val_key_t key(name, block);
if(values_.find(key) != values_.end()){
return values_.at(key);
}
ir::value *result = get_value_recursive(name, block);
builder_.set_insert_point(save_block);
if(save_pt != save_block->end())
builder_.set_insert_point(save_pt);
return result;
}
ir::value *module::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
const std::string& module::get_name() {
return name_;
}
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
if(get_value(x.first) == x.second)
set_value(x.first, try_remove_trivial_phis(x.second));
}
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();
} }
/* functions */ /* functions */
function *module::get_or_insert_function(const std::string &name, function_type *ty) { function *module::get_or_insert_function(const std::string &name, function_type *ty) {
function *&fn = (function*&)symbols_[name]; function *&fn = (function*&)symbols_[name];
if(fn == nullptr) if(fn == nullptr){
return fn = function::create(ty, global_value::external, name, this); fn = function::create(ty, global_value::external, name, this);
}
return fn; return fn;
} }

View File

@@ -92,7 +92,7 @@ public:
//------------------------- //-------------------------
void SlotTracker::process_module() { void SlotTracker::process_module() {
// Nothing to do at the moment. // Nothing to do at the moment.
// Create slots for global variable & unamed functions & ... // Create slots for global variable & unnamed functions & ...
module_processed = true; module_processed = true;
} }

View File

@@ -27,7 +27,7 @@ unsigned type::get_primitive_size_in_bits() const {
case BF16TyID: return 16; case BF16TyID: return 16;
case FP32TyID: return 32; case FP32TyID: return 32;
case FP64TyID: return 64; case FP64TyID: return 64;
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth(); case IntegerTyID: return std::max<int>(8, ((integer_type*)(this))->get_bitwidth());
case BlockTyID: return ((block_type*)(this))->get_bitwidth(); case BlockTyID: return ((block_type*)(this))->get_bitwidth();
default: return 0; default: return 0;
} }
@@ -153,10 +153,10 @@ pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
// look-up // look-up
context_impl *impl = elt_ty->get_context().p_impl.get(); context_impl *impl = elt_ty->get_context().p_impl.get();
pointer_type *&entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; std::unique_ptr<pointer_type> &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
if(!entry) if(!entry)
entry = new pointer_type(elt_ty, address_space); entry.reset(new pointer_type(elt_ty, address_space));
return entry; return entry.get();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -174,7 +174,26 @@ bool composite_type::index_valid(value *idx) const{
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// tile_type class // struct_type class
//===----------------------------------------------------------------------===//
struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed)
: composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) {
contained_tys_ = tys;
}
struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) {
assert(tys.size());
context_impl* impl = tys[0]->get_context().p_impl.get();
struct_type *& entry = impl->struct_tys[tys];
if(!entry)
entry = new struct_type(tys, is_packed);
return entry;
}
//===----------------------------------------------------------------------===//
// block_type class
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
block_type::block_type(type *ty, const block_shapes_t &shapes) block_type::block_type(type *ty, const block_shapes_t &shapes)
@@ -203,10 +222,10 @@ block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
// look-up // look-up
context_impl *impl = elt_ty->get_context().p_impl.get(); context_impl *impl = elt_ty->get_context().p_impl.get();
block_type *&entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; std::unique_ptr<block_type> &entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
if(!entry) if(!entry)
entry = new block_type(elt_ty, shapes); entry.reset(new block_type(elt_ty, shapes));
return entry; return entry.get();
} }
block_type* block_type::get_same_shapes(type *ty, type *ref){ block_type* block_type::get_same_shapes(type *ty, type *ref){

View File

@@ -43,6 +43,15 @@ std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
return result; return result;
} }
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::post_order(fn)){
auto inst_list = block->get_inst_list();
for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++)
do_work(*it);
}
}
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) { void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::reverse_post_order(fn)) for(ir::basic_block *block: cfg::reverse_post_order(fn))

View File

@@ -1,5 +1,6 @@
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <algorithm>
#include "triton/ir/value.h" #include "triton/ir/value.h"
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
@@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){
} }
void value::add_use(user *arg) { void value::add_use(user *arg) {
users_.insert(arg); users_.push_back(arg);
} }
value::users_t::iterator value::erase_use(user *arg){ value::users_t::iterator value::erase_use(user *arg){
auto it = users_.find(arg); auto it = std::find(users_.begin(), users_.end(), arg);
if(it == users_.end()) if(it == users_.end())
return it; return it;
return users_.erase(it); return users_.erase(it);

View File

@@ -1,4 +1,5 @@
import torch import torch
import triton import triton
# ------------------------------- # -------------------------------
@@ -8,17 +9,17 @@ import triton
nt = {False: 'n', True: 't'} nt = {False: 'n', True: 't'}
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N', 'K'], x_names=['M', 'N', 'K'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg='block',
line_vals = [16, 32, 64, 128], line_vals=[16, 32, 64, 128],
line_names = ['Block16', 'Block32', 'Block64', 'Block128'], line_names=['Block16', 'Block32', 'Block64', 'Block128'],
ylabel = 'TFLOPS', ylabel='TFLOPS',
plot_name = f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}', plot_name=f'{op_mode}-{layout_mode}-square-{nt[AT]}{nt[BT]}',
args = {'layout_mode': layout_mode, 'op_mode': op_mode, args={'layout_mode': layout_mode, 'op_mode': op_mode,
'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'} 'AT': AT, 'BT': BT, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for AT in [False] for BT in [False] \ for AT in [False] for BT in [False]
for op_mode in ['dsd'] for layout_mode in ['dense'] for op_mode in ['dsd'] for layout_mode in ['dense']
] ]
@@ -27,7 +28,7 @@ square_confs = [
def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000): def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider, warmup=100, rep=1000):
Z, H = 1, 1 Z, H = 1, 1
make_layout = { make_layout = {
'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),\ 'tril': lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)),
'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64),
}[layout_mode] }[layout_mode]
# create layout # create layout
@@ -39,16 +40,16 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
# create op # create op
tflops = lambda ms: num_flops / ms * 1e3 tflops = lambda ms: num_flops / ms * 1e3
if provider == 'triton': if provider == 'triton':
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT) op = triton.ops.blocksparse.matmul(layout, block, op_mode, device="cuda", trans_a=AT, trans_b=BT)
# inputs # inputs
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a, b), warmup=warmup, rep=rep)
num_flops = { num_flops = {
'sdd': 2 * Z * K * float(layout.sum()) * block * block,\ 'sdd': 2 * Z * K * float(layout.sum()) * block * block,
'dsd': 2 * Z * N * float(layout.sum()) * block * block,\ 'dsd': 2 * Z * N * float(layout.sum()) * block * block,
'dds': 2 * Z * M * float(layout.sum()) * block * block 'dds': 2 * Z * M * float(layout.sum()) * block * block
}[op_mode]*1e-12 }[op_mode] * 1e-12
return tflops(mean_ms), tflops(min_ms), tflops(max_ms) return tflops(mean_ms), tflops(min_ms), tflops(max_ms)
@@ -58,15 +59,15 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
square_confs = [ square_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['M', 'N'], x_names=['M', 'N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144],
line_arg = 'block', line_arg='block',
line_vals = [16, 32, 64], line_vals=[16, 32, 64],
line_names = ['Block16', 'Block32', 'Block64'], line_names=['Block16', 'Block32', 'Block64'],
ylabel = 'GBPS', ylabel='GBPS',
plot_name = f'{layout_mode}-square', plot_name=f'{layout_mode}-square',
args = {'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'} args={'layout_mode': layout_mode, 'dtype': torch.float16, 'provider': 'triton'}
)\ )
for layout_mode in ['dense', 'tril'] for layout_mode in ['dense', 'tril']
] ]
@@ -82,7 +83,7 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda') a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
if provider == 'triton': if provider == 'triton':
a = triton.testing.sparsify_tensor(a, layout, block) a = triton.testing.sparsify_tensor(a, layout, block)
op = triton.ops.blocksparse.softmax(layout, block) op = triton.ops.blocksparse.softmax(layout, block, device="cuda")
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms) return gbps(mean_ms), gbps(min_ms), gbps(max_ms)

View File

@@ -1,17 +1,18 @@
import torch import torch
import triton import triton
confs = [ confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names = ['N'], x_names=['N'],
x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192], x_vals=[128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192],
line_arg = 'provider', line_arg='provider',
line_vals = ['triton', 'torch'], line_vals=['triton', 'torch'],
line_names = ['Triton', 'Torch'], line_names=['Triton', 'Torch'],
ylabel = 'GBPS', ylabel='GBPS',
plot_name = f'{mode}-2048', plot_name=f'{mode}-2048',
args = {'M': 2048, 'dtype': torch.float16, 'mode': mode} args={'M': 2048, 'dtype': torch.float16, 'mode': mode}
)\ )
for mode in ['forward', 'backward'] for mode in ['forward', 'backward']
] ]
@@ -24,7 +25,7 @@ def bench_op(M, N, dtype, mode, provider):
num_gb = (2 * x.numel() * x.element_size() * 1e-9) num_gb = (2 * x.numel() * x.element_size() * 1e-9)
gbps = lambda ms: num_gb / ms * 1e3 gbps = lambda ms: num_gb / ms * 1e3
# forward pass # forward pass
op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'), \ op = {'torch': torch.nn.CrossEntropyLoss(reduction='none'),
'triton': triton.ops.cross_entropy}[provider] 'triton': triton.ops.cross_entropy}[provider]
if mode == 'forward': if mode == 'forward':
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx)) mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(x, idx))

View File

@@ -1,11 +1,11 @@
import triton
import torch import torch
import os
import triton
def rounded_linspace(low, high, steps, div): def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps) ret = torch.linspace(low, high, steps)
ret = (ret.int() + div - 1) // div * div ret = torch.div(ret.int() + div - 1, div, rounding_mode='trunc') * div
ret = torch.unique(ret) ret = torch.unique(ret)
return list(map(int, ret)) return list(map(int, ret))
@@ -29,15 +29,15 @@ square_confs = [
transformer_confs = [ transformer_confs = [
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=[x], x_names=[x],
x_vals = rounded_linspace(NK//16, NK, 32, 128), x_vals=rounded_linspace(NK // 16, NK, 32, 128),
line_arg="provider", line_arg="provider",
line_vals=["cublas", "triton", "cutlass"], line_vals=["cublas", "triton", "cutlass"],
line_names=["cuBLAS", "Triton", "CUTLASS"], line_names=["cuBLAS", "Triton", "CUTLASS"],
ylabel="TFLOPS", ylabel="TFLOPS",
plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}",
args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} args={"M": M, 'NK'.replace(x, ''): NK, "AT": False, "BT": False, "dtype": torch.float16}
) for NK in [12288]\ ) for NK in [12288]
for i, x in enumerate(["N", "K"])\ for i, x in enumerate(["N", "K"])
for M in [2048] for M in [2048]
] ]
@@ -46,9 +46,10 @@ transformer_confs = [
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75): def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
if AT: a = a.t() if AT:
if BT: b = b.t() a = a.t()
num_flops = 2 * M * N * K if BT:
b = b.t()
tflops = lambda ms: 2. * M * N * K / ms * 1e-9 tflops = lambda ms: 2. * M * N * K / ms * 1e-9
if provider == "cublas": if provider == "cublas":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
@@ -61,6 +62,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=25, rep=75):
try: try:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep) ms, min_ms, max_ms = triton.testing.do_bench(lambda: cutlass_matmul(a, b), warmup=warmup, rep=rep)
return tflops(ms), tflops(max_ms), tflops(min_ms) return tflops(ms), tflops(max_ms), tflops(min_ms)
except: except Exception:
return None return None
return None return None

View File

@@ -1,7 +1,8 @@
import argparse import argparse
import sys
import os
import inspect import inspect
import os
import sys
import triton import triton

View File

@@ -1 +0,0 @@
scipy >= 1.7.1

View File

@@ -1,2 +1,8 @@
[metadata] [metadata]
description-file = README.md description_file = README.md
[pycodestyle]
ignore = E501,E701,E731
[flake8]
ignore = E501,E701,E731

View File

@@ -1,48 +1,81 @@
import os
import re
import sys
import sysconfig
import platform
import subprocess
import distutils import distutils
import glob
import tempfile
import shutil
from distutils.version import LooseVersion
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from setuptools.command.test import test as TestCommand
import distutils.spawn import distutils.spawn
import urllib.request import os
import platform
import re
import shutil
import subprocess
import sys
import tarfile import tarfile
import torch import urllib.request
from distutils.version import LooseVersion
from typing import NamedTuple
def get_llvm(): from setuptools import Extension, setup
# tries to find system LLVM from setuptools.command.build_ext import build_ext
versions = ['-13.0', '-13', '-13-64']
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
def get_build_type():
if check_env_flag("DEBUG"):
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
else:
return "Release"
def use_system_llvm():
if platform.system() == "Windows":
return True
versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions] supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None] return any(p is not None for p in paths)
if paths:
return '', ''
if platform.system() == "Windows": def get_thirdparty_packages(triton_cache_path):
return '', '' class Package(NamedTuple):
# download if nothing is installed package: str
name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04' name: str
dir = '/tmp' url: str
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name) test_file: str
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name) include_flag: str
if not os.path.exists(llvm_library_dir): lib_flag: str
packages = [
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
]
if not use_system_llvm():
# donwload LLVM if no suitable system LLVM is installed
packages.append(
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
)
thirdparty_cmake_args = []
for p in packages:
package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name)
test_file_path = os.path.join(package_dir, p.test_file)
if not os.path.exists(test_file_path):
try: try:
shutil.rmtree(os.path.join(dir, name)) shutil.rmtree(package_root_dir)
except: except Exception:
pass pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name) os.makedirs(package_root_dir, exist_ok=True)
print('downloading and extracting ' + url + '...') print('downloading and extracting {} ...'.format(p.url))
ftpstream = urllib.request.urlopen(url) ftpstream = urllib.request.urlopen(p.url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz") file = tarfile.open(fileobj=ftpstream, mode="r|*")
file.extractall(path=dir) file.extractall(path=package_root_dir)
return llvm_include_dir, llvm_library_dir if p.include_flag:
thirdparty_cmake_args.append("-D{}={}/include".format(p.include_flag, package_dir))
if p.lib_flag:
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
return thirdparty_cmake_args
class CMakeExtension(Extension): class CMakeExtension(Extension):
@@ -80,34 +113,24 @@ class CMakeBuild(build_ext):
self.build_extension(ext) self.build_extension(ext)
def build_extension(self, ext): def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm() triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
self.debug = True thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories # create build directories
build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories # python directories
if torch.version.hip is not None: python_include_dirs = [distutils.sysconfig.get_python_inc()]
python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include']
else:
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
cmake_args = [ cmake_args = [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF", "-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON", "-DBUILD_PYTHON_MODULE=ON",
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir, # '-DPYTHON_EXECUTABLE=' + sys.executable,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir, # '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
#'-DPYTHON_EXECUTABLE=' + sys.executable,
#'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs) "-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
] ] + thirdparty_cmake_args
# configuration # configuration
cfg = "Debug" if self.debug else "Release" cfg = get_build_type()
build_args = ["--config", cfg] build_args = ["--config", cfg]
if platform.system() == "Windows": if platform.system() == "Windows":
@@ -130,14 +153,22 @@ class CMakeBuild(build_ext):
setup( setup(
name="triton", name="triton",
version="1.1.2", version="2.0.0",
author="Philippe Tillet", author="Philippe Tillet",
author_email="phil@openai.com", author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations", description="A language and compiler for custom Deep Learning operations",
long_description="", long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"], packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
install_requires=["torch", "filelock"], install_requires=[
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, "cmake",
"filelock",
"torch",
],
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"],
"triton/language": ["*.bc"],
},
include_package_data=True, include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")], ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild}, cmdclass={"build_ext": CMakeBuild},
@@ -152,4 +183,19 @@ setup(
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
], ],
extras_require={
"tests": [
"autopep8",
"flake8",
"isort",
"numpy",
"pytest",
"scipy>=1.7.1",
],
"tutorials": [
"matplotlib",
"pandas",
"tabulate",
],
},
) )

View File

@@ -105,7 +105,7 @@ ir::value *cast(ir::value *input, type_code _dtype, ir::builder *builder) {
other = builder->create_splat(other, src_ty->get_block_shapes()); other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other); return builder->create_icmpNE(input, other);
} }
throw_not_implemented("cast"); throw_not_implemented("cast from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
} }
/*---------------------------------------------- /*----------------------------------------------
@@ -253,7 +253,7 @@ ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
std::string where_docstr = R"pbdoc( std::string where_docstr = R"pbdoc(
Returns a block of elements from either `x` or `y`, depending on `condition`. Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`. Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead. If you want to avoid unintended memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
:param condition: When True (nonzero), yield x, otherwise yield y. :param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool :type condition: Block of triton.bool
@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input); return builder->create_sqrt(input);
}; };
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name, ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) { ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty(); ir::type *scalar_ty = input->get_type()->get_scalar_ty();
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
throw_not_int_or_float(name); throw_not_int_or_float(name);
} }
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
std::string min_docstr = R"pbdoc( std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`. Returns the minimum value of `input`.
)pbdoc"; )pbdoc";
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN); return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
}; };
/*----------------------------------------------
definition of triton.arg_min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value's index of `input`.
)pbdoc";
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
};
/*---------------------------------------------- /*----------------------------------------------
definition of triton.max definition of triton.max
----------------------------------------------*/ ----------------------------------------------*/
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX); return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
}; };
/*----------------------------------------------
definition of triton.arg_max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value's index of `input`.
)pbdoc";
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
};
/*---------------------------------------------- /*----------------------------------------------
definition of triton.sum definition of triton.sum
----------------------------------------------*/ ----------------------------------------------*/

View File

@@ -1,493 +0,0 @@
/*
pybind11/attr.h: Infrastructure for processing custom
type and function attributes
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "cast.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// \addtogroup annotations
/// @{
/// Annotation for methods
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
/// Annotation for operators
struct is_operator { };
/// Annotation for parent scope
struct scope { handle value; scope(const handle &s) : value(s) { } };
/// Annotation for documentation
struct doc { const char *value; doc(const char *value) : value(value) { } };
/// Annotation for function names
struct name { const char *value; name(const char *value) : value(value) { } };
/// Annotation indicating that a function is an overload associated with a given "sibling"
struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
/// Annotation indicating that a class derives from another given type
template <typename T> struct base {
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
base() { }
};
/// Keep patient alive while nurse lives
template <size_t Nurse, size_t Patient> struct keep_alive { };
/// Annotation indicating that a class is involved in a multiple inheritance relationship
struct multiple_inheritance { };
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
struct dynamic_attr { };
/// Annotation which enables the buffer protocol for a type
struct buffer_protocol { };
/// Annotation which requests that a special metaclass is created for a type
struct metaclass {
handle value;
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
metaclass() {}
/// Override pybind11's default metaclass
explicit metaclass(handle value) : value(value) { }
};
/// Annotation that marks a class as local to the module:
struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
/// Annotation to mark enums as an arithmetic type
struct arithmetic { };
/** \rst
A call policy which places one or more guard variables (``Ts...``) around the function call.
For example, this definition:
.. code-block:: cpp
m.def("foo", foo, py::call_guard<T>());
is equivalent to the following pseudocode:
.. code-block:: cpp
m.def("foo", [](args...) {
T scope_guard;
return foo(args...); // forwarded arguments
});
\endrst */
template <typename... Ts> struct call_guard;
template <> struct call_guard<> { using type = detail::void_type; };
template <typename T>
struct call_guard<T> {
static_assert(std::is_default_constructible<T>::value,
"The guard type must be default constructible");
using type = T;
};
template <typename T, typename... Ts>
struct call_guard<T, Ts...> {
struct type {
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
typename call_guard<Ts...>::type next{};
};
};
/// @} annotations
NAMESPACE_BEGIN(detail)
/* Forward declarations */
enum op_id : int;
enum op_type : int;
struct undefined_t;
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
/// Internal data structure which holds metadata about a keyword argument
struct argument_record {
const char *name; ///< Argument name
const char *descr; ///< Human-readable version of the argument value
handle value; ///< Associated Python object
bool convert : 1; ///< True if the argument is allowed to convert when loading
bool none : 1; ///< True if None is allowed when loading
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
: name(name), descr(descr), value(value), convert(convert), none(none) { }
};
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
struct function_record {
function_record()
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
/// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
// User-specified documentation string
char *doc = nullptr;
/// Human-readable version of the function signature
char *signature = nullptr;
/// List of registered keyword arguments
std::vector<argument_record> args;
/// Pointer to lambda function which converts arguments and performs the actual call
handle (*impl) (function_call &) = nullptr;
/// Storage for the wrapped function pointer and captured data, if any
void *data[3] = { };
/// Pointer to custom destructor for 'data' (if needed)
void (*free_data) (function_record *ptr) = nullptr;
/// Return value policy associated with this function
return_value_policy policy = return_value_policy::automatic;
/// True if name == '__init__'
bool is_constructor : 1;
/// True if this is a new-style `__init__` defined in `detail/init.h`
bool is_new_style_constructor : 1;
/// True if this is a stateless function pointer
bool is_stateless : 1;
/// True if this is an operator (__add__), etc.
bool is_operator : 1;
/// True if the function has a '*args' argument
bool has_args : 1;
/// True if the function has a '**kwargs' argument
bool has_kwargs : 1;
/// True if this is a method
bool is_method : 1;
/// Number of arguments (including py::args and/or py::kwargs, if present)
std::uint16_t nargs;
/// Python method object
PyMethodDef *def = nullptr;
/// Python handle to the parent scope (a class or a module)
handle scope;
/// Python handle to the sibling function representing an overload chain
handle sibling;
/// Pointer to next overload
function_record *next = nullptr;
};
/// Special data structure which (temporarily) holds metadata about a bound class
struct type_record {
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
default_holder(true), module_local(false) { }
/// Handle to the parent scope
handle scope;
/// Name of the class
const char *name = nullptr;
// Pointer to RTTI type_info data structure
const std::type_info *type = nullptr;
/// How large is the underlying C++ type?
size_t type_size = 0;
/// What is the alignment of the underlying C++ type?
size_t type_align = 0;
/// How large is the type's holder?
size_t holder_size = 0;
/// The global operator new can be overridden with a class-specific variant
void *(*operator_new)(size_t) = nullptr;
/// Function pointer to class_<..>::init_instance
void (*init_instance)(instance *, const void *) = nullptr;
/// Function pointer to class_<..>::dealloc
void (*dealloc)(detail::value_and_holder &) = nullptr;
/// List of base classes of the newly created type
list bases;
/// Optional docstring
const char *doc = nullptr;
/// Custom metaclass (optional)
handle metaclass;
/// Multiple inheritance marker
bool multiple_inheritance : 1;
/// Does the class manage a __dict__?
bool dynamic_attr : 1;
/// Does the class implement the buffer protocol?
bool buffer_protocol : 1;
/// Is the default (unique_ptr) holder type used?
bool default_holder : 1;
/// Is the class definition local to the module shared object?
bool module_local : 1;
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(base, false);
if (!base_info) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) +
"\" referenced unknown base type \"" + tname + "\"");
}
if (default_holder != base_info->default_holder) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
(default_holder ? "does not have" : "has") +
" a non-default holder type while its base \"" + tname + "\" " +
(base_info->default_holder ? "does not" : "does"));
}
bases.append((PyObject *) base_info->type);
if (base_info->type->tp_dictoffset != 0)
dynamic_attr = true;
if (caster)
base_info->implicit_casts.emplace_back(type, caster);
}
};
inline function_call::function_call(const function_record &f, handle p) :
func(f), parent(p) {
args.reserve(f.nargs);
args_convert.reserve(f.nargs);
}
/// Tag for a new-style `__init__` defined in `detail/init.h`
struct is_new_style_constructor { };
/**
* Partial template specializations to process custom attributes provided to
* cpp_function_ and class_. These are either used to initialize the respective
* fields in the type_record and function_record data structures or executed at
* runtime to deal with custom call policies (e.g. keep_alive).
*/
template <typename T, typename SFINAE = void> struct process_attribute;
template <typename T> struct process_attribute_default {
/// Default implementation: do nothing
static void init(const T &, function_record *) { }
static void init(const T &, type_record *) { }
static void precall(function_call &) { }
static void postcall(function_call &, handle) { }
};
/// Process an attribute specifying the function's name
template <> struct process_attribute<name> : process_attribute_default<name> {
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
};
/// Process an attribute specifying the function's docstring
template <> struct process_attribute<doc> : process_attribute_default<doc> {
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
};
/// Process an attribute specifying the function's docstring (provided as a C-style string)
template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
};
template <> struct process_attribute<char *> : process_attribute<const char *> { };
/// Process an attribute indicating the function's return value policy
template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
};
/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
};
/// Process an attribute which indicates that this function is a method
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
};
/// Process an attribute which indicates the parent scope of a method
template <> struct process_attribute<scope> : process_attribute_default<scope> {
static void init(const scope &s, function_record *r) { r->scope = s.value; }
};
/// Process an attribute which indicates that this function is an operator
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
};
template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
};
/// Process a keyword argument attribute (*without* a default value)
template <> struct process_attribute<arg> : process_attribute_default<arg> {
static void init(const arg &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
}
};
/// Process a keyword argument attribute (*with* a default value)
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
static void init(const arg_v &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
if (!a.value) {
#if !defined(NDEBUG)
std::string descr("'");
if (a.name) descr += std::string(a.name) + ": ";
descr += a.type + "'";
if (r->is_method) {
if (r->name)
descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
else
descr += " in method of '" + (std::string) str(r->scope) + "'";
} else if (r->name) {
descr += " in function '" + (std::string) r->name + "'";
}
pybind11_fail("arg(): could not convert default argument "
+ descr + " into a Python object (type not registered yet?)");
#else
pybind11_fail("arg(): could not convert default argument "
"into a Python object (type not registered yet?). "
"Compile in debug mode for more information.");
#endif
}
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
}
};
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
template <typename T>
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
static void init(const handle &h, type_record *r) { r->bases.append(h); }
};
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
template <typename T>
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
};
/// Process a multiple inheritance attribute
template <>
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
};
template <>
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
};
template <>
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
};
template <>
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
};
template <>
struct process_attribute<module_local> : process_attribute_default<module_local> {
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
};
/// Process an 'arithmetic' attribute for enums (does nothing here)
template <>
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
template <typename... Ts>
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
/**
* Process a keep_alive call policy -- invokes keep_alive_impl during the
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
* otherwise
*/
template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void postcall(function_call &, handle) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void precall(function_call &) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
};
/// Recursively iterate over variadic template arguments
template <typename... Args> struct process_attributes {
static void init(const Args&... args, function_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void init(const Args&... args, type_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void precall(function_call &call) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
ignore_unused(unused);
}
static void postcall(function_call &call, handle fn_ret) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
ignore_unused(unused);
}
};
template <typename T>
using is_call_guard = is_instantiation<call_guard, T>;
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
template <typename... Extra>
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
/// Check the number of named arguments at compile time
template <typename... Extra,
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,108 +0,0 @@
/*
pybind11/buffer_info.h: Python buffer object interface
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// Information record describing a Python buffer object
struct buffer_info {
void *ptr = nullptr; // Pointer to the underlying storage
ssize_t itemsize = 0; // Size of individual items in bytes
ssize_t size = 0; // Total number of entries
std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
ssize_t ndim = 0; // Number of dimensions
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
buffer_info() { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)) {
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
for (size_t i = 0; i < (size_t) ndim; ++i)
size *= shape[i];
}
template <typename T>
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
template <typename T>
buffer_info(T *ptr, ssize_t size)
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
explicit buffer_info(Py_buffer *view, bool ownview = true)
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
this->view = view;
this->ownview = ownview;
}
buffer_info(const buffer_info &) = delete;
buffer_info& operator=(const buffer_info &) = delete;
buffer_info(buffer_info &&other) {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info &&rhs) {
ptr = rhs.ptr;
itemsize = rhs.itemsize;
size = rhs.size;
format = std::move(rhs.format);
ndim = rhs.ndim;
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
std::swap(view, rhs.view);
std::swap(ownview, rhs.ownview);
return *this;
}
~buffer_info() {
if (view && ownview) { PyBuffer_Release(view); delete view; }
}
private:
struct private_ctr_tag { };
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
Py_buffer *view = nullptr;
bool ownview = false;
};
NAMESPACE_BEGIN(detail)
template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
}
};
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

View File

@@ -1,162 +0,0 @@
/*
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <cmath>
#include <ctime>
#include <chrono>
#include <datetime.h>
// Backport the PyDateTime_DELTA functions from Python3.3 if required
#ifndef PyDateTime_DELTA_GET_DAYS
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
#endif
#ifndef PyDateTime_DELTA_GET_SECONDS
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
#endif
#ifndef PyDateTime_DELTA_GET_MICROSECONDS
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename type> class duration_caster {
public:
typedef typename type::rep rep;
typedef typename type::period period;
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
// If invoked with datetime.delta object
if (PyDelta_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
return true;
}
// If invoked with a float we assume it is seconds and convert
else if (PyFloat_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
return true;
}
else return false;
}
// If this is a duration just return it back
static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
return src;
}
// If this is a time_point get the time_since_epoch
template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
return src.time_since_epoch();
}
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Use overloaded function to get our duration from our source
// Works out if it is a duration or time_point and get the duration
auto d = get_duration(src);
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using dd_t = duration<int, std::ratio<86400>>;
using ss_t = duration<int, std::ratio<1>>;
using us_t = duration<int, std::micro>;
auto dd = duration_cast<dd_t>(d);
auto subd = d - dd;
auto ss = duration_cast<ss_t>(subd);
auto us = duration_cast<us_t>(subd - ss);
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
};
// This is for casting times on the system clock into datetime.datetime instances
template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
public:
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
if (PyDateTime_Check(src.ptr())) {
std::tm cal;
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
cal.tm_isdst = -1;
value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
return true;
}
else return false;
}
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
std::time_t tt = system_clock::to_time_t(src);
// this function uses static memory so it's best to copy it out asap just in case
// otherwise other code that is using localtime may break this (not just python code)
std::tm localtime = *std::localtime(&tt);
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using us_t = duration<int, std::micro>;
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
localtime.tm_mon + 1,
localtime.tm_mday,
localtime.tm_hour,
localtime.tm_min,
localtime.tm_sec,
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
};
// Other clocks that are not the system clock are not measured as datetime.datetime objects
// since they are not measured on calendar time. So instead we just make them timedeltas
// Or if they have passed us a time as a float we convert that
template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
: public duration_caster<std::chrono::time_point<Clock, Duration>> {
};
template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
: public duration_caster<std::chrono::duration<Rep, Period>> {
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,2 +0,0 @@
#include "detail/common.h"
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."

View File

@@ -1,65 +0,0 @@
/*
pybind11/complex.h: Complex number support
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <complex>
/// glibc defines I as a macro which breaks things, e.g., boost template names
#ifdef I
# undef I
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = { 'Z', c, '\0' };
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T> constexpr const char format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
NAMESPACE_BEGIN(detail)
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr bool value = true;
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
template <typename T> class type_caster<std::complex<T>> {
public:
bool load(handle src, bool convert) {
if (!src)
return false;
if (!convert && !PyComplex_Check(src.ptr()))
return false;
Py_complex result = PyComplex_AsCComplex(src.ptr());
if (result.real == -1.0 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
value = std::complex<T>((T) result.real, (T) result.imag);
return true;
}
static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
}
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,623 +0,0 @@
/*
pybind11/detail/class.h: Python C API implementation details for py::class_
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "../attr.h"
#include "../options.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if PY_VERSION_HEX >= 0x03030000
# define PYBIND11_BUILTIN_QUALNAME
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj)
#else
// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type
// signatures; in 3.3+ this macro expands to nothing:
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj)
#endif
inline PyTypeObject *type_incref(PyTypeObject *type) {
Py_INCREF(type);
return type;
}
#if !defined(PYPY_VERSION)
/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
return PyProperty_Type.tp_descr_get(self, cls, cls);
}
/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
return PyProperty_Type.tp_descr_set(self, cls, value);
}
/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
methods are modified to always use the object type instead of a concrete instance.
Return value: New reference. */
inline PyTypeObject *make_static_property_type() {
constexpr auto *name = "pybind11_static_property";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
if (!heap_type)
pybind11_fail("make_static_property_type(): error allocating type!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyProperty_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_descr_get = pybind11_static_get;
type->tp_descr_set = pybind11_static_set;
if (PyType_Ready(type) < 0)
pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
return type;
}
#else // PYPY
/** PyPy has some issues with the above C API, so we evaluate Python code instead.
This function will only be called once so performance isn't really a concern.
Return value: New reference. */
inline PyTypeObject *make_static_property_type() {
auto d = dict();
PyObject *result = PyRun_String(R"(\
class pybind11_static_property(property):
def __get__(self, obj, cls):
return property.__get__(self, cls, cls)
def __set__(self, obj, value):
cls = obj if isinstance(obj, type) else type(obj)
property.__set__(self, cls, value)
)", Py_file_input, d.ptr(), d.ptr()
);
if (result == nullptr)
throw error_already_set();
Py_DECREF(result);
return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
}
#endif // PYPY
/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
By default, Python replaces the `static_property` itself, but for wrapped C++ types
we need to call `static_property.__set__()` in order to propagate the new value to
the underlying C++ data structure. */
extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) {
// Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
// descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
// The following assignment combinations are possible:
// 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
// 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
// 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
const auto static_prop = (PyObject *) get_internals().static_property_type;
const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop)
&& !PyObject_IsInstance(value, static_prop);
if (call_descr_set) {
// Call `static_property.__set__()` instead of replacing the `static_property`.
#if !defined(PYPY_VERSION)
return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
#else
if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
Py_DECREF(result);
return 0;
} else {
return -1;
}
#endif
} else {
// Replace existing attribute.
return PyType_Type.tp_setattro(obj, name, value);
}
}
#if PY_MAJOR_VERSION >= 3
/**
* Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
* methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
* when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
* to do a special case bypass for PyInstanceMethod_Types.
*/
extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
if (descr && PyInstanceMethod_Check(descr)) {
Py_INCREF(descr);
return descr;
}
else {
return PyType_Type.tp_getattro(obj, name);
}
}
#endif
/** This metaclass is assigned by default to all pybind11 types and is required in order
for static properties to function correctly. Users may override this using `py::metaclass`.
Return value: New reference. */
inline PyTypeObject* make_default_metaclass() {
constexpr auto *name = "pybind11_type";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
if (!heap_type)
pybind11_fail("make_default_metaclass(): error allocating metaclass!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyType_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_setattro = pybind11_meta_setattro;
#if PY_MAJOR_VERSION >= 3
type->tp_getattro = pybind11_meta_getattro;
#endif
if (PyType_Ready(type) < 0)
pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
return type;
}
/// For multiple inheritance types we need to recursively register/deregister base pointers for any
/// base classes with pointers that are difference from the instance value pointer so that we can
/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs.
inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self,
bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
for (auto &c : parent_tinfo->implicit_casts) {
if (c.first == tinfo->cpptype) {
auto *parentptr = c.second(valueptr);
if (parentptr != valueptr)
f(parentptr, self);
traverse_offset_bases(parentptr, parent_tinfo, self, f);
break;
}
}
}
}
}
inline bool register_instance_impl(void *ptr, instance *self) {
get_internals().registered_instances.emplace(ptr, self);
return true; // unused, but gives the same signature as the deregister func
}
inline bool deregister_instance_impl(void *ptr, instance *self) {
auto &registered_instances = get_internals().registered_instances;
auto range = registered_instances.equal_range(ptr);
for (auto it = range.first; it != range.second; ++it) {
if (Py_TYPE(self) == Py_TYPE(it->second)) {
registered_instances.erase(it);
return true;
}
}
return false;
}
inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
register_instance_impl(valptr, self);
if (!tinfo->simple_ancestors)
traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
}
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
bool ret = deregister_instance_impl(valptr, self);
if (!tinfo->simple_ancestors)
traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
return ret;
}
/// Instance creation function for all pybind11 types. It allocates the internal instance layout for
/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast
/// to a reference or pointer), and initialization is done by an `__init__` function.
inline PyObject *make_new_instance(PyTypeObject *type) {
#if defined(PYPY_VERSION)
// PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited
// object is a a plain Python type (i.e. not derived from an extension type). Fix it.
ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
if (type->tp_basicsize < instance_size) {
type->tp_basicsize = instance_size;
}
#endif
PyObject *self = type->tp_alloc(type, 0);
auto inst = reinterpret_cast<instance *>(self);
// Allocate the value/holder internals:
inst->allocate_layout();
inst->owned = true;
return self;
}
/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
return make_new_instance(type);
}
/// An `__init__` function constructs the C++ object. Users should provide at least one
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
/// following default function will be used which simply throws an exception.
extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
PyTypeObject *type = Py_TYPE(self);
std::string msg;
#if defined(PYPY_VERSION)
msg += handle((PyObject *) type).attr("__module__").cast<std::string>() + ".";
#endif
msg += type->tp_name;
msg += ": No constructor defined!";
PyErr_SetString(PyExc_TypeError, msg.c_str());
return -1;
}
inline void add_patient(PyObject *nurse, PyObject *patient) {
auto &internals = get_internals();
auto instance = reinterpret_cast<detail::instance *>(nurse);
instance->has_patients = true;
Py_INCREF(patient);
internals.patients[nurse].push_back(patient);
}
inline void clear_patients(PyObject *self) {
auto instance = reinterpret_cast<detail::instance *>(self);
auto &internals = get_internals();
auto pos = internals.patients.find(self);
assert(pos != internals.patients.end());
// Clearing the patients can cause more Python code to run, which
// can invalidate the iterator. Extract the vector of patients
// from the unordered_map first.
auto patients = std::move(pos->second);
internals.patients.erase(pos);
instance->has_patients = false;
for (PyObject *&patient : patients)
Py_CLEAR(patient);
}
/// Clears all internal data from the instance and removes it from registered instances in
/// preparation for deallocation.
inline void clear_instance(PyObject *self) {
auto instance = reinterpret_cast<detail::instance *>(self);
// Deallocate any values/holders, if present:
for (auto &v_h : values_and_holders(instance)) {
if (v_h) {
// We have to deregister before we call dealloc because, for virtual MI types, we still
// need to be able to get the parent pointers.
if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type))
pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
if (instance->owned || v_h.holder_constructed())
v_h.type->dealloc(v_h);
}
}
// Deallocate the value/holder layout internals:
instance->deallocate_layout();
if (instance->weakrefs)
PyObject_ClearWeakRefs(self);
PyObject **dict_ptr = _PyObject_GetDictPtr(self);
if (dict_ptr)
Py_CLEAR(*dict_ptr);
if (instance->has_patients)
clear_patients(self);
}
/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
/// to destroy the C++ object itself, while the rest is Python bookkeeping.
extern "C" inline void pybind11_object_dealloc(PyObject *self) {
clear_instance(self);
auto type = Py_TYPE(self);
type->tp_free(self);
// `type->tp_dealloc != pybind11_object_dealloc` means that we're being called
// as part of a derived type's dealloc, in which case we're not allowed to decref
// the type here. For cross-module compatibility, we shouldn't compare directly
// with `pybind11_object_dealloc`, but with the common one stashed in internals.
auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base;
if (type->tp_dealloc == pybind11_object_type->tp_dealloc)
Py_DECREF(type);
}
/** Create the type which can be used as a common base for all classes. This is
needed in order to satisfy Python's requirements for multiple inheritance.
Return value: New reference. */
inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
constexpr auto *name = "pybind11_object";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
if (!heap_type)
pybind11_fail("make_object_base_type(): error allocating type!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyBaseObject_Type);
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_new = pybind11_object_new;
type->tp_init = pybind11_object_init;
type->tp_dealloc = pybind11_object_dealloc;
/* Support weak references (needed for the keep_alive feature) */
type->tp_weaklistoffset = offsetof(instance, weakrefs);
if (PyType_Ready(type) < 0)
pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string());
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
return (PyObject *) heap_type;
}
/// dynamic_attr: Support for `d = instance.__dict__`.
extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
if (!dict)
dict = PyDict_New();
Py_XINCREF(dict);
return dict;
}
/// dynamic_attr: Support for `instance.__dict__ = dict()`.
extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) {
if (!PyDict_Check(new_dict)) {
PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'",
Py_TYPE(new_dict)->tp_name);
return -1;
}
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_INCREF(new_dict);
Py_CLEAR(dict);
dict = new_dict;
return 0;
}
/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_VISIT(dict);
return 0;
}
/// dynamic_attr: Allow the GC to clear the dictionary.
extern "C" inline int pybind11_clear(PyObject *self) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_CLEAR(dict);
return 0;
}
/// Give instances of this type a `__dict__` and opt into garbage collection.
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
auto type = &heap_type->ht_type;
#if defined(PYPY_VERSION)
pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are "
"currently not supported in "
"conjunction with PyPy!");
#endif
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
type->tp_dictoffset = type->tp_basicsize; // place dict at the end
type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it
type->tp_traverse = pybind11_traverse;
type->tp_clear = pybind11_clear;
static PyGetSetDef getset[] = {
{const_cast<char*>("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}
};
type->tp_getset = getset;
}
/// buffer_protocol: Fill in the view as specified by flags.
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
type_info *tinfo = nullptr;
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
tinfo = get_type_info((PyTypeObject *) type.ptr());
if (tinfo && tinfo->get_buffer)
break;
}
if (view == nullptr || !tinfo || !tinfo->get_buffer) {
if (view)
view->obj = nullptr;
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
return -1;
}
std::memset(view, 0, sizeof(Py_buffer));
buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
view->obj = obj;
view->ndim = 1;
view->internal = info;
view->buf = info->ptr;
view->itemsize = info->itemsize;
view->len = view->itemsize;
for (auto s : info->shape)
view->len *= s;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)
view->format = const_cast<char *>(info->format.c_str());
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->ndim = (int) info->ndim;
view->strides = &info->strides[0];
view->shape = &info->shape[0];
}
Py_INCREF(view->obj);
return 0;
}
/// buffer_protocol: Release the resources of the buffer.
extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
delete (buffer_info *) view->internal;
}
/// Give this type a buffer interface.
inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
#if PY_MAJOR_VERSION < 3
heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER;
#endif
heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
}
/** Create a brand new Python type according to the `type_record` specification.
Return value: New reference. */
inline PyObject* make_new_python_type(const type_record &rec) {
auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
auto qualname = name;
if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) {
#if PY_MAJOR_VERSION >= 3
qualname = reinterpret_steal<object>(
PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
#else
qualname = str(rec.scope.attr("__qualname__").cast<std::string>() + "." + rec.name);
#endif
}
object module;
if (rec.scope) {
if (hasattr(rec.scope, "__module__"))
module = rec.scope.attr("__module__");
else if (hasattr(rec.scope, "__name__"))
module = rec.scope.attr("__name__");
}
auto full_name = c_str(
#if !defined(PYPY_VERSION)
module ? str(module).cast<std::string>() + "." + rec.name :
#endif
rec.name);
char *tp_doc = nullptr;
if (rec.doc && options::show_user_defined_docstrings()) {
/* Allocate memory for docstring (using PyObject_MALLOC, since
Python will free this later on) */
size_t size = strlen(rec.doc) + 1;
tp_doc = (char *) PyObject_MALLOC(size);
memcpy((void *) tp_doc, rec.doc, size);
}
auto &internals = get_internals();
auto bases = tuple(rec.bases);
auto base = (bases.size() == 0) ? internals.instance_base
: bases[0].ptr();
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr()
: internals.default_metaclass;
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
if (!heap_type)
pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
heap_type->ht_name = name.release().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = qualname.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = full_name;
type->tp_doc = tp_doc;
type->tp_base = type_incref((PyTypeObject *)base);
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
if (bases.size() > 0)
type->tp_bases = bases.release().ptr();
/* Don't inherit base __init__ */
type->tp_init = pybind11_object_init;
/* Supported protocols */
type->tp_as_number = &heap_type->as_number;
type->tp_as_sequence = &heap_type->as_sequence;
type->tp_as_mapping = &heap_type->as_mapping;
/* Flags */
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
#if PY_MAJOR_VERSION < 3
type->tp_flags |= Py_TPFLAGS_CHECKTYPES;
#endif
if (rec.dynamic_attr)
enable_dynamic_attributes(heap_type);
if (rec.buffer_protocol)
enable_buffer_protocol(heap_type);
if (PyType_Ready(type) < 0)
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
: !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
/* Register type with the parent scope */
if (rec.scope)
setattr(rec.scope, rec.name, (PyObject *) type);
else
Py_INCREF(type); // Keep it alive forever (reference leak)
if (module) // Needed by pydoc
setattr((PyObject *) type, "__module__", module);
PYBIND11_SET_OLDPY_QUALNAME(type, qualname);
return (PyObject *) type;
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,807 +0,0 @@
/*
pybind11/detail/common.h -- Basic macros
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#if !defined(NAMESPACE_BEGIN)
# define NAMESPACE_BEGIN(name) namespace name {
#endif
#if !defined(NAMESPACE_END)
# define NAMESPACE_END(name) }
#endif
// Robust support for some features and loading modules compiled against different pybind versions
// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
// the main `pybind11` namespace.
#if !defined(PYBIND11_NAMESPACE)
# ifdef __GNUG__
# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
# else
# define PYBIND11_NAMESPACE pybind11
# endif
#endif
#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER)
# if __cplusplus >= 201402L
# define PYBIND11_CPP14
# if __cplusplus >= 201703L
# define PYBIND11_CPP17
# endif
# endif
#elif defined(_MSC_VER) && __cplusplus == 199711L
// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented)
// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer
# if _MSVC_LANG >= 201402L
# define PYBIND11_CPP14
# if _MSVC_LANG > 201402L && _MSC_VER >= 1910
# define PYBIND11_CPP17
# endif
# endif
#endif
// Compiler version assertions
#if defined(__INTEL_COMPILER)
# if __INTEL_COMPILER < 1700
# error pybind11 requires Intel C++ compiler v17 or newer
# endif
#elif defined(__clang__) && !defined(__apple_build_version__)
# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3)
# error pybind11 requires clang 3.3 or newer
# endif
#elif defined(__clang__)
// Apple changes clang version macros to its Xcode version; the first Xcode release based on
// (upstream) clang 3.3 was Xcode 5:
# if __clang_major__ < 5
# error pybind11 requires Xcode/clang 5.0 or newer
# endif
#elif defined(__GNUG__)
# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8)
# error pybind11 requires gcc 4.8 or newer
# endif
#elif defined(_MSC_VER)
// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features
// (e.g. std::negation) added in 2015u3:
# if _MSC_FULL_VER < 190024210
# error pybind11 requires MSVC 2015 update 3 or newer
# endif
#endif
#if !defined(PYBIND11_EXPORT)
# if defined(WIN32) || defined(_WIN32)
# define PYBIND11_EXPORT __declspec(dllexport)
# else
# define PYBIND11_EXPORT __attribute__ ((visibility("default")))
# endif
#endif
#if defined(_MSC_VER)
# define PYBIND11_NOINLINE __declspec(noinline)
#else
# define PYBIND11_NOINLINE __attribute__ ((noinline))
#endif
#if defined(PYBIND11_CPP14)
# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]]
#else
# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason)))
#endif
#define PYBIND11_VERSION_MAJOR 2
#define PYBIND11_VERSION_MINOR 3
#define PYBIND11_VERSION_PATCH 0
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4)
# define HAVE_ROUND 1
# endif
# pragma warning(push)
# pragma warning(disable: 4510 4610 4512 4005)
# if defined(_DEBUG)
# define PYBIND11_DEBUG_MARKER
# undef _DEBUG
# endif
#endif
#include <Python.h>
#include <frameobject.h>
#include <pythread.h>
#if defined(_WIN32) && (defined(min) || defined(max))
# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
#endif
#if defined(isalnum)
# undef isalnum
# undef isalpha
# undef islower
# undef isspace
# undef isupper
# undef tolower
# undef toupper
#endif
#if defined(_MSC_VER)
# if defined(PYBIND11_DEBUG_MARKER)
# define _DEBUG
# undef PYBIND11_DEBUG_MARKER
# endif
# pragma warning(pop)
#endif
#include <cstddef>
#include <cstring>
#include <forward_list>
#include <vector>
#include <string>
#include <stdexcept>
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include <typeindex>
#include <type_traits>
#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION
#define PYBIND11_BYTES_CHECK PyBytes_Check
#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize
#define PYBIND11_BYTES_AS_STRING PyBytes_AsString
#define PYBIND11_BYTES_SIZE PyBytes_Size
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o)
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o)
#define PYBIND11_BYTES_NAME "bytes"
#define PYBIND11_STRING_NAME "str"
#define PYBIND11_SLICE_OBJECT PyObject
#define PYBIND11_FROM_STRING PyUnicode_FromString
#define PYBIND11_STR_TYPE ::pybind11::str
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
#else
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION
#define PYBIND11_BYTES_CHECK PyString_Check
#define PYBIND11_BYTES_FROM_STRING PyString_FromString
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize
#define PYBIND11_BYTES_AS_STRING PyString_AsString
#define PYBIND11_BYTES_SIZE PyString_Size
#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed.
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed.
#define PYBIND11_BYTES_NAME "str"
#define PYBIND11_STRING_NAME "unicode"
#define PYBIND11_SLICE_OBJECT PySliceObject
#define PYBIND11_FROM_STRING PyString_FromString
#define PYBIND11_STR_TYPE ::pybind11::bytes
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_PLUGIN_IMPL(name) \
static PyObject *pybind11_init_wrapper(); \
extern "C" PYBIND11_EXPORT void init##name() { \
(void)pybind11_init_wrapper(); \
} \
PyObject *pybind11_init_wrapper()
#endif
#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200
extern "C" {
struct _Py_atomic_address { void *value; };
PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current;
}
#endif
#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code
#define PYBIND11_STRINGIFY(x) #x
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
#define PYBIND11_CONCAT(first, second) first##second
#define PYBIND11_CHECK_PYTHON_VERSION \
{ \
const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \
"." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
const char *runtime_ver = Py_GetVersion(); \
size_t len = std::strlen(compiled_ver); \
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
PyErr_Format(PyExc_ImportError, \
"Python version mismatch: module was compiled for Python %s, " \
"but the interpreter version is incompatible: %s.", \
compiled_ver, runtime_ver); \
return nullptr; \
} \
}
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
/** \rst
***Deprecated in favor of PYBIND11_MODULE***
This macro creates the entry point that will be invoked when the Python interpreter
imports a plugin library. Please create a `module` in the function body and return
the pointer to its underlying Python object at the end.
.. code-block:: cpp
PYBIND11_PLUGIN(example) {
pybind11::module m("example", "pybind11 example plugin");
/// Set up bindings here
return m.ptr();
}
\endrst */
#define PYBIND11_PLUGIN(name) \
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
static PyObject *pybind11_init(); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
try { \
return pybind11_init(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
PyObject *pybind11_init()
/** \rst
This macro creates the entry point that will be invoked when the Python interpreter
imports an extension module. The module name is given as the fist argument and it
should not be in quotes. The second macro argument defines a variable of type
`py::module` which can be used to initialize the module.
.. code-block:: cpp
PYBIND11_MODULE(example, m) {
m.doc() = "pybind11 example module";
// Add bindings here
m.def("foo", []() {
return "Hello, World!";
});
}
\endrst */
#define PYBIND11_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
using ssize_t = Py_ssize_t;
using size_t = std::size_t;
/// Approach used to cast a previously unknown C++ instance into a Python object
enum class return_value_policy : uint8_t {
/** This is the default return value policy, which falls back to the policy
return_value_policy::take_ownership when the return value is a pointer.
Otherwise, it uses return_value::move or return_value::copy for rvalue
and lvalue references, respectively. See below for a description of what
all of these different policies do. */
automatic = 0,
/** As above, but use policy return_value_policy::reference when the return
value is a pointer. This is the default conversion policy for function
arguments when calling Python functions manually from C++ code (i.e. via
handle::operator()). You probably won't need to use this. */
automatic_reference,
/** Reference an existing object (i.e. do not create a new copy) and take
ownership. Python will call the destructor and delete operator when the
objects reference count reaches zero. Undefined behavior ensues when
the C++ side does the same.. */
take_ownership,
/** Create a new copy of the returned object, which will be owned by
Python. This policy is comparably safe because the lifetimes of the two
instances are decoupled. */
copy,
/** Use std::move to move the return value contents into a new instance
that will be owned by Python. This policy is comparably safe because the
lifetimes of the two instances (move source and destination) are
decoupled. */
move,
/** Reference an existing object, but do not take ownership. The C++ side
is responsible for managing the objects lifetime and deallocating it
when it is no longer used. Warning: undefined behavior will ensue when
the C++ side deletes an object that is still referenced and used by
Python. */
reference,
/** This policy only applies to methods and properties. It references the
object without taking ownership similar to the above
return_value_policy::reference policy. In contrast to that policy, the
function or propertys implicit this argument (called the parent) is
considered to be the the owner of the return value (the child).
pybind11 then couples the lifetime of the parent to the child via a
reference relationship that ensures that the parent cannot be garbage
collected while Python is still using the child. More advanced
variations of this scheme are also possible using combinations of
return_value_policy::reference and the keep_alive call policy */
reference_internal
};
NAMESPACE_BEGIN(detail)
inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); }
// Returns the size as a multiple of sizeof(void *), rounded up.
inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); }
/**
* The space to allocate for simple layout instance holders (see below) in multiple of the size of
* a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required
* to holder either a std::unique_ptr or std::shared_ptr (which is almost always
* sizeof(std::shared_ptr<T>)).
*/
constexpr size_t instance_simple_holder_in_ptrs() {
static_assert(sizeof(std::shared_ptr<int>) >= sizeof(std::unique_ptr<int>),
"pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs");
return size_in_ptrs(sizeof(std::shared_ptr<int>));
}
// Forward declarations
struct type_info;
struct value_and_holder;
struct nonsimple_values_and_holders {
void **values_and_holders;
uint8_t *status;
};
/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof')
struct instance {
PyObject_HEAD
/// Storage for pointers and holder; see simple_layout, below, for a description
union {
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
nonsimple_values_and_holders nonsimple;
};
/// Weak references
PyObject *weakrefs;
/// If true, the pointer is owned which means we're free to manage it with a holder.
bool owned : 1;
/**
* An instance has two possible value/holder layouts.
*
* Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer
* and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied
* whenever there is no python-side multiple inheritance of bound C++ types *and* the type's
* holder will fit in the default space (which is large enough to hold either a std::unique_ptr
* or std::shared_ptr).
*
* Non-simple layout applies when using custom holders that require more space than `shared_ptr`
* (which is typically the size of two pointers), or when multiple inheritance is used on the
* python side. Non-simple layout allocates the required amount of memory to have multiple
* bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
* pointer to allocated space of the required space to hold a sequence of value pointers and
* holders followed `status`, a set of bit flags (1 byte each), i.e.
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
* `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the
* beginning of the [bb...] block (but not independently allocated).
*
* Status bits indicate whether the associated holder is constructed (&
* status_holder_constructed) and whether the value pointer is registered (&
* status_instance_registered) in `registered_instances`.
*/
bool simple_layout : 1;
/// For simple layout, tracks whether the holder has been constructed
bool simple_holder_constructed : 1;
/// For simple layout, tracks whether the instance is registered in `registered_instances`
bool simple_instance_registered : 1;
/// If true, get_internals().patients has an entry for this object
bool has_patients : 1;
/// Initializes all of the above type/values/holders data (but not the instance values themselves)
void allocate_layout();
/// Destroys/deallocates all of the above
void deallocate_layout();
/// Returns the value_and_holder wrapper for the given type (or the first, if `find_type`
/// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if
/// `throw_if_missing` is false.
value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true);
/// Bit values for the non-simple status flags
static constexpr uint8_t status_holder_constructed = 1;
static constexpr uint8_t status_instance_registered = 2;
};
static_assert(std::is_standard_layout<instance>::value, "Internal error: `pybind11::detail::instance` is not standard layout!");
/// from __cpp_future__ import (convenient aliases from C++14/17)
#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910)
using std::enable_if_t;
using std::conditional_t;
using std::remove_cv_t;
using std::remove_reference_t;
#else
template <bool B, typename T = void> using enable_if_t = typename std::enable_if<B, T>::type;
template <bool B, typename T, typename F> using conditional_t = typename std::conditional<B, T, F>::type;
template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> using remove_reference_t = typename std::remove_reference<T>::type;
#endif
/// Index sequences
#if defined(PYBIND11_CPP14)
using std::index_sequence;
using std::make_index_sequence;
#else
template<size_t ...> struct index_sequence { };
template<size_t N, size_t ...S> struct make_index_sequence_impl : make_index_sequence_impl <N - 1, N - 1, S...> { };
template<size_t ...S> struct make_index_sequence_impl <0, S...> { typedef index_sequence<S...> type; };
template<size_t N> using make_index_sequence = typename make_index_sequence_impl<N>::type;
#endif
/// Make an index sequence of the indices of true arguments
template <typename ISeq, size_t, bool...> struct select_indices_impl { using type = ISeq; };
template <size_t... IPrev, size_t I, bool B, bool... Bs> struct select_indices_impl<index_sequence<IPrev...>, I, B, Bs...>
: select_indices_impl<conditional_t<B, index_sequence<IPrev..., I>, index_sequence<IPrev...>>, I + 1, Bs...> {};
template <bool... Bs> using select_indices = typename select_indices_impl<index_sequence<>, 0, Bs...>::type;
/// Backports of std::bool_constant and std::negation to accommodate older compilers
template <bool B> using bool_constant = std::integral_constant<bool, B>;
template <typename T> struct negation : bool_constant<!T::value> { };
template <typename...> struct void_t_impl { using type = void; };
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
/// Compile-time all/any/none of that check the boolean value of all template types
#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916))
template <class... Ts> using all_of = bool_constant<(Ts::value && ...)>;
template <class... Ts> using any_of = bool_constant<(Ts::value || ...)>;
#elif !defined(_MSC_VER)
template <bool...> struct bools {};
template <class... Ts> using all_of = std::is_same<
bools<Ts::value..., true>,
bools<true, Ts::value...>>;
template <class... Ts> using any_of = negation<all_of<negation<Ts>...>>;
#else
// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit
// at a slight loss of compilation efficiency).
template <class... Ts> using all_of = std::conjunction<Ts...>;
template <class... Ts> using any_of = std::disjunction<Ts...>;
#endif
template <class... Ts> using none_of = negation<any_of<Ts...>>;
template <class T, template<class> class... Predicates> using satisfies_all_of = all_of<Predicates<T>...>;
template <class T, template<class> class... Predicates> using satisfies_any_of = any_of<Predicates<T>...>;
template <class T, template<class> class... Predicates> using satisfies_none_of = none_of<Predicates<T>...>;
/// Strip the class from a method type
template <typename T> struct remove_class { };
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { typedef R type(A...); };
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...) const> { typedef R type(A...); };
/// Helper template to strip away type modifiers
template <typename T> struct intrinsic_type { typedef T type; };
template <typename T> struct intrinsic_type<const T> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T*> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T&> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T&&> { typedef typename intrinsic_type<T>::type type; };
template <typename T, size_t N> struct intrinsic_type<const T[N]> { typedef typename intrinsic_type<T>::type type; };
template <typename T, size_t N> struct intrinsic_type<T[N]> { typedef typename intrinsic_type<T>::type type; };
template <typename T> using intrinsic_t = typename intrinsic_type<T>::type;
/// Helper type to replace 'void' in some expressions
struct void_type { };
/// Helper template which holds a list of types
template <typename...> struct type_list { };
/// Compile-time integer sum
#ifdef __cpp_fold_expressions
template <typename... Ts> constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); }
#else
constexpr size_t constexpr_sum() { return 0; }
template <typename T, typename... Ts>
constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); }
#endif
NAMESPACE_BEGIN(constexpr_impl)
/// Implementation details for constexpr functions
constexpr int first(int i) { return i; }
template <typename T, typename... Ts>
constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); }
constexpr int last(int /*i*/, int result) { return result; }
template <typename T, typename... Ts>
constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); }
NAMESPACE_END(constexpr_impl)
/// Return the index of the first type in Ts which satisfies Predicate<T>. Returns sizeof...(Ts) if
/// none match.
template <template<typename> class Predicate, typename... Ts>
constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate<Ts>::value...); }
/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
template <template<typename> class Predicate, typename... Ts>
constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate<Ts>::value...); }
/// Return the Nth element from the parameter pack
template <size_t N, typename T, typename... Ts>
struct pack_element { using type = typename pack_element<N - 1, Ts...>::type; };
template <typename T, typename... Ts>
struct pack_element<0, T, Ts...> { using type = T; };
/// Return the one and only type which matches the predicate, or Default if none match.
/// If more than one type matches the predicate, fail at compile-time.
template <template<typename> class Predicate, typename Default, typename... Ts>
struct exactly_one {
static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
static_assert(found <= 1, "Found more than one type matching the predicate");
static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
};
template <template<typename> class P, typename Default>
struct exactly_one<P, Default> { using type = Default; };
template <template<typename> class Predicate, typename Default, typename... Ts>
using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
/// Defer the evaluation of type T until types Us are instantiated
template <typename T, typename... /*Us*/> struct deferred_type { using type = T; };
template <typename T, typename... Us> using deferred_t = typename deferred_type<T, Us...>::type;
/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
/// unlike `std::is_base_of`)
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer
/// can be converted to a Base pointer)
template <typename Base, typename Derived> using is_accessible_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && std::is_convertible<Derived *, Base *>::value>;
template <template<typename...> class Base>
struct is_template_base_of_impl {
template <typename... Us> static std::true_type check(Base<Us...> *);
static std::false_type check(...);
};
/// Check if a template is the base of a type. For example:
/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
template <template<typename...> class Base, typename T>
#if !defined(_MSC_VER)
using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr));
#else // MSVC2015 has trouble with decltype in template aliases
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr)) { };
#endif
/// Check if T is an instantiation of the template `Class`. For example:
/// `is_instantiation<shared_ptr, T>` is true if `T == shared_ptr<U>` where U can be anything.
template <template<typename...> class Class, typename T>
struct is_instantiation : std::false_type { };
template <template<typename...> class Class, typename... Us>
struct is_instantiation<Class, Class<Us...>> : std::true_type { };
/// Check if T is std::shared_ptr<U> where U can be anything
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
/// Check if T looks like an input iterator
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
template <typename T>
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
: std::true_type {};
template <typename T> using is_function_pointer = bool_constant<
std::is_pointer<T>::value && std::is_function<typename std::remove_pointer<T>::type>::value>;
template <typename F> struct strip_function_object {
using type = typename remove_class<decltype(&F::operator())>::type;
};
// Extracts the function signature from a function, function pointer or lambda.
template <typename Function, typename F = remove_reference_t<Function>>
using function_signature_t = conditional_t<
std::is_function<F>::value,
F,
typename conditional_t<
std::is_pointer<F>::value || std::is_member_pointer<F>::value,
std::remove_pointer<F>,
strip_function_object<F>
>::type
>;
/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member
/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used
/// in a place where passing a lambda makes sense.
template <typename T> using is_lambda = satisfies_none_of<remove_reference_t<T>,
std::is_function, std::is_pointer, std::is_member_pointer>;
/// Ignore that a variable is unused in compiler warnings
inline void ignore_unused(const int *) { }
/// Apply a function over each element of a parameter pack
#ifdef __cpp_fold_expressions
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...)
#else
using expand_side_effects = bool[];
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false }
#endif
NAMESPACE_END(detail)
/// C++ bindings of builtin Python exceptions
class builtin_exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
/// Set the error using the Python C API
virtual void set_error() const = 0;
};
#define PYBIND11_RUNTIME_EXCEPTION(name, type) \
class name : public builtin_exception { public: \
using builtin_exception::builtin_exception; \
name() : name("") { } \
void set_error() const override { PyErr_SetString(type, what()); } \
};
PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); }
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); }
template <typename T, typename SFINAE = void> struct format_descriptor { };
NAMESPACE_BEGIN(detail)
// Returns the index of the given type in the type char array below, and in the list in numpy.h
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
// complex float,double,long double. Note that the long double types only participate when long
// double is actually longer than double (it isn't under MSVC).
// NB: not only the string below but also complex.h and numpy.h rely on this order.
template <typename T, typename SFINAE = void> struct is_fmt_numeric { static constexpr bool value = false; };
template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr bool value = true;
static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
std::is_integral<T>::value ? detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value : 8 + (
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
};
NAMESPACE_END(detail)
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
static constexpr const char value[2] = { c, '\0' };
static std::string format() { return std::string(1, c); }
};
#if !defined(PYBIND11_CPP17)
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
#endif
/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;
error_scope() { PyErr_Fetch(&type, &value, &trace); }
~error_scope() { PyErr_Restore(type, value, trace); }
};
/// Dummy destructor wrapper that can be used to expose classes with a private destructor
struct nodelete { template <typename T> void operator()(T*) { } };
// overload_cast requires variable templates: C++14
#if defined(PYBIND11_CPP14)
#define PYBIND11_OVERLOAD_CAST 1
NAMESPACE_BEGIN(detail)
template <typename... Args>
struct overload_cast_impl {
constexpr overload_cast_impl() {} // MSVC 2015 needs this
template <typename Return>
constexpr auto operator()(Return (*pf)(Args...)) const noexcept
-> decltype(pf) { return pf; }
template <typename Return, typename Class>
constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept
-> decltype(pmf) { return pmf; }
template <typename Return, typename Class>
constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept
-> decltype(pmf) { return pmf; }
};
NAMESPACE_END(detail)
/// Syntax sugar for resolving overloaded function pointers:
/// - regular: static_cast<Return (Class::*)(Arg0, Arg1, Arg2)>(&Class::func)
/// - sweet: overload_cast<Arg0, Arg1, Arg2>(&Class::func)
template <typename... Args>
static constexpr detail::overload_cast_impl<Args...> overload_cast = {};
// MSVC 2015 only accepts this particular initialization syntax for this variable template.
/// Const member function selector for overload_cast
/// - regular: static_cast<Return (Class::*)(Arg) const>(&Class::func)
/// - sweet: overload_cast<Arg>(&Class::func, const_)
static constexpr auto const_ = std::true_type{};
#else // no overload_cast: providing something that static_assert-fails:
template <typename... Args> struct overload_cast {
static_assert(detail::deferred_t<std::false_type, Args...>::value,
"pybind11::overload_cast<...> requires compiling in C++14 mode");
};
#endif // overload_cast
NAMESPACE_BEGIN(detail)
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
// any standard container (or C-style array) supporting std::begin/std::end, any singleton
// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair.
template <typename T>
class any_container {
std::vector<T> v;
public:
any_container() = default;
// Can construct from a pair of iterators
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
any_container(It first, It last) : v(first, last) { }
// Implicit conversion constructor from any arbitrary container type with values convertible to T
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
// to explicitly allow implicit conversion from one:
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
// Avoid copying if given an rvalue vector of the correct type.
any_container(std::vector<T> &&v) : v(std::move(v)) { }
// Moves the vector out of an rvalue any_container
operator std::vector<T> &&() && { return std::move(v); }
// Dereferencing obtains a reference to the underlying vector
std::vector<T> &operator*() { return v; }
const std::vector<T> &operator*() const { return v; }
// -> lets you call methods on the underlying vector
std::vector<T> *operator->() { return &v; }
const std::vector<T> *operator->() const { return &v; }
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,100 +0,0 @@
/*
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if !defined(_MSC_VER)
# define PYBIND11_DESCR_CONSTEXPR static constexpr
#else
# define PYBIND11_DESCR_CONSTEXPR const
#endif
/* Concatenate type signatures at compile time */
template <size_t N, typename... Ts>
struct descr {
char text[N + 1];
constexpr descr() : text{'\0'} { }
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
template <size_t... Is>
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
template <typename... Chars>
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
return {{&typeid(Ts)..., nullptr}};
}
};
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
index_sequence<Is1...>, index_sequence<Is2...>) {
return {a.text[Is1]..., b.text[Is2]...};
}
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
}
template <size_t N>
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
constexpr descr<0> _(char const(&)[1]) { return {}; }
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
template <size_t...Digits> struct int_to_str<0, Digits...> {
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
};
// Ternary description (like std::conditional)
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
return _(text1);
}
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
return _(text2);
}
template <bool B, typename T1, typename T2>
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
template <bool B, typename T1, typename T2>
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
return int_to_str<Size / 10, Size % 10>::digits;
}
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
constexpr descr<0> concat() { return {}; }
template <size_t N, typename... Ts>
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
template <size_t N, typename... Ts, typename... Args>
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
return d + _(", ") + concat(args...);
}
template <size_t N, typename... Ts>
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
return _("{") + descr + _("}");
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,335 +0,0 @@
/*
pybind11/detail/init.h: init factory function implementation and support code.
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "class.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <>
class type_caster<value_and_holder> {
public:
bool load(handle h, bool) {
value = reinterpret_cast<value_and_holder *>(h.ptr());
return true;
}
template <typename> using cast_op_type = value_and_holder &;
operator value_and_holder &() { return *value; }
static constexpr auto name = _<value_and_holder>();
private:
value_and_holder *value = nullptr;
};
NAMESPACE_BEGIN(initimpl)
inline void no_nullptr(void *ptr) {
if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
}
// Implementing functions for all forms of py::init<...> and py::init(...)
template <typename Class> using Cpp = typename Class::type;
template <typename Class> using Alias = typename Class::type_alias;
template <typename Class> using Holder = typename Class::holder_type;
template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
bool is_alias(Cpp<Class> *ptr) {
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
}
// Failing fallback version of the above for a no-alias class (always returns false)
template <typename /*Class*/>
constexpr bool is_alias(void *) { return false; }
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
// back to brace aggregate initiailization so that for aggregate initialization can be used with
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
// an alias to provide only a single Cpp factory function as long as the Alias can be
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
// inherit all the base class constructors.
template <typename Class>
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
value_and_holder &v_h, Cpp<Class> &&base) {
v_h.value_ptr() = new Alias<Class>(std::move(base));
}
template <typename Class>
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
value_and_holder &, Cpp<Class> &&) {
throw type_error("pybind11::init(): unable to convert returned instance to required "
"alias class: no `Alias<Class>(Class &&)` constructor available");
}
// Error-generating fallback for factories that don't match one of the below construction
// mechanisms.
template <typename Class>
void construct(...) {
static_assert(!std::is_same<Class, Class>::value /* always false */,
"pybind11::init(): init function must return a compatible pointer, "
"holder, or value");
}
// Pointer return v1: the factory function returns a class pointer for a registered class.
// If we don't need an alias (because this class doesn't have one, or because the final type is
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
// construct an Alias from the returned base instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
no_nullptr(ptr);
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
// We're going to try to construct an alias by moving the cpp type. Whether or not
// that succeeds, we still need to destroy the original cpp pointer (either the
// moved away leftover, if the alias construction works, or the value itself if we
// throw an error), but we can't just call `delete ptr`: it might have a special
// deleter, or might be shared_from_this. So we construct a holder around it as if
// it was a normal instance, then steal the holder away into a local variable; thus
// the holder and destruction happens when we leave the C++ scope, and the holder
// class gets to handle the destruction however it likes.
v_h.value_ptr() = ptr;
v_h.set_instance_registered(true); // To prevent init_instance from registering it
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
v_h.set_instance_registered(false);
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
} else {
// Otherwise the type isn't inherited, so we don't need an Alias
v_h.value_ptr() = ptr;
}
}
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
// ownership of the pointer.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
no_nullptr(alias_ptr);
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
}
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
// derived type (through those holder's implicit conversion from derived class holder constructors).
template <typename Class>
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
auto *ptr = holder_helper<Holder<Class>>::get(holder);
// If we need an alias, check that the held pointer is actually an alias instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
"is not an alias instance");
v_h.value_ptr() = ptr;
v_h.type->init_instance(v_h.inst, &holder);
}
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
// need it, we simply move-construct the cpp value into a new instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
static_assert(std::is_move_constructible<Cpp<Class>>::value,
"pybind11::init() return-by-value factory function requires a movable class");
if (Class::has_alias && need_alias)
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
else
v_h.value_ptr() = new Cpp<Class>(std::move(result));
}
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
// Alias instance (even if no the python-side inheritance is involved). The is intended for
// cases where Alias initialization is always desired.
template <typename Class>
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
static_assert(std::is_move_constructible<Alias<Class>>::value,
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
v_h.value_ptr() = new Alias<Class>(std::move(result));
}
// Implementing class for py::init<...>()
template <typename... Args>
struct constructor {
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
else
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
!std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
};
// Implementing class for py::init_alias<...>()
template <typename... Args> struct alias_constructor {
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
};
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc = void_type (*)(),
typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
struct factory;
// Specialization for py::init(Func)
template <typename Func, typename Return, typename... Args>
struct factory<Func, void_type (*)(), Return(Args...)> {
remove_reference_t<Func> class_factory;
factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
// The given class either has no alias or has no separate alias factory;
// this always constructs the class itself. If the class is registered with an alias
// type and an alias instance is needed (i.e. because the final type is a Python class
// inheriting from the C++ type) the returned value needs to either already be an alias
// instance, or the alias needs to be constructible from a `Class &&` argument.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
#if defined(PYBIND11_CPP14)
cl.def("__init__", [func = std::move(class_factory)]
#else
auto &func = class_factory;
cl.def("__init__", [func]
#endif
(value_and_holder &v_h, Args... args) {
construct<Class>(v_h, func(std::forward<Args>(args)...),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
};
// Specialization for py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc,
typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
remove_reference_t<CFunc> class_factory;
remove_reference_t<AFunc> alias_factory;
factory(CFunc &&c, AFunc &&a)
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
// The class factory is called when the `self` type passed to `__init__` is the direct
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra&... extra) && {
static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
"only be used if the class has an alias");
#if defined(PYBIND11_CPP14)
cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
#else
auto &class_func = class_factory;
auto &alias_func = alias_factory;
cl.def("__init__", [class_func, alias_func]
#endif
(value_and_holder &v_h, CArgs... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
// If the instance type equals the registered type we don't have inheritance, so
// don't need the alias and can construct using the class function:
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
else
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
}, is_new_style_constructor(), extra...);
}
};
/// Set just the C++ state. Same as `__init__`.
template <typename Class, typename T>
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
construct<Class>(v_h, std::forward<T>(result), need_alias);
}
/// Set both the C++ and Python states
template <typename Class, typename T, typename O,
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
construct<Class>(v_h, std::move(result.first), need_alias);
setattr((PyObject *) v_h.inst, "__dict__", result.second);
}
/// Implementation for py::pickle(GetState, SetState)
template <typename Get, typename Set,
typename = function_signature_t<Get>, typename = function_signature_t<Set>>
struct pickle_factory;
template <typename Get, typename Set,
typename RetState, typename Self, typename NewInstance, typename ArgState>
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
"The type returned by `__getstate__` must be the same "
"as the argument accepted by `__setstate__`");
remove_reference_t<Get> get;
remove_reference_t<Set> set;
pickle_factory(Get get, Set set)
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
cl.def("__getstate__", std::move(get));
#if defined(PYBIND11_CPP14)
cl.def("__setstate__", [func = std::move(set)]
#else
auto &func = set;
cl.def("__setstate__", [func]
#endif
(value_and_holder &v_h, ArgState state) {
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
};
NAMESPACE_END(initimpl)
NAMESPACE_END(detail)
NAMESPACE_END(pybind11)

View File

@@ -1,291 +0,0 @@
/*
pybind11/detail/internals.h: Internal data structure and related functions
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "../pytypes.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
// Forward declarations
inline PyTypeObject *make_static_property_type();
inline PyTypeObject *make_default_metaclass();
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
// Thread Specific Storage (TSS) API.
#if PY_VERSION_HEX >= 0x03070000
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
#else
// Usually an int but a long on Cygwin64 with Python 3.x
# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
# if PY_MAJOR_VERSION < 3
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_delete_key_value(key)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
do { \
PyThread_delete_key_value((key)); \
PyThread_set_key_value((key), (value)); \
} while (false)
# else
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_set_key_value((key), nullptr)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
PyThread_set_key_value((key), (value))
# endif
#endif
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
// which works. If not under a known-good stl, provide our own name-based hash and equality
// functions that use the type name.
#if defined(__GLIBCXX__)
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
using type_hash = std::hash<std::type_index>;
using type_equal_to = std::equal_to<std::type_index>;
#else
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
struct type_hash {
size_t operator()(const std::type_index &t) const {
size_t hash = 5381;
const char *ptr = t.name();
while (auto c = static_cast<unsigned char>(*ptr++))
hash = (hash * 33) ^ c;
return hash;
}
};
struct type_equal_to {
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
};
#endif
template <typename value_type>
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
struct overload_hash {
inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
size_t value = std::hash<const void *>()(v.first);
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
return value;
}
};
/// Internal data structure used to track registered instances and types.
/// Whenever binary incompatible changes are made to this structure,
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
struct internals {
type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
PyTypeObject *static_property_type;
PyTypeObject *default_metaclass;
PyObject *instance_base;
#if defined(WITH_THREAD)
PYBIND11_TLS_KEY_INIT(tstate);
PyInterpreterState *istate = nullptr;
#endif
};
/// Additional type information which does not fit into the PyTypeObject.
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
struct type_info {
PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, const void *);
void (*dealloc)(value_and_holder &v_h);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
/* A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type : 1;
/* True if there is no multiple inheritance in this type's inheritance tree */
bool simple_ancestors : 1;
/* for base vs derived holder_type checks */
bool default_holder : 1;
/* true if this is a type registered with py::module_local */
bool module_local : 1;
};
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
#define PYBIND11_INTERNALS_VERSION 3
#if defined(_DEBUG)
# define PYBIND11_BUILD_TYPE "_debug"
#else
# define PYBIND11_BUILD_TYPE ""
#endif
#if defined(WITH_THREAD)
# define PYBIND11_INTERNALS_KIND ""
#else
# define PYBIND11_INTERNALS_KIND "_without_thread"
#endif
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
/// Each module locally stores a pointer to the `internals` data. The data
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
inline internals **&get_internals_pp() {
static internals **internals_pp = nullptr;
return internals_pp;
}
/// Return a reference to the current `internals` data
PYBIND11_NOINLINE inline internals &get_internals() {
auto **&internals_pp = get_internals_pp();
if (internals_pp && *internals_pp)
return **internals_pp;
constexpr auto *id = PYBIND11_INTERNALS_ID;
auto builtins = handle(PyEval_GetBuiltins());
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
internals_pp = static_cast<internals **>(capsule(builtins[id]));
// We loaded builtins through python's builtins, which means that our `error_already_set`
// and `builtin_exception` may be different local classes than the ones set up in the
// initial exception translator, below, so add another for our local exception classes.
//
// libstdc++ doesn't require this (types there are identified only by name)
#if !defined(__GLIBCXX__)
(*internals_pp)->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
}
}
);
#endif
} else {
if (!internals_pp) internals_pp = new internals*();
auto *&internals_ptr = *internals_pp;
internals_ptr = new internals();
#if defined(WITH_THREAD)
PyEval_InitThreads();
PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
PyThread_tss_set(internals_ptr->tstate, tstate);
#else
internals_ptr->tstate = PyThread_create_key();
if (internals_ptr->tstate == -1)
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
PyThread_set_key_value(internals_ptr->tstate, tstate);
#endif
internals_ptr->istate = tstate->interp;
#endif
builtins[id] = capsule(internals_pp);
internals_ptr->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
} catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
} catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
} catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
} catch (...) {
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
return;
}
}
);
internals_ptr->static_property_type = make_static_property_type();
internals_ptr->default_metaclass = make_default_metaclass();
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
}
return **internals_pp;
}
/// Works like `internals.registered_types_cpp`, but for module-local registered types:
inline type_map<type_info *> &registered_local_types_cpp() {
static type_map<type_info *> locals{};
return locals;
}
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
template <typename... Args>
const char *c_str(Args &&...args) {
auto &strings = get_internals().static_strings;
strings.emplace_front(std::forward<Args>(args)...);
return strings.front().c_str();
}
NAMESPACE_END(detail)
/// Returns a named pointer that is shared among all extension modules (using the same
/// pybind11 version) running in the current interpreter. Names starting with underscores
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
}
/// Set the shared data that can be later recovered by `get_shared_data()`.
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
detail::get_internals().shared_data[name] = data;
return data;
}
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
/// added to the shared data under the given name and a reference to it is returned.
template<typename T>
T &get_or_create_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,55 +0,0 @@
/*
pybind11/detail/typeid.h: Compiler-independent access to type identifiers
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include <cstdio>
#include <cstdlib>
#if defined(__GNUG__)
#include <cxxabi.h>
#endif
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Erase all occurrences of a substring
inline void erase_all(std::string &string, const std::string &search) {
for (size_t pos = 0;;) {
pos = string.find(search, pos);
if (pos == std::string::npos) break;
string.erase(pos, search.length());
}
}
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
#if defined(__GNUG__)
int status = 0;
std::unique_ptr<char, void (*)(void *)> res {
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
if (status == 0)
name = res.get();
#else
detail::erase_all(name, "class ");
detail::erase_all(name, "struct ");
detail::erase_all(name, "enum ");
#endif
detail::erase_all(name, "pybind11::");
}
NAMESPACE_END(detail)
/// Return a string representation of a C++ type
template <typename T> static std::string type_id() {
std::string name(typeid(T).name());
detail::clean_type_id(name);
return name;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,607 +0,0 @@
/*
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "numpy.h"
#if defined(__INTEL_COMPILER)
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
#elif defined(__GNUG__) || defined(__clang__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wconversion"
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
# ifdef __clang__
// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
// under Clang, so disable that warning here:
# pragma GCC diagnostic ignored "-Wdeprecated"
# endif
# if __GNUC__ >= 7
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
# endif
#endif
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17
#endif
#include <Eigen/Core>
#include <Eigen/SparseCore>
// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
// move constructors that break things. We could detect this an explicitly copy, but an extra copy
// of matrices seems highly undesirable.
static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
NAMESPACE_BEGIN(detail)
#if EIGEN_VERSION_AT_LEAST(3,3,0)
using EigenIndex = Eigen::Index;
#else
using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
#endif
// Matches Eigen::Map, Eigen::Ref, blocks, etc:
template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
// SelfAdjointView fall into this category.
template <typename T> using is_eigen_other = all_of<
is_template_base_of<Eigen::EigenBase, T>,
negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
>;
// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
template <bool EigenRowMajor> struct EigenConformable {
bool conformable = false;
EigenIndex rows = 0, cols = 0;
EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
bool negativestrides = false; // If true, do not use stride!
EigenConformable(bool fits = false) : conformable{fits} {}
// Matrix type:
EigenConformable(EigenIndex r, EigenIndex c,
EigenIndex rstride, EigenIndex cstride) :
conformable{true}, rows{r}, cols{c} {
// TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
if (rstride < 0 || cstride < 0) {
negativestrides = true;
} else {
stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
EigenRowMajor ? cstride : rstride /* inner stride */ };
}
}
// Vector type:
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
: EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
template <typename props> bool stride_compatible() const {
// To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
// matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
return
!negativestrides &&
(props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
(EigenRowMajor ? cols : rows) == 1) &&
(props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
(EigenRowMajor ? rows : cols) == 1);
}
operator bool() const { return conformable; }
};
template <typename Type> struct eigen_extract_stride { using type = Type; };
template <typename PlainObjectType, int MapOptions, typename StrideType>
struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
template <typename PlainObjectType, int Options, typename StrideType>
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
// Helper struct for extracting information from an Eigen type
template <typename Type_> struct EigenProps {
using Type = Type_;
using Scalar = typename Type::Scalar;
using StrideType = typename eigen_extract_stride<Type>::type;
static constexpr EigenIndex
rows = Type::RowsAtCompileTime,
cols = Type::ColsAtCompileTime,
size = Type::SizeAtCompileTime;
static constexpr bool
row_major = Type::IsRowMajor,
vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
fixed_rows = rows != Eigen::Dynamic,
fixed_cols = cols != Eigen::Dynamic,
fixed = size != Eigen::Dynamic, // Fully-fixed size
dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
vector ? size : row_major ? cols : rows>::value;
static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
// Takes an input array and determines whether we can make it fit into the Eigen type. If
// the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
// (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
static EigenConformable<row_major> conformable(const array &a) {
const auto dims = a.ndim();
if (dims < 1 || dims > 2)
return false;
if (dims == 2) { // Matrix type: require exact match (or dynamic)
EigenIndex
np_rows = a.shape(0),
np_cols = a.shape(1),
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
return false;
return {np_rows, np_cols, np_rstride, np_cstride};
}
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
// is used, we want the (single) numpy stride value.
const EigenIndex n = a.shape(0),
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
if (vector) { // Eigen type is a compile-time vector
if (fixed && size != n)
return false; // Vector size mismatch
return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
}
else if (fixed) {
// The type has a fixed size, but is not a vector: abort
return false;
}
else if (fixed_cols) {
// Since this isn't a vector, cols must be != 1. We allow this only if it exactly
// equals the number of elements (rows is Dynamic, and so 1 row is allowed).
if (cols != n) return false;
return {1, n, stride};
}
else {
// Otherwise it's either fully dynamic, or column dynamic; both become a column vector
if (fixed_rows && rows != n) return false;
return {n, 1, stride};
}
}
static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
static constexpr bool show_c_contiguous = show_order && requires_row_major;
static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
static constexpr auto descriptor =
_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
_("]") +
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
// satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
// options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
// to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
// see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
// *gave* a numpy.ndarray of the right type and dimensions.
_<show_writeable>(", flags.writeable", "") +
_<show_c_contiguous>(", flags.c_contiguous", "") +
_<show_f_contiguous>(", flags.f_contiguous", "") +
_("]");
};
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
array a;
if (props::vector)
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
else
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
src.data(), base);
if (!writeable)
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
return a.release();
}
// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
// non-writeable if the given type is const.
template <typename props, typename Type>
handle eigen_ref_array(Type &src, handle parent = none()) {
// none here is to get past array's should-we-copy detection, which currently always
// copies when there is no base. Setting the base to None should be harmless.
return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
}
// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
// array that references the encapsulated data with a python-side reference to the capsule to tie
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
// not the Type of the pointer given is const.
template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
handle eigen_encapsulate(Type *src) {
capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
return eigen_ref_array<props>(*src, base);
}
// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
// types.
template<typename Type>
struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
using Scalar = typename Type::Scalar;
using props = EigenProps<Type>;
bool load(handle src, bool convert) {
// If we're in no-convert mode, only load if given an array of the correct type
if (!convert && !isinstance<array_t<Scalar>>(src))
return false;
// Coerce into an array, but don't do type conversion yet; the copy below handles it.
auto buf = array::ensure(src);
if (!buf)
return false;
auto dims = buf.ndim();
if (dims < 1 || dims > 2)
return false;
auto fits = props::conformable(buf);
if (!fits)
return false;
// Allocate the new type, then build a numpy reference into it
value = Type(fits.rows, fits.cols);
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
if (dims == 1) ref = ref.squeeze();
else if (ref.ndim() == 1) buf = buf.squeeze();
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
if (result < 0) { // Copy failed!
PyErr_Clear();
return false;
}
return true;
}
private:
// Cast implementation
template <typename CType>
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::take_ownership:
case return_value_policy::automatic:
return eigen_encapsulate<props>(src);
case return_value_policy::move:
return eigen_encapsulate<props>(new CType(std::move(*src)));
case return_value_policy::copy:
return eigen_array_cast<props>(*src);
case return_value_policy::reference:
case return_value_policy::automatic_reference:
return eigen_ref_array<props>(*src);
case return_value_policy::reference_internal:
return eigen_ref_array<props>(*src, parent);
default:
throw cast_error("unhandled return_value_policy: should not happen!");
};
}
public:
// Normal returned non-reference, non-const value:
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
return cast_impl(&src, return_value_policy::move, parent);
}
// If you return a non-reference const, we mark the numpy array readonly:
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
return cast_impl(&src, return_value_policy::move, parent);
}
// lvalue reference return; default (automatic) becomes copy
static handle cast(Type &src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
policy = return_value_policy::copy;
return cast_impl(&src, policy, parent);
}
// const lvalue reference return; default (automatic) becomes copy
static handle cast(const Type &src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
policy = return_value_policy::copy;
return cast(&src, policy, parent);
}
// non-const pointer return
static handle cast(Type *src, return_value_policy policy, handle parent) {
return cast_impl(src, policy, parent);
}
// const pointer return
static handle cast(const Type *src, return_value_policy policy, handle parent) {
return cast_impl(src, policy, parent);
}
static constexpr auto name = props::descriptor;
operator Type*() { return &value; }
operator Type&() { return value; }
operator Type&&() && { return std::move(value); }
template <typename T> using cast_op_type = movable_cast_op_type<T>;
private:
Type value;
};
// Base class for casting reference/map/block/etc. objects back to python.
template <typename MapType> struct eigen_map_caster {
private:
using props = EigenProps<MapType>;
public:
// Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
// to stay around), but we'll allow it under the assumption that you know what you're doing (and
// have an appropriate keep_alive in place). We return a numpy array pointing directly at the
// ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
// that this means you need to ensure you don't destroy the object in some other way (e.g. with
// an appropriate keep_alive, or with a reference to a statically allocated matrix).
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::copy:
return eigen_array_cast<props>(src);
case return_value_policy::reference_internal:
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
case return_value_policy::reference:
case return_value_policy::automatic:
case return_value_policy::automatic_reference:
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
default:
// move, take_ownership don't make any sense for a ref/map:
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
}
}
static constexpr auto name = props::descriptor;
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
// you end up here if you try anyway.
bool load(handle, bool) = delete;
operator MapType() = delete;
template <typename> using cast_op_type = MapType;
};
// We can return any map-like object (but can only load Refs, specialized next):
template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
: eigen_map_caster<Type> {};
// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
// copying (it requires some extra effort in many cases).
template <typename PlainObjectType, typename StrideType>
struct type_caster<
Eigen::Ref<PlainObjectType, 0, StrideType>,
enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
> : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
private:
using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
using props = EigenProps<Type>;
using Scalar = typename props::Scalar;
using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
using Array = array_t<Scalar, array::forcecast |
((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
(props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
// Delay construction (these have no default constructor)
std::unique_ptr<MapType> map;
std::unique_ptr<Type> ref;
// Our array. When possible, this is just a numpy array pointing to the source data, but
// sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
// layout, or is an array of a type that needs to be converted). Using a numpy temporary
// (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
Array copy_or_ref;
public:
bool load(handle src, bool convert) {
// First check whether what we have is already an array of the right type. If not, we can't
// avoid a copy (because the copy is also going to do type conversion).
bool need_copy = !isinstance<Array>(src);
EigenConformable<props::row_major> fits;
if (!need_copy) {
// We don't need a converting copy, but we also need to check whether the strides are
// compatible with the Ref's stride requirements
Array aref = reinterpret_borrow<Array>(src);
if (aref && (!need_writeable || aref.writeable())) {
fits = props::conformable(aref);
if (!fits) return false; // Incompatible dimensions
if (!fits.template stride_compatible<props>())
need_copy = true;
else
copy_or_ref = std::move(aref);
}
else {
need_copy = true;
}
}
if (need_copy) {
// We need to copy: If we need a mutable reference, or we're not supposed to convert
// (either because we're in the no-convert overload pass, or because we're explicitly
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
if (!convert || need_writeable) return false;
Array copy = Array::ensure(src);
if (!copy) return false;
fits = props::conformable(copy);
if (!fits || !fits.template stride_compatible<props>())
return false;
copy_or_ref = std::move(copy);
loader_life_support::add_patient(copy_or_ref);
}
ref.reset();
map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
ref.reset(new Type(*map));
return true;
}
operator Type*() { return ref.get(); }
operator Type&() { return *ref; }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
private:
template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
Scalar *data(Array &a) { return a.mutable_data(); }
template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
const Scalar *data(Array &a) { return a.data(); }
// Attempt to figure out a constructor of `Stride` that will work.
// If both strides are fixed, use a default constructor:
template <typename S> using stride_ctor_default = bool_constant<
S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
std::is_default_constructible<S>::value>;
// Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
// Eigen::Stride, and use it:
template <typename S> using stride_ctor_dual = bool_constant<
!stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
// Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
// it (passing whichever stride is dynamic).
template <typename S> using stride_ctor_outer = bool_constant<
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
std::is_constructible<S, EigenIndex>::value>;
template <typename S> using stride_ctor_inner = bool_constant<
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
std::is_constructible<S, EigenIndex>::value>;
template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
static S make_stride(EigenIndex, EigenIndex) { return S(); }
template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
};
// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
// load() is not supported, but we can cast them into the python domain by first copying to a
// regular Eigen::Matrix, then casting that.
template <typename Type>
struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
protected:
using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
using props = EigenProps<Matrix>;
public:
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
handle h = eigen_encapsulate<props>(new Matrix(src));
return h;
}
static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
static constexpr auto name = props::descriptor;
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
// you end up here if you try anyway.
bool load(handle, bool) = delete;
operator Type() = delete;
template <typename> using cast_op_type = Type;
};
template<typename Type>
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
typedef typename Type::Scalar Scalar;
typedef remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())> StorageIndex;
typedef typename Type::Index Index;
static constexpr bool rowMajor = Type::IsRowMajor;
bool load(handle src, bool) {
if (!src)
return false;
auto obj = reinterpret_borrow<object>(src);
object sparse_module = module::import("scipy.sparse");
object matrix_type = sparse_module.attr(
rowMajor ? "csr_matrix" : "csc_matrix");
if (!obj.get_type().is(matrix_type)) {
try {
obj = matrix_type(obj);
} catch (const error_already_set &) {
return false;
}
}
auto values = array_t<Scalar>((object) obj.attr("data"));
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
auto nnz = obj.attr("nnz").cast<Index>();
if (!values || !innerIndices || !outerIndices)
return false;
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
return true;
}
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
const_cast<Type&>(src).makeCompressed();
object matrix_type = module::import("scipy.sparse").attr(
rowMajor ? "csr_matrix" : "csc_matrix");
array data(src.nonZeros(), src.valuePtr());
array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
array innerIndices(src.nonZeros(), src.innerIndexPtr());
return matrix_type(
std::make_tuple(data, innerIndices, outerIndices),
std::make_pair(src.rows(), src.cols())
).release();
}
PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
+ npy_format_descriptor<Scalar>::name + _("]"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(__GNUG__) || defined(__clang__)
# pragma GCC diagnostic pop
#elif defined(_MSC_VER)
# pragma warning(pop)
#endif

View File

@@ -1,200 +0,0 @@
/*
pybind11/embed.h: Support for embedding the interpreter
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include "eval.h"
#if defined(PYPY_VERSION)
# error Embedding the interpreter is not supported with PyPy
#endif
#if PY_MAJOR_VERSION >= 3
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" PyObject *pybind11_init_impl_##name() { \
return pybind11_init_wrapper_##name(); \
}
#else
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" void pybind11_init_impl_##name() { \
pybind11_init_wrapper_##name(); \
}
#endif
/** \rst
Add a new module to the table of builtins for the interpreter. Must be
defined in global scope. The first macro parameter is the name of the
module (without quotes). The second parameter is the variable which will
be used as the interface to add functions and classes to the module.
.. code-block:: cpp
PYBIND11_EMBEDDED_MODULE(example, m) {
// ... initialize functions and classes here
m.def("foo", []() {
return "Hello, World!";
});
}
\endrst */
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
PYBIND11_CONCAT(pybind11_init_impl_, name)); \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
struct embedded_module {
#if PY_MAJOR_VERSION >= 3
using init_t = PyObject *(*)();
#else
using init_t = void (*)();
#endif
embedded_module(const char *name, init_t init) {
if (Py_IsInitialized())
pybind11_fail("Can't add new modules after the interpreter has been initialized");
auto result = PyImport_AppendInittab(name, init);
if (result == -1)
pybind11_fail("Insufficient memory to add a new module");
}
};
NAMESPACE_END(detail)
/** \rst
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
optional parameter can be used to skip the registration of signal handlers (see the
`Python documentation`_ for details). Calling this function again after the interpreter
has already been initialized is a fatal error.
If initializing the Python interpreter fails, then the program is terminated. (This
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
of throwing exceptions on errors.)
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
\endrst */
inline void initialize_interpreter(bool init_signal_handlers = true) {
if (Py_IsInitialized())
pybind11_fail("The interpreter is already running");
Py_InitializeEx(init_signal_handlers ? 1 : 0);
// Make .py files in the working directory available by default
module::import("sys").attr("path").cast<list>().append(".");
}
/** \rst
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
after this. In addition, pybind11 objects must not outlive the interpreter:
.. code-block:: cpp
{ // BAD
py::initialize_interpreter();
auto hello = py::str("Hello, World!");
py::finalize_interpreter();
} // <-- BOOM, hello's destructor is called after interpreter shutdown
{ // GOOD
py::initialize_interpreter();
{ // scoped
auto hello = py::str("Hello, World!");
} // <-- OK, hello is cleaned up properly
py::finalize_interpreter();
}
{ // BETTER
py::scoped_interpreter guard{};
auto hello = py::str("Hello, World!");
}
.. warning::
The interpreter can be restarted by calling `initialize_interpreter` again.
Modules created using pybind11 can be safely re-initialized. However, Python
itself cannot completely unload binary extension modules and there are several
caveats with regard to interpreter restarting. All the details can be found
in the CPython documentation. In short, not all interpreter memory may be
freed, either due to reference cycles or user-created global data.
\endrst */
inline void finalize_interpreter() {
handle builtins(PyEval_GetBuiltins());
const char *id = PYBIND11_INTERNALS_ID;
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
// It could also be stashed in builtins, so look there too:
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
internals_ptr_ptr = capsule(builtins[id]);
Py_Finalize();
if (internals_ptr_ptr) {
delete *internals_ptr_ptr;
*internals_ptr_ptr = nullptr;
}
}
/** \rst
Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
This a move-only guard and only a single instance can exist.
.. code-block:: cpp
#include <pybind11/embed.h>
int main() {
py::scoped_interpreter guard{};
py::print(Hello, World!);
} // <-- interpreter shutdown
\endrst */
class scoped_interpreter {
public:
scoped_interpreter(bool init_signal_handlers = true) {
initialize_interpreter(init_signal_handlers);
}
scoped_interpreter(const scoped_interpreter &) = delete;
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
~scoped_interpreter() {
if (is_valid)
finalize_interpreter();
}
private:
bool is_valid = true;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,117 +0,0 @@
/*
pybind11/exec.h: Support for evaluating Python expressions and statements
from strings and files
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
enum eval_mode {
/// Evaluate a string containing an isolated expression
eval_expr,
/// Evaluate a string containing a single statement. Returns \c none
eval_single_statement,
/// Evaluate a string containing a sequence of statement. Returns \c none
eval_statements
};
template <eval_mode mode = eval_expr>
object eval(str expr, object global = globals(), object local = object()) {
if (!local)
local = global;
/* PyRun_String does not accept a PyObject / encoding specifier,
this seems to be the only alternative */
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
template <eval_mode mode = eval_expr, size_t N>
object eval(const char (&s)[N], object global = globals(), object local = object()) {
/* Support raw string literals by removing common leading whitespace */
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
: str(s);
return eval<mode>(expr, global, local);
}
inline void exec(str expr, object global = globals(), object local = object()) {
eval<eval_statements>(expr, global, local);
}
template <size_t N>
void exec(const char (&s)[N], object global = globals(), object local = object()) {
eval<eval_statements>(s, global, local);
}
template <eval_mode mode = eval_statements>
object eval_file(str fname, object global = globals(), object local = object()) {
if (!local)
local = global;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
int closeFile = 1;
std::string fname_str = (std::string) fname;
#if PY_VERSION_HEX >= 0x03040000
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
#elif PY_VERSION_HEX >= 0x03000000
FILE *f = _Py_fopen(fname.ptr(), "r");
#else
/* No unicode support in open() :( */
auto fobj = reinterpret_steal<object>(PyFile_FromString(
const_cast<char *>(fname_str.c_str()),
const_cast<char*>("r")));
FILE *f = nullptr;
if (fobj)
f = PyFile_AsFile(fobj.ptr());
closeFile = 0;
#endif
if (!f) {
PyErr_Clear();
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
}
#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
local.ptr());
(void) closeFile;
#else
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
local.ptr(), closeFile);
#endif
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,108 +0,0 @@
/*
pybind11/functional.h: std::function<> support
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <functional>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename Return, typename... Args>
struct type_caster<std::function<Return(Args...)>> {
using type = std::function<Return(Args...)>;
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
using function_type = Return (*) (Args...);
public:
bool load(handle src, bool convert) {
if (src.is_none()) {
// Defer accepting None to other overloads (if we aren't in convert mode):
if (!convert) return false;
return true;
}
if (!isinstance<function>(src))
return false;
auto func = reinterpret_borrow<function>(src);
/*
When passing a C++ function as an argument to another C++
function via Python, every function call would normally involve
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
Here, we try to at least detect the case where the function is
stateless (i.e. function pointer or lambda function without
captured variables), in which case the roundtrip can be avoided.
*/
if (auto cfunc = func.cpp_function()) {
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
auto rec = (function_record *) c;
if (rec && rec->is_stateless &&
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
struct capture { function_type f; };
value = ((capture *) &rec->data)->f;
return true;
}
}
// ensure GIL is held during functor destruction
struct func_handle {
function f;
func_handle(function&& f_) : f(std::move(f_)) {}
func_handle(const func_handle&) = default;
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
// gil_scoped_acquire acq;
// object retval(hfunc.f(std::forward<Args>(args)...));
// /* Visual studio 2015 parser issue: need parentheses around this expression */
// return (retval.template cast<Return>());
// };
struct func_wrapper {
func_handle hfunc;
func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
object retval(hfunc.f(std::forward<Args>(args)...));
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
}
};
value = func_wrapper(func_handle(std::move(func)));
return true;
}
template <typename Func>
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
if (!f_)
return none().inc_ref();
auto result = f_.template target<function_type>();
if (result)
return cpp_function(*result, policy).release();
else
return cpp_function(std::forward<Func>(f_), policy).release();
}
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
+ make_caster<retval_type>::name + _("]"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

Some files were not shown because too many files have changed in this diff Show More