148 Commits

Author SHA1 Message Date
Yang Hau
8650b4d1cb [DRIVER] Fix typos (#939) 2022-12-02 11:13:46 -08:00
Crutcher Dunnavant
44f577984d Fix format double substitution bug: {i} => {{i}} (#886)
The previous `{i}` was silently expanding to the `i` from the
enumeration loop on `regular_args` (when it wasn't empty).
2022-11-20 11:44:42 -08:00
Crutcher Dunnavant
0e4691e6dd [FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883)
Ran mypy over `build_extern.py`, cleaned up type annotations.

Found a fixed a bug where `ExternLibrary(format=)` was being ignored.
2022-11-17 18:45:30 +01:00
Natalia Gimelshein
0d7e753227 [TESTING] use torch.int for autotuning cache (#840)
For stupid reasons, ops on int8 are 3 times slower than on int, and for
another set of stupid reasons we are not using cudaMemset for `zero_`,
so using `int8` buffer in `do_bench` makes it slow.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-11-04 18:05:16 -07:00
Shintaro Iwasaki
77bc5187b5 Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.

## What this PR does

- Skip unsupported tests on P100.
  - Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.

Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).

### Results on P100 (Google Cloud)

```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```

<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>

### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
  OS: Ubuntu 18.04 LTS
  Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried.  I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed).  Needed to wait an hour for approval
```

### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"

## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools

## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'

## 4. Test
pytest test/unit
```

### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
```

</p></details>
2022-11-03 00:11:52 -07:00
Chenggang Zhao
f16138d447 [Frontend] Interface fixes for libdevice (#830)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
2022-11-01 10:51:58 -07:00
Mark Saroufim
578ada7740 [DOCS] Add install from source instructions to README (#821) 2022-10-31 11:08:18 -07:00
Phil Tillet
6311d70406 Revert "[BUILD] Now using cibuildwheel default"
This reverts commit 584086f08c.
2022-10-29 17:15:47 -07:00
Phil Tillet
584086f08c [BUILD] Now using cibuildwheel default 2022-10-29 16:59:06 -07:00
Keren Zhou
3ca667dfa8 [Frontend] Return a scalar if all input args are scalar (#816) 2022-10-28 23:27:06 -07:00
Yanbo Liang
5ca1ed0101 Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during
triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need
add missing triton support for bf16/fp16/fp64.
2022-10-24 19:41:25 -07:00
Keren Zhou
db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00
Twizzes
ddae106c0e [DOCS] Update installation.rst to fix windows build error (#747) 2022-10-13 13:27:15 -07:00
Keren Zhou
bc98aead33 [Backend] Fix for mov.u8 (#766)
Init a potential fix for mov.u8 which is not supported by ptx for now.
Use mov.u16 instead and cast it to u8.
2022-10-12 14:32:27 -07:00
Yu Guo
71b46acc42 [IR] Added special-purpose dequantize instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
2022-10-12 14:14:45 -07:00
Philippe Tillet
33e6f0df7f [DRIVER] Bumped CUDA requirement to 11.4+. This is to avoid bad performance surprises as older ptxas are much slower. (#769)
This also makes codegen simpler by avoiding special handling of eviction policies
2022-10-12 12:02:30 -07:00
Philippe Tillet
af76c989eb [RUNTIME] Make entry point cache key depend on triton version hash (#765) 2022-10-11 13:24:30 -07:00
Bin Bao
09cc2d454b [FRONTEND] Fix a bool tensor storing problem (#746) 2022-10-10 12:11:50 -07:00
Felipe Petroski Such
5d4b26d380 [RUNTIME] support multiple devices in the same process (#757) 2022-10-09 20:30:04 -07:00
Chris
9a11a567ce [DOCS] Fixed typos in 01-vector-add.py (#751) 2022-10-09 18:12:46 -07:00
Keren Zhou
11345e9b74 [RUNTIME] Add callback functions for external tools (#738) 2022-10-05 14:46:55 -07:00
Philippe Tillet
bdfdb9a1d2 [RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742) 2022-10-05 11:00:32 -07:00
shenggan
77c752dc78 [RUNTIME] remove fixed cu_include_dir (#739)
Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
2022-10-04 19:49:57 -07:00
Natalia Gimelshein
d3c925db8a [FRONTEND] properly broadcast scalar where condition (#736) 2022-10-04 12:44:03 -07:00
fdrocha
2b0f877fad [RUNTIME] Support environments with multiple cudalibs (#733) 2022-10-03 18:36:24 +00:00
Keren Zhou
4a2d3b7d79 [RUNTIME] Dump llvm, ttir, and sass to help debugging (#732) 2022-10-03 00:39:52 +00:00
Natalia Gimelshein
f55960e773 [FRONTEND] fix broadcasting for where (#729)
Fixes #532, all 3 inputs to where have to be broadcast together.
2022-10-01 13:18:47 -07:00
Phil Tillet
b244db06da [TUTORIALS] Attention tutorial fixup 2022-09-30 19:31:43 -07:00
Shintaro Iwasaki
7b61303ea1 [CODEGEN] Fix extract_N_bufferable in layout analysis (#728) 2022-09-30 12:21:22 -07:00
Shintaro Iwasaki
ae59f51c2d [CODEGEN] Fix an inliner to call a function with a phi-node (#727) 2022-09-29 21:36:40 -07:00
albanD
f45e31ba7c [FRONTEND] Make sure to hold the gil when creating python objects (#726)
Without this patch, a debug version of python complains that:
```
Fatal Python error: Python memory allocator called without holding the GIL
Python runtime state: initialized
```
2022-09-29 18:06:22 -07:00
Philippe Tillet
dad97528b2 [TESTING] allclose fixup (#724) 2022-09-28 22:49:05 +00:00
Jason Ansel
998fd5f9af [FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
2022-09-24 13:41:47 -07:00
Shintaro Iwasaki
3ac929b48b [BUILD] Download pybind11 in setup.py (#703)
Based on the discussion in #700, this PR enables downloading pybind11 in
`setup.py` without `git submodule` instead of copy-pasting pybind11
code. The downloaded pybind11 will be in `~/.triton/pybind` (like
`llvm`).
2022-09-23 15:54:07 -07:00
Jason Ansel
579c03615d [FRONTEND] Reduce number of compiles in JITFunction (#704)
I suspect this was the cause of the "new compiles even on a warm cache"
behavior I was seeing, though haven't 100% confirmed it.

Python `set()` iteration order is nondeterministic when you create a new
process. So the same args could produce different `instance_descriptor`s
and have false cache misses.
2022-09-23 21:44:52 +00:00
Philippe Tillet
25e1b36785 Revert "[pybind11] Use git-submodule for pybind11" (#701)
Reverts openai/triton#699
2022-09-23 12:25:38 -07:00
Shintaro Iwasaki
61d104ab3a [FRONTEND] Use git-submodule for pybind11 (#699)
This PR changes the `pybind11` source code management from copy-paste to
a package controlled by git-submodule.

See the discussion in #694 for details.
2022-09-23 09:55:03 -07:00
Philippe Tillet
8c3d4d5749 [RUNTIME] now decoupling entry point from cubin (#696) 2022-09-22 16:44:22 -07:00
Shintaro Iwasaki
df67068bb0 [pybind11] Update pybind11 to 2.10.0 (#691)
This PR updates the version of pybind11 to 2.10.0 (the latest stable).
2022-09-21 20:18:02 -07:00
Philippe Tillet
677ddae618 [FRONTEND] Add warmup for triton.jit() (#684)
This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object

Co-authored-by: Jason Ansel <jansel@jansel.net>
2022-09-21 19:13:20 +00:00
Jason Ansel
6abe813d1c Fix issue breaking cudagraphs (#685)
@ngimel figured this one out. 

The errors we were seeing from cudagraphs capture were coming from
`cuStreamGetCtx` which is not allowed while a stream is capturing.

It appears the result of `cuStreamGetCtx()` isn't even used, so I
believe it can just be removed.
2022-09-21 10:20:48 -07:00
Philippe Tillet
e318185eb4 [DOCS] Improved README.md wording (#683)
Initial wording dates from a time where nobody knew Triton, and
comparing it to CUDA helped differentiate it from other existing DSLs.
But nowadays this comparison doesn't make much sense; Triton is its own
thing, and some people may even still be more productive in CUDA than
Triton -- language preferences are subjective after all.
2022-09-20 18:09:43 -07:00
Philippe Tillet
7dc2a70edb Revert "Add .warmup() for triton.jit()" (#682)
Reverts openai/triton#671

It seems like for some reason this caused out-of-memory errors on some
of our internal workloads. I'm reverting this so that HEAD can be used
in production at OpenAI, and I will work on digging into this issue
asynchronously.
2022-09-20 16:05:14 -07:00
Philippe Tillet
48f30550f1 [FRONTEND] Now using raw compiler syscalls when possible (#678) 2022-09-19 21:01:36 -07:00
Jason Ansel
93b1adc53b [FRONTEND] Add .warmup() for triton.jit() (#671) 2022-09-18 23:09:34 -07:00
Phil Tillet
82956e5d6b [PACKAGING] Added missing package 2022-09-18 17:34:05 -07:00
Philippe Tillet
2baf333d44 [DOCS] Fixed typos (#670) 2022-09-18 17:13:12 -07:00
Jason Ansel
49f6bc3f2b [FRONTEND] Fix filename too long error in new runtime (#669) 2022-09-18 21:26:29 +00:00
Phil Tillet
00f4ef6958 [CI] wheel/docs workflows now only run on V100 machine 2022-09-18 13:28:35 -07:00
Jason Ansel
e647402fd3 Fix warning in generated C code (#667) 2022-09-18 12:57:32 -07:00
Philippe Tillet
4a77dfb042 [FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
2022-09-18 08:51:48 -07:00
Ian Bearman
889d9e34a1 [REPO] update gitignore (#666)
Update `.gitignore` to include `.vs` and `.vscode`
2022-09-17 14:25:28 -07:00
Shintaro Iwasaki
c668d6596e [DOCS] Fix spelling (#664)
This PR applies minor spelling fix in comments and string literals to
`master`. It shouldn't hurt anything.
2022-09-16 12:26:40 -07:00
Sophia Wisdom
4580a04710 [FRONTEND] Improve error message for CPU tensors (#654)
Redo of #651 against master. Fixes #525 by catching CUDA error when we
check pytorch tensor size and rethrowing a more informative error that
says why we failed.
2022-09-14 14:26:42 -07:00
Philippe Tillet
cfbbc7b43a [CI] Added V100 tag to disambiguate self-hosted runners (#653) 2022-09-14 13:47:50 -07:00
Yunxing Dai
59a8e25f43 [DOCS] Fix typo (#650) 2022-09-14 12:17:05 -07:00
Da Yan
437ced38c2 fp8 <> bf16 conversion (#637)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-08-30 14:20:12 -07:00
Da Yan
210a296699 [BACKEND] bf16 flash-attention (#636) 2022-08-26 20:40:55 -07:00
Daniil Fukalov
fe0c29b9ec Fix inconsistent struct declaration instead of class. (#632)
Looks like typo.
2022-08-26 16:20:21 -07:00
Phil Wang
7394d732ad [DOCS] support for variable head dimensions in flash attention triton tutorial (#623) 2022-08-15 19:16:49 -07:00
Da Yan
3e2953f357 Allow multiple_of and max_contiguous to accept n-d values (#617) 2022-08-10 09:59:32 -07:00
Daniil Fukalov
cc79376222 Fix deprectaion warning on CreateGEP(Value *, ArrayRef<Value *>, const Twine &) (#608)
This variant of CreateGEP() is already removed in LLVM 14.
2022-08-07 17:10:18 -07:00
Daniil Fukalov
7b91c7befd Fix "warning: control reaches end of non-void function". (#607) 2022-08-02 16:12:48 -07:00
Sharad Vikram
968f59027e Expose module.print in pybind (#604) 2022-07-29 21:36:08 -07:00
Anton Kostin
923d468187 Update LICENSE (#602) 2022-07-25 09:30:03 -07:00
Jason Ansel
027321cdcf [FRONTEND] Make tl.rand() 1-exclusive (#601) 2022-07-24 17:47:23 -07:00
Jason Ansel
e02e56dc63 [FRONTEND] Add missing rfloordiv (#598)
* [FRONTEND] Add missing rfloordiv

* fix tests
2022-07-23 21:54:12 -07:00
Philippe Tillet
ab56d310dd [BACKEND][IR] Fixed up internal dtype size for booleans (1bit -> 8bit) (#600) 2022-07-23 20:08:03 -07:00
Da Yan
f28caddbf8 [FRONTEND] Allow tl.where to select pointers (#595) 2022-07-21 09:54:27 -07:00
Keren Zhou
af85f5fa46 [FRONTEND] Refresh cache when the source code of outlined functions are changed (#590) 2022-07-20 17:34:07 -07:00
daadaada
9b2bc88d11 [BACKEND] Better bf16 support (#588) 2022-07-19 21:22:37 -07:00
Philippe Tillet
86cab58d89 [CI] Changed dev wheel date to UTC time to match CRON schedule (#587) 2022-07-18 14:54:13 -07:00
Phil Tillet
5b04331dd2 [TUTORIALS] Added more credits in fused attention tutorial 2022-07-13 23:48:58 -07:00
Jason Ansel
0a3f3d5f25 [PACKAGING] Include triton/language/libdevice.10.bc in package data (#582) 2022-07-13 23:45:27 -07:00
Keren Zhou
4912916c11 [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) 2022-07-13 15:52:21 -07:00
Phil Tillet
971f5782b4 [tutorials] Added flash attention credits in tutorial 2022-07-11 18:56:48 -07:00
Philippe Tillet
d5eb9bc230 [tutorial] Added bwd in fused attention example (#579)
Doesn't work on V100
2022-07-11 15:43:46 -07:00
Jason Ansel
c9a2b9c7d4 [FRONTEND] Add missing args to get_simd_tflops() (#578) 2022-07-11 14:37:59 -07:00
Philippe Tillet
4a399a7e40 [BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
2022-07-06 20:03:04 -07:00
vesuppi
22105bc33b [FRONTEND] Added type check in semantic arange (#572) 2022-07-03 15:25:37 -07:00
Keren Zhou
4bf509889b [BUILD] Change the default build type to Release (#571) 2022-07-01 12:17:22 -07:00
Keren Zhou
a74cce375f [FRONTEND] Raise broadcast error (#555) 2022-06-30 17:32:07 -07:00
Philippe Tillet
f733327ba4 [BACKEND][CODEGEN] Disabling L2 residency control by default (#570) 2022-06-29 17:05:13 -07:00
Natalia Gimelshein
1bbb2430d9 [TUTORIALS] adjust heuristics for dwdb kernel (#565) 2022-06-29 17:00:22 -07:00
Kashif Rasul
1895ceaa2d [TUTORIAL] Fix f-string for older python (#569)
fixes issue #568
2022-06-29 09:39:10 -07:00
Philippe Tillet
feb7a2a0dc [FRONTEND] Hotfix for store argument order (#567) 2022-06-28 00:24:02 -07:00
Philippe Tillet
5b4c8f221e [BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
2022-06-27 11:49:19 -07:00
Keren Zhou
87413bc925 [BACKEND] Fix layout convert for non-contiguous input (#564) 2022-06-25 23:12:03 -07:00
Keren Zhou
d345ddf837 [DOCS] Separate atomic cas from other atomic operations since operands are very different (#559) 2022-06-22 17:51:17 -07:00
Keren Zhou
b02bac41ba [CI] Change cache dir (#561) 2022-06-22 11:44:35 -07:00
Keren Zhou
a428cf0bb2 [FRONTEND] Fix pytorch warning. (#560)
UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc').
2022-06-20 20:12:09 -07:00
Keren Zhou
b5e728cb14 Add argmin argmax (#552) 2022-06-15 13:55:20 -07:00
Jason Ansel
6b9756532f [BACKEND] Remove print in coalesce.cc (#551) 2022-06-15 13:13:20 -07:00
Madeleine Thompson
8ce2c12e33 [PYTHON] move ephemeral files to homedir (#549)
This prevents potential conflicts with other users on shared machines.
2022-06-13 19:37:52 -07:00
Keren Zhou
93209c07e0 [BACKEND][CODEGEN] Fix reduce uint (#547) 2022-06-13 16:43:57 -07:00
Philippe Tillet
58c8889235 [FRONTEND] Fix scanline layout (#548) 2022-06-13 16:21:10 -07:00
Natalia Gimelshein
7094657aa9 [FRONTEND] fix bool conversion of floating types (#545) 2022-06-13 15:52:37 -07:00
Keren Zhou
38573d1261 [FRONTEND] Return allocated registers and spilled registers for users (#541) 2022-06-07 18:37:12 -07:00
Mengchi Zhang
2cdc6d35c4 [FRONTEND] Give col_per_thread an initial value to make the compiler happy (#535)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-06-06 12:48:23 -07:00
TC
f13cbaab9f [FRONTEND] assert that num_warps is a power of 2 (#539) 2022-06-06 11:37:08 -07:00
Philippe Tillet
751e325d2e [TUTORIALS] Fixed typo 2022-06-05 13:33:21 -07:00
Philippe Tillet
801c8a4c92 [TUTORIALS] Fixed typo 2022-06-05 12:32:07 -07:00
Philippe Tillet
8876e53206 [BACKEND] Restored reduction bugfixes 2022-06-03 11:38:52 -07:00
Philippe Tillet
a60374a597 Revert "[BACKEND] Various bug fixes; making reductions faster (#533)".
This is a more stable commit that produce bitwise identical code to earlier
versions. Using commits after this one may lead to slightly different numerics
2022-06-03 11:36:06 -07:00
Philippe Tillet
efa04cac1f [FRONTEND] A couple of bugfixes (#534) 2022-06-02 16:57:37 -07:00
Philippe Tillet
3e7500dfe6 [BACKEND] Various bug fixes; making reductions faster (#533) 2022-05-31 17:14:44 -07:00
Bert Maher
37037bb3be [FRONTEND] Default cache dir to /tmp/triton_$USER (#527) 2022-05-27 13:51:05 -07:00
Philippe Tillet
c82a206684 [FRONTEND] Better dot error message (#531) 2022-05-26 17:41:09 -07:00
Philippe Tillet
0e2883020a [BACKEND] Fixed typo in alignment analysis (#528) 2022-05-25 20:01:19 -07:00
Bert Maher
43fec2adca [FRONTEND] Add binding for create_int_to_ptr (#526) 2022-05-25 15:26:18 -07:00
Philippe Tillet
011bc83c1b [FRONTEND] For loops now promote initial value (#524) 2022-05-24 13:20:10 -07:00
Natalia Gimelshein
96bff90471 [FRONTEND] faster jit function launch (#523)
With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
2022-05-24 12:08:49 -07:00
daadaada
d5eaa8dfa0 Making the generated Triton IR deterministic & a script to compare cached assembly (#522) 2022-05-24 08:56:36 -07:00
Shantanu
80f6a2698b [FRONTEND] Ensure version_key is called at most once (#519)
Co-authored-by: hauntsaninja <>
2022-05-23 13:40:08 -07:00
daadaada
205a493b10 [FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)
Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
2022-05-21 09:45:54 -07:00
Jiabao Lei
abea3dc2c6 [FRONTEND] provide device kwargs && fix fstring error for py<3.8 (#515)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-05-14 16:21:46 -07:00
Philippe Tillet
d35617bea1 [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) 2022-05-14 15:26:13 -07:00
Mengchi Zhang
d1a22a94e6 [FRONTEND] Add empty return value and remove protect to open the access to contained_tys_vec_t (#514)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-05-13 11:46:12 -07:00
Jason Ansel
d954a05989 [FRONTEND] Handle torch.uint8 args (#513)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-05-12 13:07:39 -07:00
Philippe Tillet
0835a4fb05 [TUTORIALS] Removed #noformat in layer norm tutorial 2022-05-12 12:41:25 -07:00
Philippe Tillet
c736ba7c3e [TUTORIALS] Fixed formatting 2022-05-12 12:31:23 -07:00
Philippe Tillet
cd30a99aa2 [TUTORIALS] fixed formatting 2022-05-12 12:28:22 -07:00
Philippe Tillet
d87435e536 [TUTORIALS] Layer norm tutorial now uses residency control (#510) 2022-05-05 19:53:54 -07:00
Sriram Murali
7c9bc5a47b [CODEGEN] Change return type of generator::packed_type to appease build warnings (#507) 2022-05-04 20:03:37 -07:00
Philippe Tillet
95feb10ec9 [FRONTEND] fixup (#505) 2022-04-30 14:25:06 -07:00
Philippe Tillet
11a908655d [FRONTEND] Fixup 2022-04-29 14:35:09 -07:00
Phil Tillet
cd78ce4888 [FRONTEND] Improved error message when assigning None to non-constexpr 2022-04-29 09:17:54 -07:00
Philippe Tillet
ae2a1ab225 [BACKEND] Alignment pass improvements (#503) 2022-04-25 21:16:00 -07:00
Philippe Tillet
7d544799a0 [BACKEND] Now disabling L2 eviction policy for sm < 80 2022-04-25 09:35:36 -07:00
Philippe Tillet
3ca792043f [TEST] Added test for vectorization 2022-04-24 13:50:48 -07:00
Philippe Tillet
bda209002e [BACKEND][CODEGEN] vectorization bugfix (#502) 2022-04-23 13:18:33 -07:00
Philippe Tillet
0cc3b1129b [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) 2022-04-21 23:56:01 -07:00
Philippe Tillet
7d6c504e8d [TESTING] Added testing utilities for fixing clock and using cuda-memcheck (#500) 2022-04-21 22:40:10 -07:00
Philippe Tillet
073be1d2ee [FRONTEND] check that tensors have power-of-two number of elements (#499) 2022-04-14 19:30:02 -07:00
Philippe Tillet
5c7122004c [TUTORIALS] Tutorial shouldn't expose clock. Just removed it. 2022-04-14 17:33:44 -07:00
Philippe Tillet
dc4d40faec [FRONTEND] now mangle constexpr float containing "e-" 2022-04-14 10:26:48 -07:00
Philippe Tillet
25f6689508 [FRONTEND] rename current stream monkey patch (#495) 2022-04-13 11:45:55 -07:00
Philippe Tillet
76bfac9f15 [FRONTEND] Improved constexpr handling (#493) 2022-04-12 00:02:54 -07:00
Philippe Tillet
14b0fd4cfb [FRONTEND] Added possibility for users to customize current stream query (#492) 2022-04-07 12:11:32 -07:00
Philippe Tillet
6424771f55 [CI] Documentation fixup 2022-04-07 09:42:35 -07:00
Philippe Tillet
9f08ecd684 [FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
2022-04-06 16:13:53 -07:00
Philippe Tillet
2bed6fc850 [LANG] Added support for device functions (#484) 2022-04-03 20:58:16 -07:00
apd10
e85c7a7fc7 Bugfix in ptxas path. (#487)
Bug: "ret" value is destroyed when a failing "ptxas --version" is run
overwriting the previous valid "ret" value.

Fix: keep rets only for those runs which are successful. Pick the first
one
2022-03-30 20:45:41 -07:00
Philippe Tillet
bace26143d [TUTORIALS] Removed leftover print 2022-03-28 16:53:23 -07:00
Philippe Tillet
e0cc488055 [FRONTEND] Added tl.clock and tl.globaltimer (#485) 2022-03-28 16:15:43 -07:00
Philippe Tillet
76a9ee50a8 Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)
This reverts commit 539961072c.
2022-03-24 17:16:50 -07:00
Philippe Tillet
ea6d1f1b85 [DRIVER] LLVM driver fixup (#482)
Current way of doing things is probably not super thread safe. init is shared between threads and some threads my not call the LLVMInitialize* function.
2022-03-23 00:24:45 -07:00
Keren Zhou
a4f68165cd [FRONTEND] Hot fix for lineno (#481)
Override __reduce__ to make CompilationError pickable and print out error messages
2022-03-22 22:09:49 -07:00
137 changed files with 10123 additions and 16165 deletions

View File

@@ -8,7 +8,7 @@ jobs:
Build-Documentation:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
@@ -18,6 +18,11 @@ jobs:
with:
ref: 'gh-pages'
- name: Clear docs
run: |
rm -r /tmp/triton-docs
continue-on-error: true
- name: Checkout branch
uses: actions/checkout@v1
@@ -31,7 +36,6 @@ jobs:
run: |
git branch
# update docs
rm -r /tmp/triton-docs;
mkdir /tmp/triton-docs;
mv docs/_build/html/* /tmp/triton-docs/
git checkout gh-pages

View File

@@ -5,14 +5,13 @@ on:
pull_request:
branches:
- master
- v2.0
jobs:
Integration-Tests:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
@@ -21,7 +20,7 @@ jobs:
- name: Clear cache
run: |
rm -r /tmp/triton/
rm -r ~/.triton/
continue-on-error: true
- name: Install Triton

View File

@@ -8,7 +8,7 @@ jobs:
Build-Wheels:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
@@ -18,7 +18,7 @@ jobs:
- name: Patch setup.py
run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg

3
.gitignore vendored
View File

@@ -7,3 +7,6 @@ python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
.vscode
.vs

View File

@@ -3,11 +3,6 @@ include(ExternalProject)
set(CMAKE_CXX_STANDARD 17)
if(NOT TRITON_LLVM_BUILD_DIR)
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
endif()
project(triton)
include(CTest)
if(NOT WIN32)
@@ -31,6 +26,9 @@ endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
@@ -65,7 +63,7 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES
set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a

View File

@@ -1,6 +1,6 @@
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2021 OpenAI
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
@@ -20,4 +20,4 @@
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*/

View File

@@ -12,7 +12,7 @@
# Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment for expressing tensor math workloads that offers high flexibility, developer productivity and end to end performance.
The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton!
@@ -33,6 +33,15 @@ And the latest nightly release:
pip install -U --pre triton
```
# Install from source
```
git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
```
# Changelog
Version 1.1 is out! New features include:

View File

@@ -25,7 +25,7 @@
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
#
# Note: The variable names were chosen in conformance with the offical CMake
# Note: The variable names were chosen in conformance with the official CMake
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
# Try suffixed versions to pick up the newest LLVM install available on Debian
@@ -196,4 +196,4 @@ include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(LLVM
REQUIRED_VARS LLVM_ROOT_DIR
VERSION_VAR LLVM_VERSION_STRING)
VERSION_VAR LLVM_VERSION_STRING)

View File

@@ -45,7 +45,7 @@ def setup(app):
def wrapped(obj, **kwargs):
import triton
if isinstance(obj, triton.code_gen.JITFunction):
if isinstance(obj, triton.runtime.JITFunction):
obj = obj.fn
return old(obj)
@@ -56,7 +56,7 @@ def setup(app):
def documenter(app, obj, parent):
import triton
if isinstance(obj, triton.code_gen.JITFunction):
if isinstance(obj, triton.runtime.JITFunction):
obj = obj.fn
return old_documenter(app, obj, parent)

View File

@@ -34,11 +34,13 @@ You can install the Python package from source by running the following commands
.. code-block:: bash
git clone https://github.com/openai/triton.git;
cd triton/python;
cd triton;
git submodule update --init --recursive;
cd python;
pip install cmake; # build time dependency
pip install -e .
Note that, if llvm-11 is not present on your system, the setup.py script will download the official LLVM11 static libraries link against that.
Note that, if llvm-11 is not present on your system and you are on linux, the setup.py script will download the official LLVM11 static libraries link against that. For windows users, LLVM must be installed and configured in PATH.
You can then test your installation by running the unit tests:

View File

@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
Program Representation
+++++++++++++++++++++++
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
.. table::
:widths: 50 50

View File

@@ -106,9 +106,13 @@ Atomic Ops
:nosignatures:
atomic_cas
atomic_xchg
atomic_add
atomic_max
atomic_min
atomic_and
atomic_or
atomic_xor
Comparison ops

View File

@@ -12,7 +12,9 @@ namespace ir {
class phi_node;
class splat_inst;
class cast_inst;
class cmp_inst;
class reshape_inst;
class dequantize_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
@@ -33,8 +35,10 @@ private:
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
@@ -42,6 +46,7 @@ private:
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
@@ -52,6 +57,7 @@ private:
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
@@ -65,6 +71,7 @@ public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
std::vector<cst_info> get_cst_info(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;

View File

@@ -25,6 +25,7 @@ private:
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_dequantize(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i,

View File

@@ -211,7 +211,8 @@ private:
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
};
struct scanline_layout: public distributed_layout {
class scanline_layout: public distributed_layout {
public:
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
@@ -224,7 +225,8 @@ struct scanline_layout: public distributed_layout {
int nts(size_t k) { return nts_.at(k); }
int contig_per_thread(size_t k) { return nts_.at(k); }
public:
int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
private:
// micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_;
// nano tile size. The size of a tile held by a thread.
@@ -244,7 +246,7 @@ struct N_buffer_info_t {
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
// abstract for dot and corresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
@@ -257,7 +259,8 @@ public:
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align, target *tgt);
analysis::align* align, target *tgt,
bool is_tmp = false);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
@@ -275,6 +278,7 @@ public:
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
bool is_tmp() const { return is_tmp_; }
private:
size_t size_;
@@ -289,6 +293,7 @@ private:
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
bool is_tmp_;
};
@@ -307,13 +312,20 @@ private:
void create(size_t id, const std::vector<ir::value*>& values);
public:
void create_tmp_layout(size_t id, data_layout* arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
ir::instruction* i,
bool is_index = false);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
@@ -321,7 +333,19 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// layout checkers
bool is_scanline(ir::instruction* i);
bool is_coalesced_scanline(ir::instruction* i);
bool is_mma(ir::instruction* i);
bool is_a100_mma(ir::instruction* i);
// execution
void run(ir::module &mod);
@@ -335,6 +359,7 @@ private:
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
std::map<ir::value*, size_t> tmp_index_;
};
}

View File

@@ -1,12 +1,14 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
#include "llvm/ADT/MapVector.h"
#include <set>
#include <vector>
namespace triton{
namespace ir{
@@ -42,14 +44,14 @@ struct segment {
class liveness {
private:
typedef std::map<shared_layout*, segment> intervals_map_t;
typedef llvm::MapVector<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); }
segment get(shared_layout* v) const { return intervals_.lookup(v); }
// run
void run(ir::module &mod);

View File

@@ -0,0 +1,90 @@
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
#include <memory>
#include <string>
#include <map>
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
namespace triton {
namespace codegen {
///
/// \brief ExternLib is a class that represents a library of external functions.
///
class ExternLib {
public:
ExternLib(const std::string &name, const std::string &path)
: name_(name), path_(path) {}
virtual ~ExternLib() = default;
virtual const std::string &name() const { return name_; }
virtual const std::string &path() const { return path_; }
///
/// \brief Load the library and return the module.
///
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
///
/// \brief Link the module into the given module.
///
void link(std::unique_ptr<llvm::Module> &llvm,
std::unique_ptr<llvm::Module> &mod);
///
/// \brief Run load, link, and opt on the module.
///
virtual void install(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) {
auto mod = load(ctx);
link(llvm, mod);
opt(ctx, llvm);
}
///
/// \brief Run opt on the module.
///
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) = 0;
private:
std::string name_;
std::string path_;
};
///
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
///
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
///
/// \brief Concrete class for NVIDIA's libdevice library.
///
class LibDevice final : public ExternLib {
public:
LibDevice(const std::string &name, const std::string &path)
: ExternLib(name, path) {}
virtual ~LibDevice() = default;
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) override;
};
///
/// \brief Create an ExternLib instance based on the name and path.
///
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
const std::string &lib_path);
} // namespace codegen
} // namespace triton
#endif

View File

@@ -3,6 +3,7 @@
#include <memory>
#include "extern_lib.h"
namespace llvm{
class Module;
@@ -30,12 +31,10 @@ namespace codegen{
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target,
int sm, int num_warps,
int num_stages, int &shared_static);
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
int num_warps, int num_stages, int &shared_static,
const ExternLibMap &extern_libs);
}
}

View File

@@ -4,7 +4,9 @@
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/extern_lib.h"
#include <functional>
// forward
@@ -24,6 +26,7 @@ namespace llvm{
class IRBuilder;
class ArrayType;
class Function;
class StructType;
}
namespace triton{
@@ -114,8 +117,17 @@ private:
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
void packed_type(ir::value* i);
void forward_declare(ir::function* fn);
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
public:
private:
typedef std::function<void(
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn, bool is_first)>
acc_fn_t;
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
@@ -125,6 +137,8 @@ public:
unsigned num_warps);
void visit_value(ir::value* v);
void visit_call_inst(ir::call_inst*);
void visit_launch_inst(ir::launch_inst *);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
@@ -134,9 +148,19 @@ public:
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
void visit_dequantize_inst(ir::dequantize_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
@@ -148,6 +172,8 @@ public:
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_cat_inst(ir::cat_inst*);
void visit_extract_value_inst(ir::extract_value_inst *);
void visit_insert_value_inst(ir::insert_value_inst *);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
@@ -168,8 +194,8 @@ public:
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
@@ -182,6 +208,9 @@ public:
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
@@ -197,12 +226,21 @@ public:
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
// Add a new external library based on given name and path if it doesn't exist
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
private:
// Get all external libraries
const ExternLibMap &get_extern_lib_map() {
return extern_lib_map_;
}
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
@@ -235,10 +273,11 @@ private:
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
std::map<ir::value*, Function*> fns_;
// helper for creating llvm values
adder add;
@@ -250,6 +289,9 @@ private:
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
// Eviction policies
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
};
}

View File

@@ -32,11 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
bool has_sm80_;
analysis::align* align_;
analysis::layouts* layout_;
};

View File

@@ -15,18 +15,26 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
bool is_shmem_op(ir::instruction* i, int op);
bool is_shmem_res(ir::value* i);
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
public:
cts(bool use_async = false): use_async_(use_async) {}
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
void run(ir::module &mod);
private:
bool use_async_;
bool has_sm80_;
analysis::layouts* layouts_;
};
}

View File

@@ -0,0 +1,31 @@
#pragma once
#include <list>
namespace triton {
namespace ir {
class module;
class function;
class call_inst;
class builder;
}
namespace codegen{
namespace transform{
struct fncmp {
bool operator()(ir::function* x, ir::function* y) const;
};
class inliner {
public:
inliner() {}
void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list<ir::call_inst*>& callsites);
void run(ir::module &mod);
};
}
}
}

View File

@@ -30,6 +30,9 @@ private:
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);

View File

@@ -89,6 +89,7 @@ public:
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
@@ -214,6 +215,7 @@ private:
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
// link management
static void* cuLinkAddFile_v2_;
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;

290
include/triton/external/CUDA/cuda.h vendored Executable file → Normal file
View File

@@ -224,7 +224,7 @@ typedef uint64_t cuuint64_t;
/**
* CUDA API version number
*/
#define CUDA_VERSION 11050
#define CUDA_VERSION 11040
#ifdef __cplusplus
extern "C" {
@@ -496,33 +496,7 @@ typedef enum CUarray_format_enum {
CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */
CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */
CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */
CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */
CU_AD_FORMAT_UNORM_INT8X1 = 0xc0, /**< 1 channel unsigned 8-bit normalized integer */
CU_AD_FORMAT_UNORM_INT8X2 = 0xc1, /**< 2 channel unsigned 8-bit normalized integer */
CU_AD_FORMAT_UNORM_INT8X4 = 0xc2, /**< 4 channel unsigned 8-bit normalized integer */
CU_AD_FORMAT_UNORM_INT16X1 = 0xc3, /**< 1 channel unsigned 16-bit normalized integer */
CU_AD_FORMAT_UNORM_INT16X2 = 0xc4, /**< 2 channel unsigned 16-bit normalized integer */
CU_AD_FORMAT_UNORM_INT16X4 = 0xc5, /**< 4 channel unsigned 16-bit normalized integer */
CU_AD_FORMAT_SNORM_INT8X1 = 0xc6, /**< 1 channel signed 8-bit normalized integer */
CU_AD_FORMAT_SNORM_INT8X2 = 0xc7, /**< 2 channel signed 8-bit normalized integer */
CU_AD_FORMAT_SNORM_INT8X4 = 0xc8, /**< 4 channel signed 8-bit normalized integer */
CU_AD_FORMAT_SNORM_INT16X1 = 0xc9, /**< 1 channel signed 16-bit normalized integer */
CU_AD_FORMAT_SNORM_INT16X2 = 0xca, /**< 2 channel signed 16-bit normalized integer */
CU_AD_FORMAT_SNORM_INT16X4 = 0xcb, /**< 4 channel signed 16-bit normalized integer */
CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format */
CU_AD_FORMAT_BC1_UNORM_SRGB = 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format with sRGB encoding*/
CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format */
CU_AD_FORMAT_BC2_UNORM_SRGB = 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format with sRGB encoding*/
CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format */
CU_AD_FORMAT_BC3_UNORM_SRGB = 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format with sRGB encoding*/
CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized block-compressed (BC4 compression) format */
CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized block-compressed (BC4 compression) format */
CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized block-compressed (BC5 compression) format */
CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized block-compressed (BC5 compression) format */
CU_AD_FORMAT_BC6H_UF16 = 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H compression) format */
CU_AD_FORMAT_BC6H_SF16 = 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) format */
CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized block-compressed (BC7 compression) format */
CU_AD_FORMAT_BC7_UNORM_SRGB = 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) format with sRGB encoding */
CU_AD_FORMAT_NV12 = 0xb0
} CUarray_format;
/**
@@ -657,7 +631,7 @@ typedef enum CUdevice_attribute_enum {
CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate */
CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */
CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */
CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */
@@ -665,7 +639,7 @@ typedef enum CUdevice_attribute_enum {
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */
CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */
CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */
CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */
CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */
CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */
@@ -844,7 +818,7 @@ typedef enum CUcomputemode_enum {
* Memory advise values
*/
typedef enum CUmem_advise_enum {
CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occassionally be written to */
CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occasionally be written to */
CU_MEM_ADVISE_UNSET_READ_MOSTLY = 2, /**< Undo the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY */
CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3, /**< Set the preferred location for the data as the specified device */
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION = 4, /**< Clear the preferred location for the data */
@@ -853,7 +827,7 @@ typedef enum CUmem_advise_enum {
} CUmem_advise;
typedef enum CUmem_range_attribute_enum {
CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occassionally be written to */
CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occasionally be written to */
CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION = 2, /**< The preferred location of the range */
CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY = 3, /**< Memory range has ::CU_MEM_ADVISE_SET_ACCESSED_BY set for specified device */
CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION = 4 /**< The last location to which the range was prefetched */
@@ -875,7 +849,7 @@ typedef enum CUjit_option_enum
* IN: Specifies minimum number of threads per block to target compilation
* for\n
* OUT: Returns the number of threads the compiler actually targeted.
* This restricts the resource utilization fo the compiler (e.g. max
* This restricts the resource utilization of the compiler (e.g. max
* registers) such that a block with the given number of threads should be
* able to launch based on register limitations. Note, this option does not
* currently take into account any other resource limitations, such as
@@ -1000,10 +974,10 @@ typedef enum CUjit_option_enum
CU_JIT_FAST_COMPILE,
/**
* Array of device symbol names that will be relocated to the corresponing
* Array of device symbol names that will be relocated to the corresponding
* host addresses stored in ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES.\n
* Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n
* When loding a device module, driver will relocate all encountered
* When loading a device module, driver will relocate all encountered
* unresolved symbols to the host addresses.\n
* It is only allowed to register symbols that correspond to unresolved
* global variables.\n
@@ -1220,7 +1194,7 @@ typedef enum CUlimit_enum {
* Resource types
*/
typedef enum CUresourcetype_enum {
CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resoure */
CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resource */
CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */
CU_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */
CU_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */
@@ -1650,8 +1624,7 @@ typedef enum cudaError_enum {
CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224,
/**
* This indicates that the device kernel source is invalid. This includes
* compilation/linker errors encountered in device code or user error.
* This indicates that the device kernel source is invalid.
*/
CUDA_ERROR_INVALID_SOURCE = 300,
@@ -2068,9 +2041,9 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize);
* On Windows the flag is a no-op.
* On Linux that memory is marked as non cache-coherent for the GPU and
* is expected to be physically contiguous. It may return
* ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user,
* ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions.
* On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED
* CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user,
* CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions.
* On all other platforms, it is not supported and CUDA_ERROR_NOT_SUPPORTED
* is returned.
* Flag for ::cuMemHostRegister()
*/
@@ -2079,12 +2052,12 @@ typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize);
/**
* If set, the passed memory pointer is treated as pointing to memory that is
* considered read-only by the device. On platforms without
* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
* CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
* required in order to register memory mapped to the CPU as read-only. Support
* for the use of this flag can be queried from the device attribute
* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
* CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
* a current context associated with a device that does not have this attribute
* set will cause ::cuMemHostRegister to error with ::CUDA_ERROR_NOT_SUPPORTED.
* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED.
*/
#define CU_MEMHOSTREGISTER_READ_ONLY 0x08
@@ -2941,9 +2914,9 @@ typedef struct CUmemAllocationProp_st {
CUmemLocation location;
/**
* Windows-specific POBJECT_ATTRIBUTES required when
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object atributes structure
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object attributes structure
* includes security attributes that define
* the scope of which exported allocations may be tranferred to other
* the scope of which exported allocations may be transferred to other
* processes. In all other cases, this field is required to be zero.
*/
void *win32HandleMetaData;
@@ -3063,7 +3036,7 @@ typedef struct CUmemPoolProps_st {
/**
* Windows-specific LPSECURITYATTRIBUTES required when
* ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines
* the scope of which exported allocations may be tranferred to other
* the scope of which exported allocations may be transferred to other
* processes. In all other cases, this field is required to be zero.
*/
void *win32SecurityAttributes;
@@ -3546,7 +3519,7 @@ CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal);
CUresult CUDAAPI cuDeviceGetCount(int *count);
/**
* \brief Returns an identifer string for the device
* \brief Returns an identifier string for the device
*
* Returns an ASCII string identifying the device \p dev in the NULL-terminated
* string pointed to by \p name. \p len specifies the maximum length of the
@@ -3583,7 +3556,7 @@ CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev);
* Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will
* supplant this version in 12.0, which is retained for minor version compatibility.
*
* Returns 16-octets identifing the device \p dev in the structure
* Returns 16-octets identifying the device \p dev in the structure
* pointed by the \p uuid.
*
* \param uuid - Returned UUID
@@ -3613,7 +3586,7 @@ CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev);
/**
* \brief Return an UUID for the device (11.4+)
*
* Returns 16-octets identifing the device \p dev in the structure
* Returns 16-octets identifying the device \p dev in the structure
* pointed by the \p uuid. If the device is in MIG mode, returns its
* MIG UUID which uniquely identifies the subscribed MIG compute instance.
*
@@ -3735,117 +3708,117 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* \p dev. The supported attributes are:
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per
* block;
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block;
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block;
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block;
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid;
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid;
* - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid;
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of
* shared memory available to a thread block in bytes
* shared memory available to a thread block in bytes;
* - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for
* __constant__ variables in a CUDA C kernel in bytes
* - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads
* __constant__ variables in a CUDA C kernel in bytes;
* - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads;
* - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the
* memory copy functions that involve memory regions allocated through
* ::cuMemAllocPitch()
* ::cuMemAllocPitch();
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D
* texture width
* texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width
* for a 1D texture bound to linear memory
* for a 1D texture bound to linear memory;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum
* mipmapped 1D texture width
* mipmapped 1D texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D
* texture width
* texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D
* texture height
* texture height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width
* for a 2D texture bound to linear memory
* for a 2D texture bound to linear memory;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height
* for a 2D texture bound to linear memory
* for a 2D texture bound to linear memory;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch
* in bytes for a 2D texture bound to linear memory
* in bytes for a 2D texture bound to linear memory;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum
* mipmapped 2D texture width
* mipmapped 2D texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum
* mipmapped 2D texture height
* mipmapped 2D texture height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D
* texture width
* texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D
* texture height
* texture height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D
* texture depth
* texture depth;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE:
* Alternate maximum 3D texture width, 0 if no alternate
* maximum 3D texture size is supported
* maximum 3D texture size is supported;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE:
* Alternate maximum 3D texture height, 0 if no alternate
* maximum 3D texture size is supported
* maximum 3D texture size is supported;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE:
* Alternate maximum 3D texture depth, 0 if no alternate
* maximum 3D texture size is supported
* maximum 3D texture size is supported;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH:
* Maximum cubemap texture width or height
* Maximum cubemap texture width or height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH:
* Maximum 1D layered texture width
* Maximum 1D layered texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS:
* Maximum layers in a 1D layered texture
* Maximum layers in a 1D layered texture;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH:
* Maximum 2D layered texture width
* Maximum 2D layered texture width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT:
* Maximum 2D layered texture height
* Maximum 2D layered texture height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS:
* Maximum layers in a 2D layered texture
* Maximum layers in a 2D layered texture;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH:
* Maximum cubemap layered texture width or height
* Maximum cubemap layered texture width or height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS:
* Maximum layers in a cubemap layered texture
* Maximum layers in a cubemap layered texture;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH:
* Maximum 1D surface width
* Maximum 1D surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH:
* Maximum 2D surface width
* Maximum 2D surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT:
* Maximum 2D surface height
* Maximum 2D surface height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH:
* Maximum 3D surface width
* Maximum 3D surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT:
* Maximum 3D surface height
* Maximum 3D surface height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH:
* Maximum 3D surface depth
* Maximum 3D surface depth;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH:
* Maximum 1D layered surface width
* Maximum 1D layered surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS:
* Maximum layers in a 1D layered surface
* Maximum layers in a 1D layered surface;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH:
* Maximum 2D layered surface width
* Maximum 2D layered surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT:
* Maximum 2D layered surface height
* Maximum 2D layered surface height;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS:
* Maximum layers in a 2D layered surface
* Maximum layers in a 2D layered surface;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH:
* Maximum cubemap surface width
* Maximum cubemap surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH:
* Maximum cubemap layered surface width
* Maximum cubemap layered surface width;
* - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS:
* Maximum layers in a cubemap layered surface
* Maximum layers in a cubemap layered surface;
* - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit
* registers available to a thread block
* - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz
* registers available to a thread block;
* - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz;
* - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture
* base addresses aligned to ::textureAlign bytes do not need an offset
* applied to texture fetches
* applied to texture fetches;
* - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement
* for 2D texture references bound to pitched memory
* for 2D texture references bound to pitched memory;
* - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy
* memory between host and device while executing a kernel, or 0 if not
* memory between host and device while executing a kernel, or 0 if not;
* - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on
* the device
* the device;
* - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit
* for kernels executed on the device, or 0 if not
* for kernels executed on the device, or 0 if not;
* - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the
* memory subsystem, or 0 if not
* memory subsystem, or 0 if not;
* - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host
* memory into the CUDA address space, or 0 if not
* memory into the CUDA address space, or 0 if not;
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently
* in. Available modes are as follows:
* - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and
@@ -3858,33 +3831,33 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* executing multiple kernels within the same context simultaneously, or 0 if
* not. It is not guaranteed that multiple kernels will be resident
* on the device concurrently so this feature should not be relied upon for
* correctness.
* correctness;
* - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the
* device, 0 if error correction is disabled or not supported by the device
* - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device
* device, 0 if error correction is disabled or not supported by the device;
* - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device;
* - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier
* of the device
* of the device;
* - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device
* - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC
* is only available on Tesla hardware running Windows Vista or later
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits
* - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor
* is only available on Tesla hardware running Windows Vista or later;
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz;
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits;
* - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache;
* - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor;
* - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with
* the host, or 0 if not
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number
* the host, or 0 if not;
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number;
* - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number;
* - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals
* in L1 cache, 0 if caching globals in L1 cache is not supported by the device
* in L1 cache, 0 if caching globals in L1 cache is not supported by the device;
* - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals
* in L1 cache, 0 if caching locals in L1 cache is not supported by the device
* in L1 cache, 0 if caching locals in L1 cache is not supported by the device;
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of
* shared memory available to a multiprocessor in bytes; this amount is shared
* by all thread blocks simultaneously resident on a multiprocessor
* by all thread blocks simultaneously resident on a multiprocessor;
* - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit
* registers available to a multiprocessor; this number is shared by all thread
* blocks simultaneously resident on a multiprocessor
* blocks simultaneously resident on a multiprocessor;
* - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory
* on this system, 0 if allocating managed memory is not supported by the device on this system.
* - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not.
@@ -3894,7 +3867,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* supports native atomic operations.
* - ::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO: Ratio of single precision performance
* (in floating-point operations per second) to double precision performance.
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device suppports coherently accessing
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device supports coherently accessing
* pageable memory without calling cudaHostRegister on it.
* - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS: Device can coherently access managed memory
* concurrently with the CPU.
@@ -3902,7 +3875,7 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* - ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM: Device can access host registered
* memory at the same virtual address as the CPU.
* - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN: The maximum per block shared memory size
* suported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call.
* supported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call.
* For more details see ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
* - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device accesses pageable memory via the host's
* page tables.
@@ -3910,20 +3883,14 @@ CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements,
* - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor
* - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested ::cuMemCreate
* - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes.
* - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes.
* - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor.
* - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate
* - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes
* - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate.
* - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes
* - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays.
* - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU
* - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes.
* - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU
* - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information)
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum
* - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here.
* - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC
*
* \param pi - Returned device attribute value
* \param attrib - Device attribute to query
@@ -4165,7 +4132,7 @@ __CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, CUdevi
*
* \deprecated
*
* This function was deprecated as of CUDA 5.0 and its functionality superceded
* This function was deprecated as of CUDA 5.0 and its functionality superseded
* by ::cuDeviceGetAttribute().
*
* Returns in \p *major and \p *minor the major and minor revision numbers that
@@ -4690,13 +4657,6 @@ CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArra
* It is the responsibility of the calling function to ensure that no API
* call issues using \p ctx while ::cuCtxDestroy() is executing.
*
* Destroys and cleans up all resources associated with the context.
* It is the caller's responsibility to ensure that the context or its resources
* are not accessed or passed in subsequent API calls and doing so will result in undefined behavior.
* These resources include CUDA types such as ::CUmodule, ::CUfunction, ::CUstream, ::CUevent,
* ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, ::CUsurfref,
* ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and ::CUexternalSemaphore.
*
* If \p ctx is current to the calling thread then \p ctx will also be
* popped from the current thread's context stack (as though ::cuCtxPopCurrent()
* were called). If \p ctx is current to other threads, then \p ctx will
@@ -5002,10 +4962,10 @@ CUresult CUDAAPI cuCtxSynchronize(void);
* returned.
*
* - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch granularity.
* Values can range from 0B to 128B. This is purely a performence hint and
* Values can range from 0B to 128B. This is purely a performance hint and
* it can be ignored or clamped depending on the platform.
*
* - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes availabe for
* - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes available for
* persisting L2 cache. This is purely a performance hint and it can be
* ignored or clamped depending on the platform.
*
@@ -5672,7 +5632,6 @@ CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin);
* ::CUDA_ERROR_INVALID_CONTEXT,
* ::CUDA_ERROR_INVALID_VALUE
* \notefnerr
* \note_destroy_ub
*
* \sa ::cuModuleGetFunction,
* ::cuModuleGetGlobal,
@@ -5993,9 +5952,8 @@ cuLinkDestroy(CUlinkState state);
/**
* \brief Gets free and total memory
*
* Returns in \p *total the total amount of memory available to the the current context.
* Returns in \p *free the amount of memory on the device that is free according to the OS.
* CUDA is not guaranteed to be able to allocate all of the memory that the OS reports as free.
* Returns in \p *free and \p *total respectively, the free and total amount of
* memory available for allocation by the CUDA context, in bytes.
*
* \param free - Returned free memory in bytes
* \param total - Returned total memory in bytes
@@ -6440,7 +6398,7 @@ CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p);
* ::cuStreamAttachMemAsync will be required to enable access on such devices.
*
* If the association is later changed via ::cuStreamAttachMemAsync to
* a single stream, the default association as specifed during ::cuMemAllocManaged
* a single stream, the default association as specified during ::cuMemAllocManaged
* is restored when that stream is destroyed. For __managed__ variables, the
* default association is always ::CU_MEM_ATTACH_GLOBAL. Note that destroying a
* stream is an asynchronous operation, and as a result, the change to default
@@ -6839,10 +6797,10 @@ CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr);
*
* - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory
* that is considered read-only by the device. On platforms without
* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
* CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is
* required in order to register memory mapped to the CPU as read-only. Support
* for the use of this flag can be queried from the device attribute
* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
* CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with
* a current context associated with a device that does not have this attribute
* set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED.
*
@@ -8987,7 +8945,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi
* float16's:
* \code
CUDA_ARRAY_DESCRIPTOR desc;
desc.Format = CU_AD_FORMAT_HALF;
desc.FormatFlags = CU_AD_FORMAT_HALF;
desc.NumChannels = 4;
desc.Width = width;
desc.Height = height;
@@ -8997,7 +8955,7 @@ CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsi
* of which is two 8-bit unsigned chars:
* \code
CUDA_ARRAY_DESCRIPTOR arrayDesc;
desc.Format = CU_AD_FORMAT_UNSIGNED_INT8;
desc.FormatFlags = CU_AD_FORMAT_UNSIGNED_INT8;
desc.NumChannels = 2;
desc.Width = width;
desc.Height = height;
@@ -9323,7 +9281,7 @@ CUresult CUDAAPI cuArrayDestroy(CUarray hArray);
* 4x16-bit float16's:
* \code
CUDA_ARRAY3D_DESCRIPTOR desc;
desc.Format = CU_AD_FORMAT_HALF;
desc.FormatFlags = CU_AD_FORMAT_HALF;
desc.NumChannels = 4;
desc.Width = width;
desc.Height = height;
@@ -9658,13 +9616,13 @@ CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size);
* \brief Create a CUDA memory handle representing a memory allocation of a given size described by the given properties
*
* This creates a memory allocation on the target device specified through the
* \p prop strcuture. The created allocation will not have any device or host
* \p prop structure. The created allocation will not have any device or host
* mappings. The generic memory \p handle for the allocation can be
* mapped to the address space of calling process via ::cuMemMap. This handle
* cannot be transmitted directly to other processes (see
* ::cuMemExportToShareableHandle). On Windows, the caller must also pass
* an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which
* limits or allows access to this handle for a recepient process (see
* limits or allows access to this handle for a recipient process (see
* ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this
* allocation must be a multiple of the the value given via
* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM
@@ -9702,7 +9660,7 @@ CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size,
* are unmapped and when all outstanding references to the handle (including it's
* shareable counterparts) are also released. The generic memory handle can be
* freed when there are still outstanding mappings made with this handle. Each
* time a recepient process imports a shareable handle, it needs to pair it with
* time a recipient process imports a shareable handle, it needs to pair it with
* ::cuMemRelease for the handle to be freed. If \p handle is not a valid handle
* the behavior is undefined.
*
@@ -11017,7 +10975,7 @@ CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advi
* a GPU id or CU_DEVICE_CPU depending on whether the last location for prefetch was a GPU or the CPU
* respectively. If any page in the memory range was never explicitly prefetched or if all pages were not
* prefetched to the same location, CU_DEVICE_INVALID will be returned. Note that this simply returns the
* last location that the applicaton requested to prefetch the memory range to. It gives no indication as to
* last location that the application requested to prefetch the memory range to. It gives no indication as to
* whether the prefetch operation to that location has completed or even begun.
*
* \param data - A pointers to a memory location where the result
@@ -13603,7 +13561,7 @@ CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f,
* All kernels launched must be identical with respect to the compiled code. Note that
* any __device__, __constant__ or __managed__ variables present in the module that owns
* the kernel launched on each device, are independently instantiated on every device.
* It is the application's responsiblity to ensure these variables are initialized and
* It is the application's responsibility to ensure these variables are initialized and
* used appropriately.
*
* The size of the grids as specified in blocks, the size of the blocks themselves
@@ -15180,7 +15138,7 @@ CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, c
* \param nodeParams - Parameters for the node
*
* When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in
* \p nodeParams.dptr. The allocation's address remains fixed across instantiations and launches.
* \param nodeParams.dptr. The allocation's address remains fixed across instantiations and launches.
*
* If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode,
* the allocation can be accessed by nodes ordered after the allocation node but before the free node.
@@ -15356,9 +15314,7 @@ CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, CUdeviceptr *dpt
*
* \sa
* ::cuGraphAddMemAllocNode,
* ::cuGraphAddMemFreeNode,
* ::cuDeviceSetGraphMemAttribute,
* ::cuDeviceGetGraphMemAttribute
* ::cuGraphAddMemFreeNode
*/
CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device);
@@ -15384,7 +15340,6 @@ CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device);
* ::CUDA_ERROR_INVALID_DEVICE
*
* \sa
* ::cuDeviceSetGraphMemAttribute,
* ::cuGraphAddMemAllocNode,
* ::cuGraphAddMemFreeNode
*/
@@ -15409,7 +15364,6 @@ CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, CUgraphMem_attrib
* ::CUDA_ERROR_INVALID_DEVICE
*
* \sa
* ::cuDeviceGetGraphMemAttribute,
* ::cuGraphAddMemAllocNode,
* ::cuGraphAddMemFreeNode
*/

View File

@@ -328,7 +328,7 @@ typedef enum nvmlGpuLevel_enum
typedef enum nvmlGpuP2PStatus_enum
{
NVML_P2P_STATUS_OK = 0,
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED,
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED,
NVML_P2P_STATUS_GPU_NOT_SUPPORTED,
NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED,
NVML_P2P_STATUS_DISABLED_BY_REGKEY,
@@ -736,7 +736,7 @@ typedef enum nvmlReturn_enum
NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use
NVML_ERROR_MEMORY = 20, //!< Insufficient memory
NVML_ERROR_NO_DATA = 21, //!<No data
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, because ECC is enabled
NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred
} nvmlReturn_t;
@@ -1463,7 +1463,7 @@ typedef struct nvmlEncoderSessionInfo_st
*/
typedef enum nvmlFBCSessionType_enum
{
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknwon
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknown
NVML_FBC_SESSION_TYPE_TOSYS, //!< ToSys
NVML_FBC_SESSION_TYPE_CUDA, //!< Cuda
NVML_FBC_SESSION_TYPE_VID, //!< Vid
@@ -3678,10 +3678,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetEncoderStats (nvmlDevice_t device, unsigned in
* Retrieves information about active encoder sessions on a target device.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfos. The
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
@@ -3727,7 +3727,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetDecoderUtilization(nvmlDevice_t device, unsign
* For Maxwell &tm; or newer fully supported devices.
*
* @param device The identifier of the target device
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
*
* @return
* - \ref NVML_SUCCESS if \a fbcStats is fetched
@@ -3742,10 +3742,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetFBCStats(nvmlDevice_t device, nvmlFBCStats_t *
* Retrieves information about active frame buffer capture sessions on a target device.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
@@ -4208,7 +4208,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetRetiredPages(nvmlDevice_t device, nvmlPageReti
* The address information provided from this API is the hardware address of the page that was retired. Note
* that this does not match the virtual address used in CUDA, but will match the address information in XID 63
*
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps paramter to return the time of each page's
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps parameter to return the time of each page's
* retirement.
*
* For Kepler &tm; or newer fully supported devices.
@@ -4476,7 +4476,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetDriverModel(nvmlDevice_t device, nvmlDriverMod
* Set clocks that device will lock to.
*
* Sets the clocks that the device will be running at to the value in the range of minGpuClockMHz to maxGpuClockMHz.
* Setting this will supercede application clock values and take effect regardless if a cuda app is running.
* Setting this will supersede application clock values and take effect regardless if a cuda app is running.
* See /ref nvmlDeviceSetApplicationsClocks
*
* Can be used as a setting to request constant performance.
@@ -5297,7 +5297,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetVirtualizationMode(nvmlDevice_t device, nvmlGp
* pointed at by \a vgpuTypeIds. The element count of nvmlVgpuTypeId_t array is passed in \a vgpuCount, and \a vgpuCount
* is used to return the number of vGPU types written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
* To query the number of vGPU types supported for the GPU, call this function with *vgpuCount = 0.
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are supported.
@@ -5327,9 +5327,9 @@ nvmlReturn_t DECLDIR nvmlDeviceGetSupportedVgpus(nvmlDevice_t device, unsigned i
* can concurrently run on a device. For example, if only one vGPU type is allowed at a time on a device, then the creatable
* list will be restricted to whatever vGPU type is already running on the device.
*
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
* To query the number of vGPU types createable for the GPU, call this function with *vgpuCount = 0.
* To query the number of vGPU types creatable for the GPU, call this function with *vgpuCount = 0.
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are creatable.
*
* @param device The identifier of the target device
@@ -5392,7 +5392,7 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetName(nvmlVgpuTypeId_t vgpuTypeId, char *vgpu
*
* @param vgpuTypeId Handle to vGPU type
* @param deviceID Device ID and vendor ID of the device contained in single 32 bit value
* @param subsystemID Subsytem ID and subsytem vendor ID of the device contained in single 32 bit value
* @param subsystemID subsystem ID and subsystem vendor ID of the device contained in single 32 bit value
*
* @return
* - \ref NVML_SUCCESS successful completion
@@ -5516,10 +5516,10 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetMaxInstances(nvmlDevice_t device, nvmlVgpuTy
* Retrieve the active vGPU instances on a device.
*
* An array of active vGPU instances is returned in the caller-supplied buffer pointed at by \a vgpuInstances. The
* array elememt count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
* array element count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the vGPU instance array, the function returns
* If the supplied buffer is not large enough to accommodate the vGPU instance array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuInstance_t array required in \a vgpuCount.
* To query the number of active vGPU instances, call this function with *vgpuCount = 0. The code will return
* NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU Types are supported.
@@ -5702,7 +5702,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFrameRateLimit(nvmlVgpuInstance_t vgpuIn
* @param encoderCapacity Reference to an unsigned int for the encoder capacity
*
* @return
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrived
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrieved
* - \ref NVML_ERROR_UNINITIALIZED if the library has not been successfully initialized
* - \ref NVML_ERROR_INVALID_ARGUMENT if \a vgpuInstance is 0, or \a encoderQueryType is invalid
* - \ref NVML_ERROR_NOT_FOUND if \a vgpuInstance does not match a valid active vGPU instance on the system
@@ -5863,10 +5863,10 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderStats(nvmlVgpuInstance_t vgpuInst
* Retrieves information about all active encoder sessions on a vGPU Instance.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
@@ -5896,7 +5896,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderSessions(nvmlVgpuInstance_t vgpuI
* For Maxwell &tm; or newer fully supported devices.
*
* @param vgpuInstance Identifier of the target vGPU instance
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
*
* @return
* - \ref NVML_SUCCESS if \a fbcStats is fetched
@@ -5914,7 +5914,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFBCStats(nvmlVgpuInstance_t vgpuInstance
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
@@ -6094,7 +6094,7 @@ typedef struct nvmlVgpuPgpuMetadata_st
unsigned int version; //!< Current version of the structure
unsigned int revision; //!< Current revision of the structure
char hostDriverVersion[NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE]; //!< Host driver version
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualizaion capabilities bitfileld
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualization capabilities bitfield
unsigned int reserved[7]; //!< Reserved for internal use
unsigned int opaqueDataSize; //!< Size of opaque data field in bytes
char opaqueData[4]; //!< Opaque data
@@ -6191,7 +6191,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetVgpuMetadata(nvmlDevice_t device, nvmlVgpuPgpu
*
* The caller passes in a buffer via \a compatibilityInfo, into which a compatibility information structure is written. The
* structure defines the states in which the vGPU / VM may be booted on the physical GPU. If the vGPU / VM compatibility
* with the physical GPU is limited, a limit code indicates the factor limiting compability.
* with the physical GPU is limited, a limit code indicates the factor limiting compatibility.
* (see \ref nvmlVgpuPgpuCompatibilityLimitCode_t for details).
*
* Note: vGPU compatibility does not take into account dynamic capacity conditions that may limit a system's ability to

View File

@@ -950,7 +950,7 @@ namespace half_float
/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<std::float_round_style R,bool E,typename T> T half2int_impl(uint16 value)
@@ -988,13 +988,13 @@ namespace half_float
/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<std::float_round_style R,typename T> T half2int(uint16 value) { return half2int_impl<R,HALF_ROUND_TIES_TO_EVEN,T>(value); }
/// Convert half-precision floating point to integer using round-to-nearest-away-from-zero.
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<typename T> T half2int_up(uint16 value) { return half2int_impl<std::round_to_nearest,0,T>(value); }
@@ -1053,7 +1053,7 @@ namespace half_float
/// Half-precision floating point type.
/// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and
/// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and
/// conversions. It is implicitly convertible to single-precision floating point, which makes arithmetic expressions and
/// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations
/// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to
/// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic
@@ -1062,7 +1062,7 @@ namespace half_float
/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and
/// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which
/// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not necessarily have to be of
/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most
/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit
/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if
@@ -2181,7 +2181,7 @@ namespace half_float
/// Identity.
/// \param arg operand
/// \return uncahnged operand
/// \return unchanged operand
template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator+(T arg) { return arg; }
/// Negation.
@@ -2620,7 +2620,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); }
inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); }
inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); }
@@ -2636,7 +2636,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); }
inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); }
inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); }
@@ -2644,7 +2644,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type scalbln(T arg, long exp) { return functions::scalbln(arg, exp); }
inline half scalbln(half arg, long exp) { return functions::scalbln(arg, exp); }
inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); }

View File

@@ -1,4 +1,4 @@
#pragma once
#pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_
@@ -27,7 +27,7 @@ public:
private:
// constructors
basic_block(context &ctx, const std::string &name, function *parent);
basic_block(context &ctx, const std::string &name, function *parent, basic_block *next);
public:
// accessors
@@ -35,6 +35,7 @@ public:
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
void replace_phi_uses_with(basic_block* before, basic_block* after);
iterator get_first_non_phi();
// get instruction list
@@ -60,13 +61,16 @@ public:
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
void append_instruction(ir::instruction* i);
// split
basic_block* split_before(ir::instruction* loc, const std::string& name);
// predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
const std::vector<basic_block*>& get_successors() const { return succs_; }
void add_predecessor(basic_block* pred);
std::vector<basic_block*> get_predecessors() const;
std::vector<basic_block*> get_successors() const;
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent);
static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr);
void print(std::ostream &os);

View File

@@ -22,13 +22,16 @@ class phi_node;
/* Builder */
class builder{
public:
typedef basic_block::iterator iterator;
public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
@@ -69,6 +72,9 @@ public:
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
value* create_ret(value *ret);
// Dequantize instructions
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
// Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty);
@@ -82,6 +88,9 @@ public:
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Call instruction
value* create_call(function* fn, const std::vector<value*>& args);
value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
@@ -135,9 +144,12 @@ public:
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val);
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask);
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
// Struct instructions
value *create_insert_value(value* val, value *elt, size_t idx);
value *create_extract_value(value* val, size_t idx);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
@@ -156,6 +168,15 @@ public:
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Utilities
value *create_clock();
value *create_globaltimer();
// Extern instruction
value *create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty);
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
@@ -163,7 +184,7 @@ public:
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);

View File

@@ -30,7 +30,8 @@ public:
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
// Struct types
std::map<type::contained_tys_vec_t, struct_type*> struct_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
// Float constants

View File

@@ -95,6 +95,9 @@ enum value_id_t: unsigned {
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// call
INST_CALL,
INST_LAUNCH,
// phi
INST_PHI,
// arithmetic
@@ -105,6 +108,8 @@ enum value_id_t: unsigned {
// cmp
INST_ICMP,
INST_FCMP,
// dequantize
INST_DEQUANTIZE,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
@@ -129,6 +134,9 @@ enum value_id_t: unsigned {
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// struct
INST_EXTRACT_VALUE,
INST_INSERT_VALUE,
// retile
INST_RESHAPE,
INST_SPLAT,
@@ -148,6 +156,8 @@ enum value_id_t: unsigned {
INST_COS,
INST_SIN,
INST_LOG,
// extern
INST_EXTERN_ELEMENTWISE,
// array arithmetic
INST_TRANS,
INST_REDUCE,
@@ -165,6 +175,8 @@ enum value_id_t: unsigned {
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
};

View File

@@ -24,7 +24,7 @@ public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
unsigned get_arg_no() const;
void accept(visitor *v);
@@ -112,7 +112,7 @@ public:
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);
@@ -121,6 +121,8 @@ public:
const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
bool get_is_kernel() { return is_kernel_; }
void print(std::ostream &os);
@@ -134,6 +136,7 @@ private:
args_t args_;
blocks_t blocks_;
attr_map_t attrs_;
bool is_kernel_;
};
}

View File

@@ -59,8 +59,8 @@ public:
std::string repr() const { return repr_impl(); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
std::vector<unsigned> value) { metadatas_[kind] = value;}
std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
@@ -77,10 +77,55 @@ public:
private:
basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_;
value_id_t id_;
};
//===----------------------------------------------------------------------===//
// call_inst classes
//===----------------------------------------------------------------------===//
class call_inst: public instruction {
private:
std::string repr_impl() const;
call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next);
public:
static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "", instruction *next = nullptr);
ir::function* get_fn() { return fn_; }
_TRITON_DEFINE_CLONE(call_inst)
_TRITON_DEFINE_ACCEPT(call_inst)
private:
ir::function* fn_;
};
class launch_inst: public instruction {
private:
std::string repr_impl() const { return "launch"; }
launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
const std::string &name = "", instruction *next = nullptr);
public:
static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
const std::string& name = "", instruction* next = nullptr);
ir::function* get_fn();
std::vector<ir::value*> get_values();
std::vector<ir::value*> get_grid();
ir::value* get_num_warps();
_TRITON_DEFINE_CLONE(launch_inst)
_TRITON_DEFINE_ACCEPT(launch_inst)
private:
unsigned val_begin;
unsigned val_end;
unsigned grid_begin;
unsigned grid_end;
};
//===----------------------------------------------------------------------===//
// phi_node classes
@@ -229,6 +274,24 @@ protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// dequantize_inst classes
//===----------------------------------------------------------------------===//
class dequantize_inst: public instruction{
private:
std::string repr_impl() const override { return "dequantize"; }
protected:
dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
public:
static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dequantize_inst)
_TRITON_DEFINE_ACCEPT(dequantize_inst)
};
//===----------------------------------------------------------------------===//
// cast_inst classes
@@ -390,13 +453,31 @@ private:
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
public:
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
protected:
EVICTION_POLICY eviction_;
};
// load
@@ -408,14 +489,8 @@ public:
CG,
};
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
bool get_is_volatile() const { return is_volatile_; }
protected:
@@ -425,13 +500,8 @@ protected:
std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca";
if (cache_ == CG) return ".cg";
return "";
return "";
}
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
}
EVICTION_POLICY eviction_;
CACHE_MODIFIER cache_;
std::string get_volatile_repr() {
@@ -507,7 +577,7 @@ public:
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
public:
@@ -518,11 +588,11 @@ public:
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
@@ -532,20 +602,58 @@ public:
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
_TRITON_DEFINE_ACCEPT(masked_store_inst)
};
//===----------------------------------------------------------------------===//
// struct classes
//===----------------------------------------------------------------------===//
// insert_value
class insert_value_inst: public instruction {
private:
std::string repr_impl() const { return "insertvalue"; }
insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next);
public:
static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr);
size_t get_idx() { return idx_; }
_TRITON_DEFINE_CLONE(insert_value_inst)
_TRITON_DEFINE_ACCEPT(insert_value_inst)
private:
size_t idx_;
};
// extract_value
class extract_value_inst: public instruction {
private:
std::string repr_impl() const { return "extractvalue"; }
extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next);
public:
static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr);
size_t get_idx() { return idx_; }
_TRITON_DEFINE_CLONE(extract_value_inst)
_TRITON_DEFINE_ACCEPT(extract_value_inst)
private:
size_t idx_;
};
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
@@ -671,6 +779,8 @@ private:
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
io_inst(ty, id, num_ops, NORMAL, name, next) {}
};
class atomic_rmw_inst: public atomic_inst {
@@ -758,20 +868,22 @@ public:
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
UNKNOWN,
};
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
bool is_trans_a() const { return AT_ == Trans; }
bool is_trans_b() const { return BT_ == Trans; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
@@ -788,6 +900,8 @@ private:
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
TransT AT_;
TransT BT_;
};
//class outer_inst: public builtin_inst {
@@ -829,8 +943,10 @@ public:
class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN,
ADD, SUB, MAX, MIN, UMAX, UMIN,
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
FADD, FSUB, FMAX, FMIN,
ARGFMAX, ARGFMIN,
XOR
};
@@ -848,12 +964,19 @@ public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
bool with_index() const {
return with_index_ops_.find(op_) != with_index_ops_.end();
}
private:
unsigned axis_;
op_t op_;
const static inline std::set<op_t> with_index_ops_ = {
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
@@ -941,11 +1064,11 @@ class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}
@@ -971,7 +1094,53 @@ private:
constant_int* last_;
};
/* timing utilities */
class clock_inst: public instruction{
clock_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "clock"; }
_TRITON_DEFINE_CLONE(clock_inst)
_TRITON_DEFINE_ACCEPT(clock_inst)
public:
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class globaltimer_inst: public instruction{
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "globaltimer"; }
_TRITON_DEFINE_CLONE(globaltimer_inst)
_TRITON_DEFINE_ACCEPT(globaltimer_inst)
public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class extern_elementwise_inst : public instruction {
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
type *dst_ty, const std::string &lib_name,
const std::string &extern_lib_path,
const std::string &symbol_name,
const std::string &name, instruction *next);
std::string repr_impl() const { return "extern_elementwise"; }
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
public:
static extern_elementwise_inst *create(
context &ctx, const std::vector<value *> &args, type *dst_ty,
const std::string &lib_name = "", const std::string &lib_path = "",
const std::string &symbol_name = "", const std::string &name = "",
instruction *next = nullptr);
const std::string &get_lib_name() const { return lib_name_; }
const std::string &get_lib_path() const { return lib_path_; }
const std::string &get_symbol_name() const { return symbol_name_; }
private:
std::string lib_name_;
std::string lib_path_;
std::string symbol_name_;
};
}
}

View File

@@ -3,6 +3,8 @@
#ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_
#include <vector>
namespace triton{
namespace ir{
@@ -16,14 +18,14 @@ public:
};
private:
metadata(kind_t kind, unsigned value);
metadata(kind_t kind, std::vector<unsigned> value);
public:
static metadata* get(kind_t kind, unsigned value);
static metadata* get(kind_t kind, std::vector<unsigned> value);
private:
kind_t kind_;
unsigned value_;
std::vector<unsigned> value_;
};
}

View File

@@ -34,26 +34,50 @@ class constant;
class global_value;
class alloc_const;
/* Module */
class module {
class value_constructor {
typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
lang::iteration_statement *statement;
basic_block *block;
};
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
public:
value_constructor(builder &builder);
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Metadata
private:
ir::builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
std::map<value*, value**> current_phi_;
};
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, std::vector<unsigned>> md_pair_t;
friend class function;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
@@ -63,8 +87,21 @@ public:
// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_function(const std::string& name) {
if(symbols_.find(name) == symbols_.end())
throw std::runtime_error("function " + name + " is not declared");
return (function*)symbols_.at(name);
}
function *get_or_insert_function(const std::string &name, function_type *ty);
bool has_function(const std::string& name){
return symbols_.find(name) != symbols_.end();
}
void remove_function(ir::function* fn){
functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end());
}
void reset_ret_ty(const std::string& name, type* ty);
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
@@ -72,9 +109,9 @@ public:
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void print(std::ostream &os);
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
void print(std::ostream &os);
private:
std::string name_;

View File

@@ -1,4 +1,4 @@
#pragma once
#pragma once
#ifndef _TRITON_IR_TYPE_H_
#define _TRITON_IR_TYPE_H_
@@ -21,7 +21,6 @@ class type {
public:
typedef std::vector<unsigned> block_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator;
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
@@ -69,6 +68,8 @@ public:
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;
type *get_pointer_element_ty() const;
unsigned get_struct_numel() const { return contained_tys_.size(); }
type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
@@ -84,6 +85,7 @@ public:
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
bool is_struct_ty() const { return id_ == StructTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
@@ -127,10 +129,10 @@ public:
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case BF16TyID: return "bf16";
case FP16TyID: return "f16";
case FP32TyID: return "f32";
case FP64TyID: return "f64";
case BF16TyID: return "bf16";
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
@@ -180,6 +182,16 @@ public:
type* get_type_at_index(value *idx) const;
};
class struct_type: public composite_type {
public:
struct_type(const contained_tys_vec_t& tys, bool is_packed);
unsigned get_num_types() const { return contained_tys_.size(); }
static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
private:
bool is_packed_;
};
class block_type: public composite_type {
private:
block_type(type *ty, const block_shapes_t &shapes);
@@ -228,6 +240,7 @@ public:
ty_iterator params_end() { return contained_tys_.end(); }
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
type* get_return_ty() const { return contained_tys_.at(0); }
void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
// factory methods
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
};

View File

@@ -22,6 +22,7 @@ public:
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}

View File

@@ -21,7 +21,7 @@ class visitor;
class value {
public:
typedef std::set<user*> users_t;
typedef std::vector<user*> users_t;
public:
// constructor
@@ -30,7 +30,7 @@ public:
// uses
void add_use(user* arg);
users_t::iterator erase_use(user* arg);
const std::set<user*> &get_users() { return users_; }
const std::vector<user*> &get_users() { return users_; }
void replace_all_uses_with(value *target);
// name
void set_name(const std::string &name);

View File

@@ -11,12 +11,16 @@ class value;
class instruction;
class call_inst;
class launch_inst;
class phi_node;
class binary_operator;
class getelementptr_inst;
class icmp_inst;
class fcmp_inst;
class dequantize_inst;
class cast_inst;
class trunc_inst;
class z_ext_inst;
@@ -42,6 +46,9 @@ class masked_load_inst;
class unmasked_store_inst;
class masked_store_inst;
class extract_value_inst;
class insert_value_inst;
class retile_inst;
class reshape_inst;
class splat_inst;
@@ -75,6 +82,10 @@ class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;
class extern_elementwise_inst;
class make_range_sta;
class undef_value;
@@ -103,6 +114,8 @@ public:
virtual ~visitor() {}
virtual void visit_value(ir::value*);
virtual void visit_call_inst(ir::call_inst*) = 0;
virtual void visit_launch_inst(ir::launch_inst*) = 0;
virtual void visit_basic_block(basic_block*) = 0;
virtual void visit_argument(argument*) = 0;
@@ -112,6 +125,7 @@ public:
virtual void visit_icmp_inst(icmp_inst*) = 0;
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
virtual void visit_dequantize_inst(dequantize_inst*) = 0;
virtual void visit_cast_inst(cast_inst*) = 0;
virtual void visit_return_inst(return_inst*) = 0;
@@ -130,6 +144,9 @@ public:
virtual void visit_sin_inst(sin_inst*) = 0;
virtual void visit_log_inst(log_inst*) = 0;
virtual void visit_extract_value_inst(extract_value_inst*) = 0;
virtual void visit_insert_value_inst(insert_value_inst*) = 0;
virtual void visit_reshape_inst(reshape_inst*) = 0;
virtual void visit_splat_inst(splat_inst*) = 0;
virtual void visit_cat_inst(cat_inst*) = 0;
@@ -157,11 +174,15 @@ public:
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_clock_inst(clock_inst*) = 0;
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;
virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0;
virtual void visit_alloc_const(alloc_const*) = 0;
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
};
}

View File

@@ -3,8 +3,9 @@
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
#define _TRITON_TOOLS_THREAD_GRAPH_H_
#include "llvm/ADT/SetVector.h"
#include <map>
#include <set>
#include <vector>
#include <iostream>
@@ -13,21 +14,21 @@ namespace tools{
template<class node_t>
class graph {
typedef std::map<node_t, std::set<node_t>> edges_t;
typedef std::map<node_t, llvm::SetVector<node_t>> edges_t;
public:
typedef std::map<size_t, std::vector<node_t>> cmap_t;
typedef std::map<node_t, size_t> nmap_t;
private:
void connected_components_impl(node_t x, std::set<node_t> &nodes,
void connected_components_impl(node_t x, llvm::SetVector<node_t> &nodes,
nmap_t* nmap, cmap_t* cmap, int id) const {
if(nmap)
(*nmap)[x] = id;
if(cmap)
(*cmap)[id].push_back(x);
if(nodes.find(x) != nodes.end()) {
nodes.erase(x);
if (nodes.count(x)) {
nodes.remove(x);
for(const node_t &y: edges_.at(x))
connected_components_impl(y, nodes, nmap, cmap, id);
}
@@ -39,7 +40,7 @@ public:
cmap->clear();
if(nmap)
nmap->clear();
std::set<node_t> nodes = nodes_;
llvm::SetVector<node_t> nodes = nodes_;
unsigned id = 0;
while(!nodes.empty()){
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
@@ -59,7 +60,7 @@ public:
}
private:
std::set<node_t> nodes_;
llvm::SetVector<node_t> nodes_;
edges_t edges_;
};

View File

@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++) {
result.push_back(op_cst[d]);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
@@ -129,6 +141,36 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax = {1, 0};
// Examples:
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
//
// if LHS is a range of N continuous (or equal) elements that starts at M,
// and RHS is a set of N constants that start at K
// then the result in constant in groups of gcd(M, K)
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
rhs[d].num_cst % lhs[d].num_cst == 0)
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
@@ -136,12 +178,14 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
@@ -180,10 +224,14 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_is_constant_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
return populate_is_constant_cmp(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
@@ -245,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++) {
unsigned factor = 1;
if (d == shapes.size() - 1) {
factor = ret_last_dim / op_last_dim;
}
result.push_back(factor * op_mc[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
@@ -285,8 +350,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
@@ -332,9 +397,9 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
if(max_contiguous > 0)
return add_to_cache(x, {max_contiguous}, max_contiguous_);
std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
if(!max_contiguous.empty())
return add_to_cache(x, max_contiguous, max_contiguous_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x);
@@ -342,6 +407,8 @@ std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_max_contiguous_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
@@ -386,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto op_multiple = populate_starting_multiple(op);
for(size_t d = 0; d < shapes.size(); d++) {
unsigned factor = 1;
if (d == shapes.size() - 1) {
factor = ret_last_dim / op_last_dim;
}
result.push_back(factor * op_multiple[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
@@ -401,7 +485,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = 1;
result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]);
}
@@ -471,28 +555,42 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
return add_to_cache(v, {1}, starting_multiple_);
}
unsigned get_max_multiple(int val){
if(val == 0) return 1 << 31;
if(val % 128 == 0) return 128;
if(val % 64 == 0) return 64;
if(val % 32 == 0) return 32;
if(val % 16 == 0) return 16;
if(val % 8 == 0) return 8;
if(val % 4 == 0) return 4;
if(val % 2 == 0) return 2;
return 1;
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(!multiple_of.empty())
return add_to_cache(x, multiple_of, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_starting_multiple_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
@@ -511,12 +609,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
return is_constant_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {

View File

@@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) {
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
for(shared_layout* x: V){
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
}
}
}

View File

@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_dequantize(ir::instruction *i) {
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
auto shapes = dequantize->get_type()->get_block_shapes();
ir::value *op = dequantize->get_operand(0);
// add edge except the last axis
for(unsigned d = 0; d < shapes.size() - 1; d ++){
graph_.add_edge({i, d}, {op, d});
}
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_block_shapes();
@@ -79,7 +90,7 @@ void axes::update_graph_dot(ir::instruction *i) {
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i,
void axes::update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async) {
if(i->get_num_operands() == 0)
return;
@@ -119,6 +130,7 @@ void axes::update_graph(ir::instruction *i) {
case ir::INST_SPLAT: return update_graph_no_edge(i);
case ir::INST_CAT: return update_graph_elementwise(i, true);
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);

View File

@@ -209,14 +209,13 @@ mma_layout::mma_layout(size_t num_warps,
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
contig_per_thread_ = {1, 1};
order_ = {0, 1};
}
else{
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 2};
// rep_ = {2, 2, 1};
order_ = {1, 0};
}
order_ = {0, 1};
/* warps per tile */
wpt_ = {1, 1, 1};
@@ -232,24 +231,45 @@ mma_layout::mma_layout(size_t num_warps,
}while(wpt_nm1 != wpt_);
} else {
bool changed = false;
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
// try to have a warp own entire rows of the output
// this makes it easier to fuse multiple mmas by fusing
// registers
bool one_warp_per_row = false;
for(ir::value* v: values)
for(ir::user* u: v->get_users()){
auto* dot = dynamic_cast<ir::dot_inst*>(u);
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
}
// std::cout << one_warp_per_row << std::endl;
if(one_warp_per_row){
wpt_[1] = 1;
wpt_[0] = num_warps;
}
else{
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} while (changed);
} while(changed);
}
}
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
/* shape per block */
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
@@ -347,12 +367,16 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
static bool is_smem(ir::value* v) {
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v))
return true;
else
return false;
static bool is_smem_in(ir::value* v, const ir::basic_block* bb) {
if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) {
if (instr->get_parent() != bb)
return false;
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v)) {
return true;
}
}
return false;
}
/// param:
@@ -367,14 +391,14 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
if (is_smem(c0)) {
if (is_smem_in(c0, cbb0)) {
assert(cbb0 == bb0);
values_0.push_back(c0);
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
next = phi1;
continue;
} else {
if (is_smem(c1)) {
if (is_smem_in(c1, cbb1)) {
value_1 = c1;
assert(cbb1 == bb1);
return true;
@@ -429,8 +453,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
analysis::align* align, target *tgt, bool is_tmp)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
size_ = 0;
arg_layout_ = arg;
@@ -587,6 +611,45 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
}
}
// layout checkers
bool layouts::is_scanline(ir::instruction *i) {
return this->get(i->get_operand(0))->to_scanline() != nullptr;
}
bool layouts::is_coalesced_scanline(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
auto *scanline = this->get(i->get_operand(0))->to_scanline();
return scanline && scanline->get_order()[0] == red->get_axis();
}
return false;
}
bool layouts::is_mma(ir::instruction *i) {
return this->get(i->get_operand(0))->to_mma() != nullptr;
}
bool layouts::is_a100_mma(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
(red->get_axis() == 1);
}
return false;
}
void layouts::create_tmp_layout(size_t id, data_layout *arg,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
if (is_index) {
tmp_index_[i] = id;
} else {
tmp_[i] = id;
}
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
@@ -608,22 +671,29 @@ void layouts::run(ir::module &mod) {
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
// std::cout << "layout: " << std::endl;
// i->print(std::cout);
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
distributed_layout *layout =
dynamic_cast<analysis::distributed_layout *>(get(arg));
// shape
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
unsigned axis = red->get_axis();
shapes[axis] =
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id;
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
if (red->with_index()) {
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
}
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
id++;
size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){
@@ -636,13 +706,12 @@ void layouts::run(ir::module &mod) {
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
id++;
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id;
create_tmp_layout(id, nullptr, {}, {1}, atom);
}
});

View File

@@ -14,43 +14,108 @@ namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(!layout || layout->is_tmp())
continue;
for(ir::value* v:layout->get_values()){
layouts_map[v].insert(layout);
}
}
// create live intervals
std::map<ir::user*, std::set<shared_layout*>> live_in;
while(true){
bool changed = false;
ir::instruction* last_inst = nullptr;
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
// gen
std::set<shared_layout*> gen;
for(ir::value* v: i->ops())
for(shared_layout* layout: layouts_map[v])
gen.insert(layout);
// kill
std::set<shared_layout*> kill;
for(shared_layout* layout: layouts_map[i])
kill.insert(layout);
// temporaries are handled separately
if(layouts_->has_tmp(i)){
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
}
if(layouts_->has_tmp_index(i)){
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
}
// live-out
std::set<shared_layout*> live_out;
std::vector<ir::instruction*> succs = {last_inst};
if(i == i->get_parent()->get_inst_list().back())
for(ir::basic_block* succ: i->get_parent()->get_successors())
succs.push_back(succ->get_inst_list().front());
for(ir::instruction* succ: succs)
for(shared_layout* layout: live_in[succ])
if(!layout->is_tmp())
live_out.insert(layout);
// new sets
std::set<shared_layout*> live_out_minus_kill;
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
std::set<shared_layout*> new_live_in;
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
std::inserter(new_live_in, new_live_in.end()));
changed = changed || (new_live_in != live_in[i]);
live_in[i] = new_live_in;
last_inst = i;
});
if(!changed)
break;
}
// ir::for_each_instruction(mod, [&](ir::instruction* i){
// i->print(std::cout);
// std::cout << " live_in: " << live_in[i].size() << std::endl;
// });
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
slot_index index = 0;
ir::for_each_instruction(mod, [&](ir::instruction* instr){
index += 1;
indices.insert({instr, index});
});
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(layout)
intervals_[layout] = segment{INT32_MAX, 0};
}
for(auto& x: live_in)
for(shared_layout* layout: x.second)
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
for(auto& x: live_in)
for(shared_layout* layout: x.second){
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
}
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
}
}

View File

@@ -28,12 +28,15 @@ void swizzle::run(ir::module &) {
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int per_phase = 1;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(in_layout)
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
else
per_phase = 1;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
@@ -46,7 +49,7 @@ void swizzle::run(ir::module &) {
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}

63
lib/codegen/extern_lib.cc Normal file
View File

@@ -0,0 +1,63 @@
#include "triton/codegen/extern_lib.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Type.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "triton/codegen/pass.h"
namespace triton {
namespace codegen {
std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
llvm::SMDiagnostic err;
auto mod = llvm::parseIRFile(this->path_, err, ctx);
if (!mod) {
throw std::runtime_error("Failed to load extern lib " + this->name_ +
" at " + this->path_);
}
return mod;
}
void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
std::unique_ptr<llvm::Module>& mod) {
// Set triple and data layout to match the target module
mod->setTargetTriple(llvm->getTargetTriple());
mod->setDataLayout(llvm->getDataLayout());
if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
throw std::runtime_error("Failed to link extern lib " + this->name_ +
" at " + this->path_);
}
}
void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
// Add nvvm reflect flags to llvm module
// https://llvm.org/docs/LangRef.html#module-flags-metadata
// i32 4: Override the other module.
// i32 1: Emit an error
// If both modules specify Override, but the values differ, an error
// will be emitted.
llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata* md_four =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata* md_one =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
llvm->addModuleFlag(reflect);
}
std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
const std::string& lib_path) {
if (lib_name == "libdevice") {
return std::make_unique<LibDevice>(lib_name, lib_path);
} else {
throw std::runtime_error("Unknown external library: " + lib_name);
}
}
} // namespace codegen
} // namespace triton

View File

@@ -1,4 +1,14 @@
#include "triton/codegen/pass.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
@@ -9,6 +19,7 @@
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/inline.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
@@ -16,43 +27,90 @@
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace triton {
namespace codegen {
static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
const ExternLibMap& target_extern_lib_map,
ir::module& ir, llvm::LLVMContext& ctx,
std::unique_ptr<llvm::Module>& llvm) {
for (const auto& iter : target_extern_lib_map) {
auto &lib_name = iter.first;
if (user_extern_lib_map.count(lib_name) != 0 &&
user_extern_lib_map.at(lib_name)->path() != "") {
// If the user specified a path for this library, use it.
user_extern_lib_map.at(lib_name)->install(ctx, llvm);
} else {
// Otherwise, use the default path.
iter.second->install(ctx, llvm);
}
}
std::set<llvm::StringRef> function_names;
for (auto& func : ir.get_function_list()) {
function_names.insert(func->get_name());
}
llvm::legacy::PassManager pass;
pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
if (function_names.count(v.getName()) != 0) {
// Preserve global functions
return true;
}
// Internalize all device functions
return false;
}));
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*llvm);
llvm::PassManagerBuilder builder;
builder.OptLevel = 3;
builder.SizeLevel = 0;
builder.populateModulePassManager(pass);
pass.run(*llvm);
}
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target,
int cc, int num_warps, int num_stages, int& shared_static) {
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
int num_warps, int num_stages, int& shared_static,
const ExternLibMap& extern_lib_map) {
// generate llvm code
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::transform::inliner inliner;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::pipeline pipeline(has_sm80, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::transform::cts cts(&layouts, has_sm80);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps);
codegen::transform::membar barriers(&liveness, &layouts, &allocation,
&prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
target, num_warps);
// run passes
inliner.run(ir);
dce.run(ir);
peephole.run(ir);
dce.run(ir);
pipeline.run(ir);
dce.run(ir);
dce.run(ir);
// ir.print(std::cout);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
@@ -60,8 +118,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
if (target->is_gpu()) cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
@@ -69,8 +126,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
dce.run(ir);
align.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
if (target->is_gpu()) cts.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
@@ -81,14 +137,34 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
liveness.run(ir);
allocation.run(ir);
prefetch_s.run(ir);
barriers.run(ir);
// exit(1);
// ir.print(std::cout);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
// sm < 70 (Pascal) has little shared memory resource.
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
if (shared_static >= 65536) {
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
}
}
if (isel.get_extern_lib_map().size() > 0) {
// If there's any extern lib calls,
// we need to link them in.
link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
}
return llvm;
}
} // namespace codegen
} // namespace triton
} // namespace codegen
} // namespace triton

File diff suppressed because it is too large Load Diff

View File

@@ -12,46 +12,11 @@ namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// simplify layout conversions using the following simple rules:
// - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// - cvt_1(elementwise(x, y)) = elementwise(convert(x), convert(y))
//ir::value* coalesce::simplify(ir::instruction *inst, ir::builder& builder){
// ir::value* _op = inst->get_operand(0);
// ir::instruction* op = dynamic_cast<ir::instruction*>(_op);
// analysis::mma_layout* mma_in = layout_->get(op) ->to_mma();
// analysis::mma_layout* mma_out = layout_->get(inst)->to_mma();
// std::cout << 1 << std::endl;
// // i must be layout conversion instruction
// if(!mma_in && !mma_out)
// return inst;
// // - cvt_1(cvt_2(x)) if convert1 is the inverse of convert2
// bool is_op_cvt = op->get_id() == ir::INST_CVT_LAYOUT;
// if((mma_in || mma_out) && is_op_cvt &&
// (layout_->get(inst) == layout_->get(op->get_operand(0))))
// return op->get_operand(0);
// // - cvt_1(elementwise(x, y)) = elementwise(cvt_1(x), cvt_2(y))
// if(op->get_id() != ir::INST_BINOP && op->get_id() != ir::INST_GETELEMENTPTR)
// return inst;
// std::cout << 1 << std::endl;
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = inst->clone();
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, simplify(new_arg_i, builder));
// }
// std::cout << 2 << std::endl;
// return op;
//}
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
void coalesce::run(ir::module &mod) {
std::set<analysis::data_layout*> invalidated;
ir::builder& builder = mod.get_builder();
// add layout conversion instructions
for(ir::function *fn: mod.get_function_list())
@@ -61,17 +26,38 @@ void coalesce::run(ir::module &mod) {
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
if(ir::value* op = i->get_operand(1))
if(op->get_type()->is_block_ty())
if(layout_->get(op)->to_mma()){
if(op->get_type()->get_tile_ranks1() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma())
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
i->replace_uses_of_with(op, new_op);
}
// coalesce before copy_to_shared
// only necessary for sm < 80 as Ampere+ can handle reduction
// on MMA layout
if(!has_sm80_)
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
if(ir::value* op = i->get_operand(0))
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_ranks1() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma()){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
op->replace_all_uses_with(new_op);
new_op->replace_uses_of_with(new_op, op);
invalidated.insert(layout_->get(op));
}
// uncoalesce after load
if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_rank()==2)
if(layout_->get(x)->to_mma()){
if(x->get_type()->get_tile_ranks1()==2)
if(layout_->get(x)->to_mma())
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
builder.set_insert_point_after(x);
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x);
@@ -89,9 +75,11 @@ void coalesce::run(ir::module &mod) {
auto out_contig = align_->contiguous(ptr);
auto val_inst = dynamic_cast<ir::instruction*>(val);
if(!val_inst)
break;
continue;
if(dynamic_cast<ir::cvt_layout_inst*>(val))
break;
continue;
if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
continue;
std::vector<unsigned> in_contig;
std::vector<ir::instruction*> queue = {val_inst};
std::set<ir::instruction*> seen;

View File

@@ -1,8 +1,10 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
@@ -10,9 +12,9 @@ namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
bool cts::is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
return op == 0 || op == 1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
return false;
}
inline bool is_shmem_res(ir::value* v){
bool cts::is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
// run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
return;
}
// already in shared memory
@@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
copies.insert({x, copy});
parent->replace_uses_of_with(x, copies.at(x));
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k)){
add_copy(i, i->get_operand(k), builder, true);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
// Precompute where copies should be added
std::set<ir::value*> shmem_ops;
std::set<ir::value*> shmem_res;
ir::for_each_instruction(mod, [&](ir::instruction* i) {
if(i->get_id() == ir::INST_DOT){
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
ir::value* lhs = i->get_operand(0);
ir::type* ty = lhs->get_type()->get_scalar_ty();
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
// TODO: V100
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
if(is_lhs_shmem)
shmem_ops.insert(lhs);
shmem_ops.insert(i->get_operand(1));
}
}
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS ||
i->get_id() == ir::INST_COPY_TO_SHARED ||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
shmem_res.insert(i);
});
// Add shared copies
std::map<ir::value*, ir::value*> copies;
ir::builder &builder = mod.get_builder();
ir::for_each_instruction(mod, [&](ir::instruction* i) {
size_t num_op = i->get_num_operands();
for(size_t k = 0; k < num_op; k++){
ir::value* op = i->get_operand(k);
// copy to shared operands
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
if(is_shmem_op)
add_copy(i, op, builder, true, copies);
}
});
}

View File

@@ -3,6 +3,7 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
namespace codegen{
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_RMW:
case ir::INST_ATOMIC_EXCH:
case ir::INST_CALL:
case ir::INST_LAUNCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
}
}
// delete
for(ir::instruction* i: to_delete)
i->erase_from_parent();

View File

@@ -0,0 +1,147 @@
#include <iostream>
#include "triton/codegen/transform/inline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace transform{
bool fncmp::operator()(ir::function* x, ir::function* y) const {
auto fn_list = x->get_parent()->get_function_list();
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
};
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
std::list<ir::call_inst*>& callsites){
ir::basic_block* parent_block = callsite->get_parent();
ir::function* parent_fn = parent_block->get_parent();
// the parent block is split into block A and block B:
// - block A (`new_blocks[0]`) is the entry block of the inlined function
// - block B (`exit`) resumes execution of the parent function
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
ir::basic_block* exit = entry->get_successors()[0];
std::vector<ir::basic_block*> new_blocks = {entry};
for(size_t i = 1; i < fn->blocks().size(); i++){
ir::basic_block* block = fn->blocks()[i];
ir::context& ctx = block->get_context();
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
}
// a phi node holds the return values of the inlined function
if(exit->get_inst_list().empty())
builder.set_insert_point(exit);
else
builder.set_insert_point(exit->get_first_non_phi());
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
callsite->replace_all_uses_with(exit_val);
callsite->erase_from_parent();
// get arguments `fn` is called with
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
// Actually generate the instructions:
// - Remove the branch created by basic_block::split_before
// - Clone all instructions
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
terminator->erase_from_parent();
std::map<ir::instruction*, ir::instruction*> inst_map;
std::map<ir::argument*, ir::value*> arg_map;
for(size_t k = 0; k < fn->args().size(); k++)
arg_map[fn->args()[k]] = callsite->ops()[k];
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// clone instructions
for(size_t i = 0; i < new_blocks.size(); i++){
ir::basic_block* old_block = fn->blocks()[i];
ir::basic_block* new_block = new_blocks[i];
builder.set_insert_point(new_block);
for(ir::instruction* old_inst: old_block->get_inst_list()){
ir::instruction* new_inst = old_inst->clone();
inst_map[old_inst] = new_inst;
builder.insert(new_inst);
}
}
// update basic blocks
for(size_t i = 0; i < new_blocks.size(); i++) {
for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
// replace basic use cases
for(size_t k = 0; k < new_blocks.size(); k++)
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
// additionally replace basic blocks of phi-nodes since
// replace_uses_of_with() does not replace them.
for(unsigned in = 0; in < phi->get_num_incoming(); in++)
for(size_t k = 0; k < new_blocks.size(); k++)
if (phi->get_incoming_block(in) == fn->blocks()[k])
phi->set_incoming_block(in, new_blocks[k]);
}
}
}
// replace operands of instructions after constructing inst_map
for (auto& it: inst_map) {
ir::instruction* new_inst = it.second;
for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
ir::value* op = new_inst->get_operand(k);
if(auto arg_op = dynamic_cast<ir::argument*>(op))
new_inst->set_operand(k, arg_map.at(arg_op));
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// handles a ret instruction.
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
if(ir::value* ret_val = ret->get_return_value())
exit_val->add_incoming(ret_val, new_inst->get_parent());
// replace ret with branch
ir::instruction* new_br_inst = ir::branch_inst::create(exit);
builder.set_insert_point(new_inst->get_parent());
builder.insert(new_br_inst);
new_inst->erase_from_parent();
}
}
if(exit_val->get_num_incoming() == 1)
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
// done -- make sure insert point is properly set to exit block
builder.set_insert_point(exit);
}
void inliner::run(ir::module &mod) {
// gather all call sites
while(true){
std::map<ir::function*, size_t> counts;
for(ir::function* fn: mod.get_function_list())
counts[fn] = 0;
std::list<ir::call_inst*> callsites;
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* instr: block->get_inst_list())
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
callsites.push_back(call);
counts[call->get_fn()] += 1;
}
}
for(auto& count: counts){
if(!count.first->get_is_kernel() && count.second == 0)
count.first->get_parent()->remove_function(count.first);
}
if(callsites.empty())
break;
for(ir::call_inst* call: callsites)
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
}
}
}
}
}

View File

@@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
else{
if(layouts_->has_tmp(v))
return async_write.size() - 1;
// // Ignore copy_to_shared. It won't modify async behavior.
// if(dynamic_cast<ir::copy_to_shared_inst*>(v))
// return 0;
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
@@ -60,15 +63,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
for(ir::value* b: bs){
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) ||
intersect_with(a_layout, b_tmp_index) ||
intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp))
intersect_with(a_tmp, b_tmp) ||
intersect_with(a_tmp, b_tmp_index) ||
intersect_with(a_tmp_index, b_layout) ||
intersect_with(a_tmp_index, b_tmp) ||
intersect_with(a_tmp_index, b_tmp_index))
ret.insert(b);
}
}

View File

@@ -87,7 +87,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}
@@ -150,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
}
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
}
return false;
}
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
if(!extracted)
return false;
size_t extract_idx = extracted->get_idx();
ir::value* agg = extracted->get_operand(0);
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
while(insert){
agg = insert->get_operand(0);
ir::value* inserted = insert->get_operand(1);
size_t insert_idx = insert->get_idx();
insert = dynamic_cast<ir::insert_value_inst*>(agg);
if(extract_idx == insert_idx){
extracted->replace_all_uses_with(inserted);
return true;
}
insert = dynamic_cast<ir::insert_value_inst*>(agg);
}
return false;
}
@@ -291,6 +312,7 @@ void peephole::run(ir::module &mod) {
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_insert_extract(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD

View File

@@ -134,6 +134,7 @@ void pipeline::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;

View File

@@ -138,6 +138,7 @@ CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice
CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
// link management
CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **);
CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);

View File

@@ -90,7 +90,7 @@ void check(CUresult err)
case CUDA_ERROR_NOT_PERMITTED : throw not_permitted();
case CUDA_ERROR_NOT_SUPPORTED : throw not_supported();
case CUDA_ERROR_UNKNOWN : throw unknown();
default : throw unknown();
default : throw std::runtime_error("unimplemented code: " + std::to_string(err));
}
}

View File

@@ -1,27 +1,27 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <fstream>
#if __has_include(<unistd.h>)
#include <unistd.h>
#include <unistd.h>
#endif
#include <memory>
#include <regex>
@@ -59,300 +59,318 @@
#include "llvm/Analysis/TargetLibraryInfo.h"
// end AMD stuff
extern "C"{
int set_curterm(char* nterm){ return 0; }
int del_curterm(char* nterm){ return 0; }
extern "C"
{
int set_curterm(char *nterm) { return 0; }
int del_curterm(char *nterm) { return 0; }
int tigetnum(char *capname) { return 0; }
int setupterm(char *term, int fildes, int *errret) { return 0; }
}
namespace triton{
namespace driver{
void init_llvm() {
static bool init = false;
if(!init){
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
init = true;
}
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){
size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace);
if(start_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
std::string path_to_ptxas(int& version) {
std::string ret;
// search pathes for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if(!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for(std::string prefix: ptxas_prefixes){
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if(works)
working_ptxas.push_back(ptxas);
}
// error if no working ptxas was found
if(working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
if(std::regex_search(ret, match, version_regex)){
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major*1000 + minor*10;
}
else
throw std::runtime_error("couldn't parse ptxas version: " + ret);
return ptxas;
}
int vptx(int version){
if(version >= 11040) return 74;
if(version >= 11030) return 73;
if(version >= 11020) return 72;
if(version >= 11010) return 71;
if(version >= 11000) return 70;
if(version >= 10020) return 65;
if(version >= 10010) return 64;
if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 10+");
}
std::string llir_to_ptx(llvm::Module* module, int cc, int version){
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto* short_ptr = static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptx = vptx(version);
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
// pm.add(llvm::createDeadCodeEliminationPass());
// pm.add(llvm::createEarlyCSEPass());
pm.run(*module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while(find_and_replace(result, "\t// begin inline asm", "\n", ""));
while(find_and_replace(result, "\t// end inline asm", "\n", ""));
return result;
}
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int cc) {
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char* _fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if(err != 0){
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
dispatch::cuModuleLoadData(&ret, cubin.c_str());
return cubin;
}
/* ------------------------ */
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module* module, const std::string& _proc) {
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features;
std::string proc = "gfx908";
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive);
// set data layout
if(layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// create dump files
std::string module_name = module->getModuleIdentifier();
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
namespace triton
{
namespace driver
{
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
}
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
pass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
void init_llvm()
{
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
}
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result)
{
std::cout << "ld.lld execute fail: " << std::endl;
std::cout << error_message << std::endl;
std::cout << lld_result << std::endl;
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
static bool find_and_replace(std::string &str, const std::string &begin, const std::string &end, const std::string &target)
{
size_t start_replace = str.find(begin);
size_t end_replace = str.find(end, start_replace);
if (start_replace == std::string::npos)
return false;
str.replace(start_replace, end_replace + 1 - start_replace, target);
return true;
}
return hsaco_path;
}
std::string path_to_ptxas(int &version)
{
std::vector<std::string> rets;
std::string ret;
// search paths for ptxas
std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
if (!triton_ptxas.empty())
ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
// see what path for ptxas are valid
std::vector<std::string> working_ptxas;
for (std::string prefix : ptxas_prefixes)
{
std::string ptxas = prefix + "ptxas";
bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
if (works)
{
working_ptxas.push_back(ptxas);
rets.push_back(ret);
}
}
// error if no working ptxas was found
if (working_ptxas.empty())
throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
" but a working version could not be found.");
std::string ptxas = working_ptxas.front();
// parse version
std::regex version_regex("release (\\d+)\\.(\\d+)");
std::smatch match;
bool found = false;
// currently choosing the first ptxas. Other logics can be implemented in future
for (std::string ret : rets)
{
if (std::regex_search(ret, match, version_regex))
{
int major = std::stoi(match[1]);
int minor = std::stoi(match[2]);
version = major * 1000 + minor * 10;
found = true;
break;
}
}
if (not found)
{
throw std::runtime_error("Error in parsing version");
}
return ptxas;
}
int vptx(int version)
{
if (version >= 11040)
return 74;
// if(version >= 11030) return 73;
// if(version >= 11020) return 72;
// if(version >= 11010) return 71;
// if(version >= 11000) return 70;
// if(version >= 10020) return 65;
// if(version >= 10010) return 64;
// if(version >= 10000) return 63;
throw std::runtime_error("Triton requires CUDA 11.4+");
}
hipModule_t amdgpu_to_hipmodule(const std::string& path) {
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::string llir_to_ptx(llvm::Module *module, int cc, int version)
{
// LLVM version in use may not officially support target hardware
int max_nvvm_cc = 75;
int max_nvvm_ptx = 74;
// options
auto options = llvm::cl::getRegisteredOptions();
auto *short_ptr = static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(short_ptr);
short_ptr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
// max PTX version
int ptx = vptx(version);
int ptx_major = ptx / 10;
int ptx_minor = ptx % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
// pm.add(llvm::createPrintModulePass(llvm::outs()));
pm.add(llvm::createVerifierPass());
pm.run(*module);
// module->print(llvm::outs(), nullptr);
std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
// create machine
module->setTargetTriple(triple);
std::string error;
llvm::TargetMachine *machine;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// emit
machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);
// post-process
std::string result(buffer.begin(), buffer.end());
find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
;
while (find_and_replace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas, int cc)
{
// compile ptx with ptxas
char _fsrc[L_tmpnam];
char _flog[L_tmpnam];
std::tmpnam(_fsrc);
std::tmpnam(_flog);
std::string fsrc = _fsrc;
std::string flog = _flog;
std::string fbin = fsrc + ".o";
const char *_fbin = fbin.c_str();
std::ofstream ofs(fsrc);
ofs << ptx << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
err = system(cmd.c_str());
if (err != 0)
{
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
unlink(_fsrc);
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
return cubin;
}
/* ------------------------ */
// HIP //
/* ------------------------ */
std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
{
init_llvm();
// proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
// features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features;
std::string proc = "gfx908";
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
llvm::Reloc::PIC_, llvm::None,
llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);
// create dump files
std::string module_name = module->getModuleIdentifier();
std::error_code ec;
// Save GCN ISA binary.
std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
if (ec)
{
std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
}
// emit
machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
pass.run(*module);
// Save GCN ISA.
std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
std::string result(buffer.begin(), buffer.end());
std::ofstream amdgcn(amdgcn_path);
amdgcn << result;
amdgcn.close();
// generate HASCO file
std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
std::string error_message;
int lld_result =
llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
{"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
llvm::None, {}, 0, 0, &error_message);
if (lld_result)
{
std::cout << "ld.lld execute fail: " << std::endl;
std::cout << error_message << std::endl;
std::cout << lld_result << std::endl;
}
return hsaco_path;
}
hipModule_t amdgpu_to_hipmodule(const std::string &path)
{
// Read HSACO.
std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
std::vector<unsigned char> hsaco(hsaco_file_size);
hsaco_file.seekg(0, std::ios::beg);
hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
hsaco_file.close();
hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
hipJitOptionLogVerbose};
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void* optval[] = {(void*)(uintptr_t)errbufsize,
(void*)_err, (void*)(uintptr_t)logbufsize,
(void*)_log, (void*)1};
hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret;
}
}
}
const unsigned int errbufsize = 8192;
const unsigned int logbufsize = 8192;
char _err[errbufsize];
char _log[logbufsize];
void *optval[] = {(void *)(uintptr_t)errbufsize,
(void *)_err, (void *)(uintptr_t)logbufsize,
(void *)_log, (void *)1};
hipModule_t ret;
dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
return ret;
}
} // namespace driver
} // namespace triton

View File

@@ -1,3 +1,5 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
@@ -9,23 +11,71 @@ namespace ir {
class phi_node;
basic_block::basic_block(context &ctx, const std::string &name, function *parent):
basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next):
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
if(parent_)
parent_->insert_block(this);
parent_->insert_block(this, next);
}
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){
return new basic_block(ctx, name, parent);
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){
return new basic_block(ctx, name, parent, next);
}
void basic_block::add_predecessor(basic_block *pred) {
preds_.push_back(pred);
if(pred)
pred->succs_.push_back(this);
void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) {
for(ir::instruction* i: inst_list_){
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
if(!curr_phi)
break;
// curr_phi->replace_uses_of_with(before, after);
for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx)
if (curr_phi->get_incoming_block(idx) == before)
curr_phi->set_incoming_block(idx, after);
}
}
void basic_block::append_instruction(ir::instruction* i){
i->set_parent(this);
inst_list_.push_back(i);
}
basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) {
basic_block* ret = basic_block::create(ctx_, name, parent_, this);
ret->set_name(get_name());
set_name("after_" + name);
// splice instruction list
auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc);
ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it);
for(ir::instruction* i: ret->get_inst_list())
i->set_parent(ret);
// the predecessors of `this` becomes the predecessors of `ret`
for(ir::basic_block* pred: get_predecessors()){
auto* term = dynamic_cast<ir::terminator_inst*>(pred->get_inst_list().back());
assert(term);
term->replace_uses_of_with(this, ret);
replace_phi_uses_with(pred, ret);
}
ir::branch_inst* br = branch_inst::create(this);
ret->append_instruction(br);
return ret;
}
std::vector<basic_block*> basic_block::get_predecessors() const {
std::vector<basic_block*> ret;
for(ir::user* u: users_)
if(auto term = dynamic_cast<ir::terminator_inst*>(u))
ret.push_back(term->get_parent());
return ret;
}
std::vector<basic_block*> basic_block::get_successors() const {
std::vector<basic_block*> ret;
for(ir::instruction* i: inst_list_)
for(ir::value* v: i->ops())
if(auto block = dynamic_cast<ir::basic_block*>(v))
ret.push_back(block);
return ret;
}
basic_block::iterator basic_block::get_first_non_phi(){
auto it = begin();

View File

@@ -105,13 +105,10 @@ type *builder::get_double_ty()
//===----------------------------------------------------------------------===//
value* builder::create_br(basic_block *dest){
dest->add_predecessor(block_);
return insert(branch_inst::create(dest));
}
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
if_dest->add_predecessor(block_);
else_dest->add_predecessor(block_);
return insert(branch_inst::create(cond, if_dest, else_dest));
}
@@ -119,6 +116,18 @@ value *builder::create_ret_void() {
return insert(return_inst::create(ctx_));
}
value *builder::create_ret(value* val) {
return insert(return_inst::create(ctx_, val));
}
//===----------------------------------------------------------------------===//
// dequantize instructions
//===----------------------------------------------------------------------===//
value* builder::create_dequantize(value *src, value *scale, value *shift, type *dst_ty){
return insert(dequantize_inst::create(src, scale, shift, dst_ty));
}
//===----------------------------------------------------------------------===//
// cast instructions
//===----------------------------------------------------------------------===//
@@ -153,6 +162,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){
return insert(phi_node::create(ty, num_reserved));
}
//===----------------------------------------------------------------------===//
// call instructions
//===----------------------------------------------------------------------===//
value *builder::create_call(function* fn, const std::vector<value*>& args){
return insert(call_inst::create(fn, args));
}
value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){
return insert(launch_inst::create(fn, args, grid, num_warps));
}
//===----------------------------------------------------------------------===//
// binary float instructions
//===----------------------------------------------------------------------===//
@@ -285,18 +307,31 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in
return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile));
}
value *builder::create_store(value *ptr, value *val){
return insert(unmasked_store_inst::create(ptr, val));
value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){
return insert(unmasked_store_inst::create(ptr, val, eviction));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){
return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile));
}
value *builder::create_masked_store(value *ptr, value *val, value *mask){
return insert(masked_store_inst::create(ptr, val, mask));
value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){
return insert(masked_store_inst::create(ptr, val, mask, eviction));
}
//===----------------------------------------------------------------------===//
// struct instructions
//===----------------------------------------------------------------------===//
// Struct instructions
value *builder::create_insert_value(value* val, value *elt, size_t idx){
return insert(insert_value_inst::create(val, elt, idx));
}
value *builder::create_extract_value(value* val, size_t idx) {
return insert(extract_value_inst::create(val, idx));
}
//===----------------------------------------------------------------------===//
// block instructions
//===----------------------------------------------------------------------===//
@@ -343,6 +378,28 @@ DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
// Utilities
value *builder::create_clock() {
return insert(clock_inst::create(ctx_));
}
value *builder::create_globaltimer() {
return insert(globaltimer_inst::create(ctx_));
}
//===----------------------------------------------------------------------===//
// externs
//===----------------------------------------------------------------------===//
value *builder::create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty) {
return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name,
lib_path, symbol_name));
}
//===----------------------------------------------------------------------===//
// built-in instructions
//===----------------------------------------------------------------------===//
@@ -376,8 +433,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg));
}
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) {
return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32));
}
value *builder::create_trans(value *A, const std::vector<int>& perm) {

View File

@@ -18,6 +18,8 @@ constant *constant::get_null_value(type *ty) {
return constant_int::get(ty, 0);
case type::FP16TyID:
return constant_fp::get(type::get_fp16_ty(ctx), 0);
case type::BF16TyID:
return constant_fp::get(type::get_bf16_ty(ctx), 0);
case type::FP32TyID:
return constant_fp::get(type::get_fp32_ty(ctx), 0);
case type::FP64TyID:

View File

@@ -33,8 +33,10 @@ void argument::accept(visitor *v) {
/* function */
function::function(function_type *ty, linkage_types_t linkage,
const std::string &name, module *parent)
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) {
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) {
unsigned num_params = fn_ty_->get_num_params();
if(parent)
parent->push_function(this);
// skip if no parameter
if(num_params == 0)
return;
@@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage,
type *param_ty = fn_ty_->get_param_ty(i);
args_[i] = argument::create(param_ty, "", this, i);
}
if(parent)
parent->push_function(this);
}
/* basic block */

View File

@@ -5,6 +5,7 @@
#include "triton/ir/instructions.h"
#include "triton/ir/constant.h"
#include "triton/ir/type.h"
#include "triton/ir/function.h"
namespace triton{
namespace ir{
@@ -68,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){
// Add incoming
void phi_node::add_incoming(value *v, basic_block *block){
assert(v && "PHI node got a null value!!");
resize_ops(get_num_operands() + 1);
blocks_.resize(get_num_operands() + 1);
set_incoming_value(get_num_operands() - 1, v);
@@ -79,6 +81,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
return new phi_node(ty, num_reserved, name, next);
}
//===----------------------------------------------------------------------===//
// call_inst classes
//===----------------------------------------------------------------------===//
std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
: instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
for(size_t i = 0; i < values.size(); i++)
set_operand(i, values.at(i));
}
call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
return new call_inst(fn, values, name, next);
}
// launch
launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
: instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
int k = 0;
if(grid.size() != 3)
throw std::runtime_error("grid must have 3 elements");
set_operand(k++, fn);
val_begin = k;
for(ir::value* v: values)
set_operand(k++, v);
val_end = k;
grid_begin = k;
for(ir::value* g: grid)
set_operand(k++, g);
grid_end = k;
set_operand(k++, num_warps);
}
ir::function* launch_inst::get_fn() {
return (ir::function*)get_operand(0);
}
std::vector<ir::value*> launch_inst::get_values() {
std::vector<ir::value*> ret;
for(int i = val_begin; i < val_end; i++)
ret.push_back(get_operand(i));
return ret;
}
std::vector<ir::value*> launch_inst::get_grid() {
std::vector<ir::value*> ret;
for(int i = grid_begin; i < grid_end; i++)
ret.push_back(get_operand(i));
return ret;
}
ir::value* launch_inst::get_num_warps() {
return get_operand(grid_end);
}
launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
return new launch_inst(fn, values, grid, num_warps, name, next);
}
//===----------------------------------------------------------------------===//
// binary_operator classes
@@ -257,6 +323,21 @@ unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &nam
set_operand(0, v);
}
//===----------------------------------------------------------------------===//
// dequantize_inst classes
//===----------------------------------------------------------------------===//
dequantize_inst::dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next)
: instruction(ty, INST_DEQUANTIZE, 3, name, next) {
set_operand(0, v);
set_operand(1, scale);
set_operand(2, shift);
}
dequantize_inst *dequantize_inst::create(value *arg, value *scale, value *shift, type *ty, const std::string &name, instruction *next){
return new dequantize_inst(ty, arg, scale, shift, name, next);
}
//===----------------------------------------------------------------------===//
// cast_inst classes
//===----------------------------------------------------------------------===//
@@ -324,7 +405,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
// return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
: terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
if(ret_val)
set_operand(0, ret_val);
}
@@ -429,13 +510,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
//===----------------------------------------------------------------------===//
// io_inst
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next)
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next), eviction_(eviction)
{ }
// load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, load_inst::CACHE_MODIFIER cache, EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile)
{ }
// load
@@ -492,35 +573,66 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask,
// store
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next)
{ }
// unmasked_store
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
: store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new unmasked_store_inst(ptr, val, name, next);
return new unmasked_store_inst(ptr, val, eviction, name, next);
}
// masked store
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next)
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
: store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, mask);
}
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, name, next);
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, eviction, name, next);
}
//===----------------------------------------------------------------------===//
// struct classes
//===----------------------------------------------------------------------===//
// insert value
insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
: instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
set_operand(0, val);
set_operand(1, elt);
}
insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
return new insert_value_inst(val, elt, idx, name, next);
}
// extract value
extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
: instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
set_operand(0, val);
}
extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
return new extract_value_inst(val, idx, name, next);
}
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
@@ -575,13 +687,16 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
}
//===----------------------------------------------------------------------===//
// matmul_inst classes
//===----------------------------------------------------------------------===//
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);
@@ -861,8 +976,7 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
@@ -881,27 +995,44 @@ prefetch_s_inst *prefetch_s_inst::create(context &ctx, value *arg, int inc, cons
return new prefetch_s_inst(ctx, arg, inc, name, next);
}
//// nv_dynamic_program_idx
//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
// global timer
globaltimer_inst::globaltimer_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_GLOBALTIMER, 0, name, next) { }
//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
// return new make_range_dyn(ty, name, next);
//}
globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name, instruction *next) {
return new globaltimer_inst(ctx, name, next);
}
//// nv_static_program_idx
//make_range_sta::make_range_sta(make_range *range)
// : constant(range->get_type(), 0), range_(range) { }
// extern elementwise
extern_elementwise_inst::extern_elementwise_inst(
context &ctx, const std::vector<value *> &args, type *ret_ty,
const std::string &lib_name, const std::string &lib_path,
const std::string &symbol_name, const std::string &name, instruction *next)
: instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), name, next),
lib_name_(lib_name),
lib_path_(lib_path),
symbol_name_(symbol_name) {
for (size_t i = 0; i < args.size(); i++) {
set_operand(i, args[i]);
}
}
//make_range* make_range_sta::get_range() const
//{ return range_; }
extern_elementwise_inst *extern_elementwise_inst::create(
context &ctx, const std::vector<value *> &args, type *ret_ty,
const std::string &lib_name, const std::string &lib_path,
const std::string &symbol_name, const std::string &name,
instruction *next) {
return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path,
symbol_name, name, next);
}
//make_range_sta* make_range_sta::get(make_range* range) {
// static std::map<make_range*, make_range_sta*> cache;
// if(cache.find(range) == cache.end())
// cache.insert({range, new make_range_sta(range)});
// return cache.at(range);
//}
// clock
clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next)
: instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { }
clock_inst* clock_inst::create(context &ctx, const std::string &name, instruction *next) {
return new clock_inst(ctx, name, next);
}
// make_range

View File

@@ -3,10 +3,10 @@
namespace triton{
namespace ir{
metadata::metadata(kind_t kind, unsigned value)
metadata::metadata(kind_t kind, std::vector<unsigned> value)
: kind_(kind), value_(value) { }
metadata* metadata::get(kind_t kind, unsigned value) {
metadata* metadata::get(kind_t kind, std::vector<unsigned> value) {
return new metadata(kind, value);
}

View File

@@ -9,11 +9,16 @@
namespace triton{
namespace ir{
void module::reset_ret_ty(const std::string& name, type* ty) {
get_function(name)->get_fn_type()->reset_ret_ty(ty);
}
/* functions */
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
function *&fn = (function*&)symbols_[name];
if(fn == nullptr)
return fn = function::create(ty, global_value::external, name, this);
if(fn == nullptr){
fn = function::create(ty, global_value::external, name, this);
}
return fn;
}

View File

@@ -92,7 +92,7 @@ public:
//-------------------------
void SlotTracker::process_module() {
// Nothing to do at the moment.
// Create slots for global variable & unamed functions & ...
// Create slots for global variable & unnamed functions & ...
module_processed = true;
}

View File

@@ -27,7 +27,7 @@ unsigned type::get_primitive_size_in_bits() const {
case BF16TyID: return 16;
case FP32TyID: return 32;
case FP64TyID: return 64;
case IntegerTyID: return ((integer_type*)(this))->get_bitwidth();
case IntegerTyID: return std::max<int>(8, ((integer_type*)(this))->get_bitwidth());
case BlockTyID: return ((block_type*)(this))->get_bitwidth();
default: return 0;
}
@@ -174,7 +174,26 @@ bool composite_type::index_valid(value *idx) const{
}
//===----------------------------------------------------------------------===//
// tile_type class
// struct_type class
//===----------------------------------------------------------------------===//
struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed)
: composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) {
contained_tys_ = tys;
}
struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) {
assert(tys.size());
context_impl* impl = tys[0]->get_context().p_impl.get();
struct_type *& entry = impl->struct_tys[tys];
if(!entry)
entry = new struct_type(tys, is_packed);
return entry;
}
//===----------------------------------------------------------------------===//
// block_type class
//===----------------------------------------------------------------------===//
block_type::block_type(type *ty, const block_shapes_t &shapes)

View File

@@ -43,6 +43,15 @@ std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
return result;
}
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::post_order(fn)){
auto inst_list = block->get_inst_list();
for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++)
do_work(*it);
}
}
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::reverse_post_order(fn))

View File

@@ -1,5 +1,6 @@
#include <cassert>
#include <iostream>
#include <algorithm>
#include "triton/ir/value.h"
#include "triton/ir/instructions.h"
@@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){
}
void value::add_use(user *arg) {
users_.insert(arg);
users_.push_back(arg);
}
value::users_t::iterator value::erase_use(user *arg){
auto it = users_.find(arg);
auto it = std::find(users_.begin(), users_.end(), arg);
if(it == users_.end())
return it;
return users_.erase(it);

View File

@@ -40,7 +40,7 @@ def bench_matmul(M, N, K, block, layout_mode, op_mode, AT, BT, dtype, provider,
# create op
tflops = lambda ms: num_flops / ms * 1e3
if provider == 'triton':
op = triton.ops.blocksparse.matmul(layout, block, op_mode, trans_a=AT, trans_b=BT)
op = triton.ops.blocksparse.matmul(layout, block, op_mode, device="cuda", trans_a=AT, trans_b=BT)
# inputs
a = triton.testing.sparsify_tensor(a, layout, block) if op_mode == 'dsd' else a
b = triton.testing.sparsify_tensor(b, layout, block) if op_mode == 'dds' else b
@@ -83,7 +83,7 @@ def bench_softmax(M, N, block, layout_mode, dtype, provider, warmup=10, rep=50):
a = torch.randn((Z, H, M, N), dtype=dtype, device='cuda')
if provider == 'triton':
a = triton.testing.sparsify_tensor(a, layout, block)
op = triton.ops.blocksparse.softmax(layout, block)
op = triton.ops.blocksparse.softmax(layout, block, device="cuda")
gbps = lambda ms: (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
mean_ms, min_ms, max_ms = triton.testing.do_bench(lambda: op(a), warmup=warmup, rep=rep)
return gbps(mean_ms), gbps(min_ms), gbps(max_ms)

View File

@@ -5,7 +5,7 @@ import triton
def rounded_linspace(low, high, steps, div):
ret = torch.linspace(low, high, steps)
ret = (ret.int() + div - 1) // div * div
ret = torch.div(ret.int() + div - 1, div, rounding_mode='trunc') * div
ret = torch.unique(ret)
return list(map(int, ret))

View File

@@ -7,40 +7,75 @@ import shutil
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
from distutils.version import LooseVersion
from typing import NamedTuple
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def get_llvm():
# tries to find system LLVM
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
def get_build_type():
if check_env_flag("DEBUG"):
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
else:
return "Release"
def use_system_llvm():
if platform.system() == "Windows":
return True
versions = ['-11.0', '-11', '-11-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return '', ''
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
if not os.path.exists(llvm_library_dir):
try:
shutil.rmtree(os.path.join(dir, name))
except Exception:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
ftpstream = urllib.request.urlopen(url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
file.extractall(path=dir)
return llvm_include_dir, llvm_library_dir
return any(p is not None for p in paths)
def get_thirdparty_packages(triton_cache_path):
class Package(NamedTuple):
package: str
name: str
url: str
test_file: str
include_flag: str
lib_flag: str
packages = [
Package("pybind11", "pybind11-2.10.0", "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz", "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "")
]
if not use_system_llvm():
# download LLVM if no suitable system LLVM is installed
packages.append(
Package("llvm", "clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04", "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR")
)
thirdparty_cmake_args = []
for p in packages:
package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name)
test_file_path = os.path.join(package_dir, p.test_file)
if not os.path.exists(test_file_path):
try:
shutil.rmtree(package_root_dir)
except Exception:
pass
os.makedirs(package_root_dir, exist_ok=True)
print('downloading and extracting {} ...'.format(p.url))
ftpstream = urllib.request.urlopen(p.url)
file = tarfile.open(fileobj=ftpstream, mode="r|*")
file.extractall(path=package_root_dir)
if p.include_flag:
thirdparty_cmake_args.append("-D{}={}/include".format(p.include_flag, package_dir))
if p.lib_flag:
thirdparty_cmake_args.append("-D{}={}/lib".format(p.lib_flag, package_dir))
return thirdparty_cmake_args
class CMakeExtension(Extension):
@@ -78,31 +113,24 @@ class CMakeBuild(build_ext):
self.build_extension(ext)
def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm()
self.debug = True
triton_cache_path = os.path.join(os.environ["HOME"], ".triton")
thirdparty_cmake_args = get_thirdparty_packages(triton_cache_path)
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'
llvm_build_dir = os.path.join(tempfile.gettempdir(), "llvm-" + build_suffix)
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
python_include_dirs = [distutils.sysconfig.get_python_inc()]
cmake_args = [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
"-DBUILD_PYTHON_MODULE=ON",
"-DLLVM_INCLUDE_DIRS=" + llvm_include_dir,
"-DLLVM_LIBRARY_DIR=" + llvm_library_dir,
# '-DPYTHON_EXECUTABLE=' + sys.executable,
# '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DTRITON_LLVM_BUILD_DIR=" + llvm_build_dir,
"-DPYTHON_INCLUDE_DIRS=" + ";".join(python_include_dirs)
]
] + thirdparty_cmake_args
# configuration
cfg = "Debug" if self.debug else "Release"
cfg = get_build_type()
build_args = ["--config", cfg]
if platform.system() == "Windows":
@@ -127,13 +155,17 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/runtime", "triton/tools", "triton/ops", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"filelock",
"torch",
],
package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]},
package_data={
"triton/ops": ["*.c"],
"triton/ops/blocksparse": ["*.c"],
"triton/language": ["*.bc"],
},
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={"build_ext": CMakeBuild},

View File

@@ -232,7 +232,7 @@ ir::value *store(ir::value *ptr, ir::value *val, std::optional<ir::value *> _mas
----------------------------------------------*/
std::string dot_docstr = R"pbdoc(
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
The two blocks must be two dimensions and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {`float16`, `float32`}
@@ -253,7 +253,7 @@ ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
std::string where_docstr = R"pbdoc(
Returns a block of elements from either `x` or `y`, depending on `condition`.
Note that `x` and `y` are always evaluated regardless of the value of `condition`.
If you want to avoid unintented memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
If you want to avoid unintended memory operations, use the `mask` arguments in `triton.load` and `triton.store` instead.
:param condition: When True (nonzero), yield x, otherwise yield y.
:type condition: Block of triton.bool
@@ -353,9 +353,6 @@ ir::value *sqrt(ir::value *input, ir::builder *builder) {
return builder->create_sqrt(input);
};
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
@@ -367,6 +364,9 @@ ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder
throw_not_int_or_float(name);
}
/*----------------------------------------------
definition of triton.min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value of `input`.
)pbdoc";
@@ -374,6 +374,16 @@ ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
};
/*----------------------------------------------
definition of triton.arg_min
----------------------------------------------*/
std::string min_docstr = R"pbdoc(
Returns the minimum value's index of `input`.
)pbdoc";
ir::value *argmin(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmin", ir::reduce_inst::ARGFMIN, ir::reduce_inst::ARGMIN);
};
/*----------------------------------------------
definition of triton.max
----------------------------------------------*/
@@ -384,6 +394,16 @@ ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
};
/*----------------------------------------------
definition of triton.arg_max
----------------------------------------------*/
std::string max_docstr = R"pbdoc(
Returns the maximum value's index of `input`.
)pbdoc";
ir::value *argmax(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "argmax", ir::reduce_inst::ARGFMAX, ir::reduce_inst::ARGMAX);
};
/*----------------------------------------------
definition of triton.sum
----------------------------------------------*/

View File

@@ -1,493 +0,0 @@
/*
pybind11/attr.h: Infrastructure for processing custom
type and function attributes
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "cast.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// \addtogroup annotations
/// @{
/// Annotation for methods
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
/// Annotation for operators
struct is_operator { };
/// Annotation for parent scope
struct scope { handle value; scope(const handle &s) : value(s) { } };
/// Annotation for documentation
struct doc { const char *value; doc(const char *value) : value(value) { } };
/// Annotation for function names
struct name { const char *value; name(const char *value) : value(value) { } };
/// Annotation indicating that a function is an overload associated with a given "sibling"
struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
/// Annotation indicating that a class derives from another given type
template <typename T> struct base {
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
base() { }
};
/// Keep patient alive while nurse lives
template <size_t Nurse, size_t Patient> struct keep_alive { };
/// Annotation indicating that a class is involved in a multiple inheritance relationship
struct multiple_inheritance { };
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
struct dynamic_attr { };
/// Annotation which enables the buffer protocol for a type
struct buffer_protocol { };
/// Annotation which requests that a special metaclass is created for a type
struct metaclass {
handle value;
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
metaclass() {}
/// Override pybind11's default metaclass
explicit metaclass(handle value) : value(value) { }
};
/// Annotation that marks a class as local to the module:
struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
/// Annotation to mark enums as an arithmetic type
struct arithmetic { };
/** \rst
A call policy which places one or more guard variables (``Ts...``) around the function call.
For example, this definition:
.. code-block:: cpp
m.def("foo", foo, py::call_guard<T>());
is equivalent to the following pseudocode:
.. code-block:: cpp
m.def("foo", [](args...) {
T scope_guard;
return foo(args...); // forwarded arguments
});
\endrst */
template <typename... Ts> struct call_guard;
template <> struct call_guard<> { using type = detail::void_type; };
template <typename T>
struct call_guard<T> {
static_assert(std::is_default_constructible<T>::value,
"The guard type must be default constructible");
using type = T;
};
template <typename T, typename... Ts>
struct call_guard<T, Ts...> {
struct type {
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
typename call_guard<Ts...>::type next{};
};
};
/// @} annotations
NAMESPACE_BEGIN(detail)
/* Forward declarations */
enum op_id : int;
enum op_type : int;
struct undefined_t;
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
/// Internal data structure which holds metadata about a keyword argument
struct argument_record {
const char *name; ///< Argument name
const char *descr; ///< Human-readable version of the argument value
handle value; ///< Associated Python object
bool convert : 1; ///< True if the argument is allowed to convert when loading
bool none : 1; ///< True if None is allowed when loading
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
: name(name), descr(descr), value(value), convert(convert), none(none) { }
};
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
struct function_record {
function_record()
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
/// Function name
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
// User-specified documentation string
char *doc = nullptr;
/// Human-readable version of the function signature
char *signature = nullptr;
/// List of registered keyword arguments
std::vector<argument_record> args;
/// Pointer to lambda function which converts arguments and performs the actual call
handle (*impl) (function_call &) = nullptr;
/// Storage for the wrapped function pointer and captured data, if any
void *data[3] = { };
/// Pointer to custom destructor for 'data' (if needed)
void (*free_data) (function_record *ptr) = nullptr;
/// Return value policy associated with this function
return_value_policy policy = return_value_policy::automatic;
/// True if name == '__init__'
bool is_constructor : 1;
/// True if this is a new-style `__init__` defined in `detail/init.h`
bool is_new_style_constructor : 1;
/// True if this is a stateless function pointer
bool is_stateless : 1;
/// True if this is an operator (__add__), etc.
bool is_operator : 1;
/// True if the function has a '*args' argument
bool has_args : 1;
/// True if the function has a '**kwargs' argument
bool has_kwargs : 1;
/// True if this is a method
bool is_method : 1;
/// Number of arguments (including py::args and/or py::kwargs, if present)
std::uint16_t nargs;
/// Python method object
PyMethodDef *def = nullptr;
/// Python handle to the parent scope (a class or a module)
handle scope;
/// Python handle to the sibling function representing an overload chain
handle sibling;
/// Pointer to next overload
function_record *next = nullptr;
};
/// Special data structure which (temporarily) holds metadata about a bound class
struct type_record {
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false),
default_holder(true), module_local(false) { }
/// Handle to the parent scope
handle scope;
/// Name of the class
const char *name = nullptr;
// Pointer to RTTI type_info data structure
const std::type_info *type = nullptr;
/// How large is the underlying C++ type?
size_t type_size = 0;
/// What is the alignment of the underlying C++ type?
size_t type_align = 0;
/// How large is the type's holder?
size_t holder_size = 0;
/// The global operator new can be overridden with a class-specific variant
void *(*operator_new)(size_t) = nullptr;
/// Function pointer to class_<..>::init_instance
void (*init_instance)(instance *, const void *) = nullptr;
/// Function pointer to class_<..>::dealloc
void (*dealloc)(detail::value_and_holder &) = nullptr;
/// List of base classes of the newly created type
list bases;
/// Optional docstring
const char *doc = nullptr;
/// Custom metaclass (optional)
handle metaclass;
/// Multiple inheritance marker
bool multiple_inheritance : 1;
/// Does the class manage a __dict__?
bool dynamic_attr : 1;
/// Does the class implement the buffer protocol?
bool buffer_protocol : 1;
/// Is the default (unique_ptr) holder type used?
bool default_holder : 1;
/// Is the class definition local to the module shared object?
bool module_local : 1;
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(base, false);
if (!base_info) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) +
"\" referenced unknown base type \"" + tname + "\"");
}
if (default_holder != base_info->default_holder) {
std::string tname(base.name());
detail::clean_type_id(tname);
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
(default_holder ? "does not have" : "has") +
" a non-default holder type while its base \"" + tname + "\" " +
(base_info->default_holder ? "does not" : "does"));
}
bases.append((PyObject *) base_info->type);
if (base_info->type->tp_dictoffset != 0)
dynamic_attr = true;
if (caster)
base_info->implicit_casts.emplace_back(type, caster);
}
};
inline function_call::function_call(const function_record &f, handle p) :
func(f), parent(p) {
args.reserve(f.nargs);
args_convert.reserve(f.nargs);
}
/// Tag for a new-style `__init__` defined in `detail/init.h`
struct is_new_style_constructor { };
/**
* Partial template specializations to process custom attributes provided to
* cpp_function_ and class_. These are either used to initialize the respective
* fields in the type_record and function_record data structures or executed at
* runtime to deal with custom call policies (e.g. keep_alive).
*/
template <typename T, typename SFINAE = void> struct process_attribute;
template <typename T> struct process_attribute_default {
/// Default implementation: do nothing
static void init(const T &, function_record *) { }
static void init(const T &, type_record *) { }
static void precall(function_call &) { }
static void postcall(function_call &, handle) { }
};
/// Process an attribute specifying the function's name
template <> struct process_attribute<name> : process_attribute_default<name> {
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
};
/// Process an attribute specifying the function's docstring
template <> struct process_attribute<doc> : process_attribute_default<doc> {
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
};
/// Process an attribute specifying the function's docstring (provided as a C-style string)
template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
};
template <> struct process_attribute<char *> : process_attribute<const char *> { };
/// Process an attribute indicating the function's return value policy
template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
};
/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
};
/// Process an attribute which indicates that this function is a method
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
};
/// Process an attribute which indicates the parent scope of a method
template <> struct process_attribute<scope> : process_attribute_default<scope> {
static void init(const scope &s, function_record *r) { r->scope = s.value; }
};
/// Process an attribute which indicates that this function is an operator
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
};
template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
};
/// Process a keyword argument attribute (*without* a default value)
template <> struct process_attribute<arg> : process_attribute_default<arg> {
static void init(const arg &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
}
};
/// Process a keyword argument attribute (*with* a default value)
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
static void init(const arg_v &a, function_record *r) {
if (r->is_method && r->args.empty())
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
if (!a.value) {
#if !defined(NDEBUG)
std::string descr("'");
if (a.name) descr += std::string(a.name) + ": ";
descr += a.type + "'";
if (r->is_method) {
if (r->name)
descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
else
descr += " in method of '" + (std::string) str(r->scope) + "'";
} else if (r->name) {
descr += " in function '" + (std::string) r->name + "'";
}
pybind11_fail("arg(): could not convert default argument "
+ descr + " into a Python object (type not registered yet?)");
#else
pybind11_fail("arg(): could not convert default argument "
"into a Python object (type not registered yet?). "
"Compile in debug mode for more information.");
#endif
}
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
}
};
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
template <typename T>
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
static void init(const handle &h, type_record *r) { r->bases.append(h); }
};
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
template <typename T>
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
};
/// Process a multiple inheritance attribute
template <>
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
};
template <>
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
};
template <>
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
};
template <>
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
};
template <>
struct process_attribute<module_local> : process_attribute_default<module_local> {
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
};
/// Process an 'arithmetic' attribute for enums (does nothing here)
template <>
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
template <typename... Ts>
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
/**
* Process a keep_alive call policy -- invokes keep_alive_impl during the
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
* otherwise
*/
template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
static void postcall(function_call &, handle) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void precall(function_call &) { }
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
};
/// Recursively iterate over variadic template arguments
template <typename... Args> struct process_attributes {
static void init(const Args&... args, function_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void init(const Args&... args, type_record *r) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
ignore_unused(unused);
}
static void precall(function_call &call) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
ignore_unused(unused);
}
static void postcall(function_call &call, handle fn_ret) {
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
ignore_unused(unused);
}
};
template <typename T>
using is_call_guard = is_instantiation<call_guard, T>;
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
template <typename... Extra>
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
/// Check the number of named arguments at compile time
template <typename... Extra,
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,108 +0,0 @@
/*
pybind11/buffer_info.h: Python buffer object interface
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
/// Information record describing a Python buffer object
struct buffer_info {
void *ptr = nullptr; // Pointer to the underlying storage
ssize_t itemsize = 0; // Size of individual items in bytes
ssize_t size = 0; // Total number of entries
std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
ssize_t ndim = 0; // Number of dimensions
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
buffer_info() { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)) {
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
for (size_t i = 0; i < (size_t) ndim; ++i)
size *= shape[i];
}
template <typename T>
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
: buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
template <typename T>
buffer_info(T *ptr, ssize_t size)
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
explicit buffer_info(Py_buffer *view, bool ownview = true)
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
this->view = view;
this->ownview = ownview;
}
buffer_info(const buffer_info &) = delete;
buffer_info& operator=(const buffer_info &) = delete;
buffer_info(buffer_info &&other) {
(*this) = std::move(other);
}
buffer_info& operator=(buffer_info &&rhs) {
ptr = rhs.ptr;
itemsize = rhs.itemsize;
size = rhs.size;
format = std::move(rhs.format);
ndim = rhs.ndim;
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
std::swap(view, rhs.view);
std::swap(ownview, rhs.ownview);
return *this;
}
~buffer_info() {
if (view && ownview) { PyBuffer_Release(view); delete view; }
}
private:
struct private_ctr_tag { };
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
Py_buffer *view = nullptr;
bool ownview = false;
};
NAMESPACE_BEGIN(detail)
template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
}
};
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

View File

@@ -1,162 +0,0 @@
/*
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <cmath>
#include <ctime>
#include <chrono>
#include <datetime.h>
// Backport the PyDateTime_DELTA functions from Python3.3 if required
#ifndef PyDateTime_DELTA_GET_DAYS
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
#endif
#ifndef PyDateTime_DELTA_GET_SECONDS
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
#endif
#ifndef PyDateTime_DELTA_GET_MICROSECONDS
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename type> class duration_caster {
public:
typedef typename type::rep rep;
typedef typename type::period period;
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
// If invoked with datetime.delta object
if (PyDelta_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
return true;
}
// If invoked with a float we assume it is seconds and convert
else if (PyFloat_Check(src.ptr())) {
value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
return true;
}
else return false;
}
// If this is a duration just return it back
static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
return src;
}
// If this is a time_point get the time_since_epoch
template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
return src.time_since_epoch();
}
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Use overloaded function to get our duration from our source
// Works out if it is a duration or time_point and get the duration
auto d = get_duration(src);
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using dd_t = duration<int, std::ratio<86400>>;
using ss_t = duration<int, std::ratio<1>>;
using us_t = duration<int, std::micro>;
auto dd = duration_cast<dd_t>(d);
auto subd = d - dd;
auto ss = duration_cast<ss_t>(subd);
auto us = duration_cast<us_t>(subd - ss);
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
};
// This is for casting times on the system clock into datetime.datetime instances
template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
public:
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
bool load(handle src, bool) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
if (!src) return false;
if (PyDateTime_Check(src.ptr())) {
std::tm cal;
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
cal.tm_isdst = -1;
value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
return true;
}
else return false;
}
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
using namespace std::chrono;
// Lazy initialise the PyDateTime import
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
std::time_t tt = system_clock::to_time_t(src);
// this function uses static memory so it's best to copy it out asap just in case
// otherwise other code that is using localtime may break this (not just python code)
std::tm localtime = *std::localtime(&tt);
// Declare these special duration types so the conversions happen with the correct primitive types (int)
using us_t = duration<int, std::micro>;
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
localtime.tm_mon + 1,
localtime.tm_mday,
localtime.tm_hour,
localtime.tm_min,
localtime.tm_sec,
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
}
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
};
// Other clocks that are not the system clock are not measured as datetime.datetime objects
// since they are not measured on calendar time. So instead we just make them timedeltas
// Or if they have passed us a time as a float we convert that
template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
: public duration_caster<std::chrono::time_point<Clock, Duration>> {
};
template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
: public duration_caster<std::chrono::duration<Rep, Period>> {
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,2 +0,0 @@
#include "detail/common.h"
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."

View File

@@ -1,65 +0,0 @@
/*
pybind11/complex.h: Complex number support
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <complex>
/// glibc defines I as a macro which breaks things, e.g., boost template names
#ifdef I
# undef I
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = { 'Z', c, '\0' };
static std::string format() { return std::string(value); }
};
#ifndef PYBIND11_CPP17
template <typename T> constexpr const char format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
#endif
NAMESPACE_BEGIN(detail)
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr bool value = true;
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
template <typename T> class type_caster<std::complex<T>> {
public:
bool load(handle src, bool convert) {
if (!src)
return false;
if (!convert && !PyComplex_Check(src.ptr()))
return false;
Py_complex result = PyComplex_AsCComplex(src.ptr());
if (result.real == -1.0 && PyErr_Occurred()) {
PyErr_Clear();
return false;
}
value = std::complex<T>((T) result.real, (T) result.imag);
return true;
}
static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
}
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,623 +0,0 @@
/*
pybind11/detail/class.h: Python C API implementation details for py::class_
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "../attr.h"
#include "../options.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if PY_VERSION_HEX >= 0x03030000
# define PYBIND11_BUILTIN_QUALNAME
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj)
#else
// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type
// signatures; in 3.3+ this macro expands to nothing:
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj)
#endif
inline PyTypeObject *type_incref(PyTypeObject *type) {
Py_INCREF(type);
return type;
}
#if !defined(PYPY_VERSION)
/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
return PyProperty_Type.tp_descr_get(self, cls, cls);
}
/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
return PyProperty_Type.tp_descr_set(self, cls, value);
}
/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
methods are modified to always use the object type instead of a concrete instance.
Return value: New reference. */
inline PyTypeObject *make_static_property_type() {
constexpr auto *name = "pybind11_static_property";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
if (!heap_type)
pybind11_fail("make_static_property_type(): error allocating type!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyProperty_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_descr_get = pybind11_static_get;
type->tp_descr_set = pybind11_static_set;
if (PyType_Ready(type) < 0)
pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
return type;
}
#else // PYPY
/** PyPy has some issues with the above C API, so we evaluate Python code instead.
This function will only be called once so performance isn't really a concern.
Return value: New reference. */
inline PyTypeObject *make_static_property_type() {
auto d = dict();
PyObject *result = PyRun_String(R"(\
class pybind11_static_property(property):
def __get__(self, obj, cls):
return property.__get__(self, cls, cls)
def __set__(self, obj, value):
cls = obj if isinstance(obj, type) else type(obj)
property.__set__(self, cls, value)
)", Py_file_input, d.ptr(), d.ptr()
);
if (result == nullptr)
throw error_already_set();
Py_DECREF(result);
return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
}
#endif // PYPY
/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
By default, Python replaces the `static_property` itself, but for wrapped C++ types
we need to call `static_property.__set__()` in order to propagate the new value to
the underlying C++ data structure. */
extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) {
// Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
// descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
// The following assignment combinations are possible:
// 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
// 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
// 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
const auto static_prop = (PyObject *) get_internals().static_property_type;
const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop)
&& !PyObject_IsInstance(value, static_prop);
if (call_descr_set) {
// Call `static_property.__set__()` instead of replacing the `static_property`.
#if !defined(PYPY_VERSION)
return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
#else
if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
Py_DECREF(result);
return 0;
} else {
return -1;
}
#endif
} else {
// Replace existing attribute.
return PyType_Type.tp_setattro(obj, name, value);
}
}
#if PY_MAJOR_VERSION >= 3
/**
* Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
* methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
* when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
* to do a special case bypass for PyInstanceMethod_Types.
*/
extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
if (descr && PyInstanceMethod_Check(descr)) {
Py_INCREF(descr);
return descr;
}
else {
return PyType_Type.tp_getattro(obj, name);
}
}
#endif
/** This metaclass is assigned by default to all pybind11 types and is required in order
for static properties to function correctly. Users may override this using `py::metaclass`.
Return value: New reference. */
inline PyTypeObject* make_default_metaclass() {
constexpr auto *name = "pybind11_type";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
if (!heap_type)
pybind11_fail("make_default_metaclass(): error allocating metaclass!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyType_Type);
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_setattro = pybind11_meta_setattro;
#if PY_MAJOR_VERSION >= 3
type->tp_getattro = pybind11_meta_getattro;
#endif
if (PyType_Ready(type) < 0)
pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
return type;
}
/// For multiple inheritance types we need to recursively register/deregister base pointers for any
/// base classes with pointers that are difference from the instance value pointer so that we can
/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs.
inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self,
bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
for (auto &c : parent_tinfo->implicit_casts) {
if (c.first == tinfo->cpptype) {
auto *parentptr = c.second(valueptr);
if (parentptr != valueptr)
f(parentptr, self);
traverse_offset_bases(parentptr, parent_tinfo, self, f);
break;
}
}
}
}
}
inline bool register_instance_impl(void *ptr, instance *self) {
get_internals().registered_instances.emplace(ptr, self);
return true; // unused, but gives the same signature as the deregister func
}
inline bool deregister_instance_impl(void *ptr, instance *self) {
auto &registered_instances = get_internals().registered_instances;
auto range = registered_instances.equal_range(ptr);
for (auto it = range.first; it != range.second; ++it) {
if (Py_TYPE(self) == Py_TYPE(it->second)) {
registered_instances.erase(it);
return true;
}
}
return false;
}
inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
register_instance_impl(valptr, self);
if (!tinfo->simple_ancestors)
traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
}
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
bool ret = deregister_instance_impl(valptr, self);
if (!tinfo->simple_ancestors)
traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
return ret;
}
/// Instance creation function for all pybind11 types. It allocates the internal instance layout for
/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast
/// to a reference or pointer), and initialization is done by an `__init__` function.
inline PyObject *make_new_instance(PyTypeObject *type) {
#if defined(PYPY_VERSION)
// PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited
// object is a a plain Python type (i.e. not derived from an extension type). Fix it.
ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
if (type->tp_basicsize < instance_size) {
type->tp_basicsize = instance_size;
}
#endif
PyObject *self = type->tp_alloc(type, 0);
auto inst = reinterpret_cast<instance *>(self);
// Allocate the value/holder internals:
inst->allocate_layout();
inst->owned = true;
return self;
}
/// Instance creation function for all pybind11 types. It only allocates space for the
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
return make_new_instance(type);
}
/// An `__init__` function constructs the C++ object. Users should provide at least one
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
/// following default function will be used which simply throws an exception.
extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
PyTypeObject *type = Py_TYPE(self);
std::string msg;
#if defined(PYPY_VERSION)
msg += handle((PyObject *) type).attr("__module__").cast<std::string>() + ".";
#endif
msg += type->tp_name;
msg += ": No constructor defined!";
PyErr_SetString(PyExc_TypeError, msg.c_str());
return -1;
}
inline void add_patient(PyObject *nurse, PyObject *patient) {
auto &internals = get_internals();
auto instance = reinterpret_cast<detail::instance *>(nurse);
instance->has_patients = true;
Py_INCREF(patient);
internals.patients[nurse].push_back(patient);
}
inline void clear_patients(PyObject *self) {
auto instance = reinterpret_cast<detail::instance *>(self);
auto &internals = get_internals();
auto pos = internals.patients.find(self);
assert(pos != internals.patients.end());
// Clearing the patients can cause more Python code to run, which
// can invalidate the iterator. Extract the vector of patients
// from the unordered_map first.
auto patients = std::move(pos->second);
internals.patients.erase(pos);
instance->has_patients = false;
for (PyObject *&patient : patients)
Py_CLEAR(patient);
}
/// Clears all internal data from the instance and removes it from registered instances in
/// preparation for deallocation.
inline void clear_instance(PyObject *self) {
auto instance = reinterpret_cast<detail::instance *>(self);
// Deallocate any values/holders, if present:
for (auto &v_h : values_and_holders(instance)) {
if (v_h) {
// We have to deregister before we call dealloc because, for virtual MI types, we still
// need to be able to get the parent pointers.
if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type))
pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
if (instance->owned || v_h.holder_constructed())
v_h.type->dealloc(v_h);
}
}
// Deallocate the value/holder layout internals:
instance->deallocate_layout();
if (instance->weakrefs)
PyObject_ClearWeakRefs(self);
PyObject **dict_ptr = _PyObject_GetDictPtr(self);
if (dict_ptr)
Py_CLEAR(*dict_ptr);
if (instance->has_patients)
clear_patients(self);
}
/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
/// to destroy the C++ object itself, while the rest is Python bookkeeping.
extern "C" inline void pybind11_object_dealloc(PyObject *self) {
clear_instance(self);
auto type = Py_TYPE(self);
type->tp_free(self);
// `type->tp_dealloc != pybind11_object_dealloc` means that we're being called
// as part of a derived type's dealloc, in which case we're not allowed to decref
// the type here. For cross-module compatibility, we shouldn't compare directly
// with `pybind11_object_dealloc`, but with the common one stashed in internals.
auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base;
if (type->tp_dealloc == pybind11_object_type->tp_dealloc)
Py_DECREF(type);
}
/** Create the type which can be used as a common base for all classes. This is
needed in order to satisfy Python's requirements for multiple inheritance.
Return value: New reference. */
inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
constexpr auto *name = "pybind11_object";
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
if (!heap_type)
pybind11_fail("make_object_base_type(): error allocating type!");
heap_type->ht_name = name_obj.inc_ref().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = name_obj.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = name;
type->tp_base = type_incref(&PyBaseObject_Type);
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
type->tp_new = pybind11_object_new;
type->tp_init = pybind11_object_init;
type->tp_dealloc = pybind11_object_dealloc;
/* Support weak references (needed for the keep_alive feature) */
type->tp_weaklistoffset = offsetof(instance, weakrefs);
if (PyType_Ready(type) < 0)
pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string());
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
return (PyObject *) heap_type;
}
/// dynamic_attr: Support for `d = instance.__dict__`.
extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
if (!dict)
dict = PyDict_New();
Py_XINCREF(dict);
return dict;
}
/// dynamic_attr: Support for `instance.__dict__ = dict()`.
extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) {
if (!PyDict_Check(new_dict)) {
PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'",
Py_TYPE(new_dict)->tp_name);
return -1;
}
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_INCREF(new_dict);
Py_CLEAR(dict);
dict = new_dict;
return 0;
}
/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_VISIT(dict);
return 0;
}
/// dynamic_attr: Allow the GC to clear the dictionary.
extern "C" inline int pybind11_clear(PyObject *self) {
PyObject *&dict = *_PyObject_GetDictPtr(self);
Py_CLEAR(dict);
return 0;
}
/// Give instances of this type a `__dict__` and opt into garbage collection.
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
auto type = &heap_type->ht_type;
#if defined(PYPY_VERSION)
pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are "
"currently not supported in "
"conjunction with PyPy!");
#endif
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
type->tp_dictoffset = type->tp_basicsize; // place dict at the end
type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it
type->tp_traverse = pybind11_traverse;
type->tp_clear = pybind11_clear;
static PyGetSetDef getset[] = {
{const_cast<char*>("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}
};
type->tp_getset = getset;
}
/// buffer_protocol: Fill in the view as specified by flags.
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
type_info *tinfo = nullptr;
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
tinfo = get_type_info((PyTypeObject *) type.ptr());
if (tinfo && tinfo->get_buffer)
break;
}
if (view == nullptr || !tinfo || !tinfo->get_buffer) {
if (view)
view->obj = nullptr;
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
return -1;
}
std::memset(view, 0, sizeof(Py_buffer));
buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
view->obj = obj;
view->ndim = 1;
view->internal = info;
view->buf = info->ptr;
view->itemsize = info->itemsize;
view->len = view->itemsize;
for (auto s : info->shape)
view->len *= s;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)
view->format = const_cast<char *>(info->format.c_str());
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
view->ndim = (int) info->ndim;
view->strides = &info->strides[0];
view->shape = &info->shape[0];
}
Py_INCREF(view->obj);
return 0;
}
/// buffer_protocol: Release the resources of the buffer.
extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
delete (buffer_info *) view->internal;
}
/// Give this type a buffer interface.
inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
#if PY_MAJOR_VERSION < 3
heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER;
#endif
heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
}
/** Create a brand new Python type according to the `type_record` specification.
Return value: New reference. */
inline PyObject* make_new_python_type(const type_record &rec) {
auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
auto qualname = name;
if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) {
#if PY_MAJOR_VERSION >= 3
qualname = reinterpret_steal<object>(
PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
#else
qualname = str(rec.scope.attr("__qualname__").cast<std::string>() + "." + rec.name);
#endif
}
object module;
if (rec.scope) {
if (hasattr(rec.scope, "__module__"))
module = rec.scope.attr("__module__");
else if (hasattr(rec.scope, "__name__"))
module = rec.scope.attr("__name__");
}
auto full_name = c_str(
#if !defined(PYPY_VERSION)
module ? str(module).cast<std::string>() + "." + rec.name :
#endif
rec.name);
char *tp_doc = nullptr;
if (rec.doc && options::show_user_defined_docstrings()) {
/* Allocate memory for docstring (using PyObject_MALLOC, since
Python will free this later on) */
size_t size = strlen(rec.doc) + 1;
tp_doc = (char *) PyObject_MALLOC(size);
memcpy((void *) tp_doc, rec.doc, size);
}
auto &internals = get_internals();
auto bases = tuple(rec.bases);
auto base = (bases.size() == 0) ? internals.instance_base
: bases[0].ptr();
/* Danger zone: from now (and until PyType_Ready), make sure to
issue no Python C API calls which could potentially invoke the
garbage collector (the GC will call type_traverse(), which will in
turn find the newly constructed type in an invalid state) */
auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr()
: internals.default_metaclass;
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
if (!heap_type)
pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
heap_type->ht_name = name.release().ptr();
#ifdef PYBIND11_BUILTIN_QUALNAME
heap_type->ht_qualname = qualname.inc_ref().ptr();
#endif
auto type = &heap_type->ht_type;
type->tp_name = full_name;
type->tp_doc = tp_doc;
type->tp_base = type_incref((PyTypeObject *)base);
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
if (bases.size() > 0)
type->tp_bases = bases.release().ptr();
/* Don't inherit base __init__ */
type->tp_init = pybind11_object_init;
/* Supported protocols */
type->tp_as_number = &heap_type->as_number;
type->tp_as_sequence = &heap_type->as_sequence;
type->tp_as_mapping = &heap_type->as_mapping;
/* Flags */
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
#if PY_MAJOR_VERSION < 3
type->tp_flags |= Py_TPFLAGS_CHECKTYPES;
#endif
if (rec.dynamic_attr)
enable_dynamic_attributes(heap_type);
if (rec.buffer_protocol)
enable_buffer_protocol(heap_type);
if (PyType_Ready(type) < 0)
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
: !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
/* Register type with the parent scope */
if (rec.scope)
setattr(rec.scope, rec.name, (PyObject *) type);
else
Py_INCREF(type); // Keep it alive forever (reference leak)
if (module) // Needed by pydoc
setattr((PyObject *) type, "__module__", module);
PYBIND11_SET_OLDPY_QUALNAME(type, qualname);
return (PyObject *) type;
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,807 +0,0 @@
/*
pybind11/detail/common.h -- Basic macros
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#if !defined(NAMESPACE_BEGIN)
# define NAMESPACE_BEGIN(name) namespace name {
#endif
#if !defined(NAMESPACE_END)
# define NAMESPACE_END(name) }
#endif
// Robust support for some features and loading modules compiled against different pybind versions
// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
// the main `pybind11` namespace.
#if !defined(PYBIND11_NAMESPACE)
# ifdef __GNUG__
# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
# else
# define PYBIND11_NAMESPACE pybind11
# endif
#endif
#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER)
# if __cplusplus >= 201402L
# define PYBIND11_CPP14
# if __cplusplus >= 201703L
# define PYBIND11_CPP17
# endif
# endif
#elif defined(_MSC_VER) && __cplusplus == 199711L
// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented)
// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer
# if _MSVC_LANG >= 201402L
# define PYBIND11_CPP14
# if _MSVC_LANG > 201402L && _MSC_VER >= 1910
# define PYBIND11_CPP17
# endif
# endif
#endif
// Compiler version assertions
#if defined(__INTEL_COMPILER)
# if __INTEL_COMPILER < 1700
# error pybind11 requires Intel C++ compiler v17 or newer
# endif
#elif defined(__clang__) && !defined(__apple_build_version__)
# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3)
# error pybind11 requires clang 3.3 or newer
# endif
#elif defined(__clang__)
// Apple changes clang version macros to its Xcode version; the first Xcode release based on
// (upstream) clang 3.3 was Xcode 5:
# if __clang_major__ < 5
# error pybind11 requires Xcode/clang 5.0 or newer
# endif
#elif defined(__GNUG__)
# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8)
# error pybind11 requires gcc 4.8 or newer
# endif
#elif defined(_MSC_VER)
// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features
// (e.g. std::negation) added in 2015u3:
# if _MSC_FULL_VER < 190024210
# error pybind11 requires MSVC 2015 update 3 or newer
# endif
#endif
#if !defined(PYBIND11_EXPORT)
# if defined(WIN32) || defined(_WIN32)
# define PYBIND11_EXPORT __declspec(dllexport)
# else
# define PYBIND11_EXPORT __attribute__ ((visibility("default")))
# endif
#endif
#if defined(_MSC_VER)
# define PYBIND11_NOINLINE __declspec(noinline)
#else
# define PYBIND11_NOINLINE __attribute__ ((noinline))
#endif
#if defined(PYBIND11_CPP14)
# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]]
#else
# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason)))
#endif
#define PYBIND11_VERSION_MAJOR 2
#define PYBIND11_VERSION_MINOR 3
#define PYBIND11_VERSION_PATCH 0
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
#if defined(_MSC_VER)
# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4)
# define HAVE_ROUND 1
# endif
# pragma warning(push)
# pragma warning(disable: 4510 4610 4512 4005)
# if defined(_DEBUG)
# define PYBIND11_DEBUG_MARKER
# undef _DEBUG
# endif
#endif
#include <Python.h>
#include <frameobject.h>
#include <pythread.h>
#if defined(_WIN32) && (defined(min) || defined(max))
# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
#endif
#if defined(isalnum)
# undef isalnum
# undef isalpha
# undef islower
# undef isspace
# undef isupper
# undef tolower
# undef toupper
#endif
#if defined(_MSC_VER)
# if defined(PYBIND11_DEBUG_MARKER)
# define _DEBUG
# undef PYBIND11_DEBUG_MARKER
# endif
# pragma warning(pop)
#endif
#include <cstddef>
#include <cstring>
#include <forward_list>
#include <vector>
#include <string>
#include <stdexcept>
#include <unordered_set>
#include <unordered_map>
#include <memory>
#include <typeindex>
#include <type_traits>
#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION
#define PYBIND11_BYTES_CHECK PyBytes_Check
#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize
#define PYBIND11_BYTES_AS_STRING PyBytes_AsString
#define PYBIND11_BYTES_SIZE PyBytes_Size
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o)
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o)
#define PYBIND11_BYTES_NAME "bytes"
#define PYBIND11_STRING_NAME "str"
#define PYBIND11_SLICE_OBJECT PyObject
#define PYBIND11_FROM_STRING PyUnicode_FromString
#define PYBIND11_STR_TYPE ::pybind11::str
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_PLUGIN_IMPL(name) \
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
#else
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION
#define PYBIND11_BYTES_CHECK PyString_Check
#define PYBIND11_BYTES_FROM_STRING PyString_FromString
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize
#define PYBIND11_BYTES_AS_STRING PyString_AsString
#define PYBIND11_BYTES_SIZE PyString_Size
#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed.
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed.
#define PYBIND11_BYTES_NAME "str"
#define PYBIND11_STRING_NAME "unicode"
#define PYBIND11_SLICE_OBJECT PySliceObject
#define PYBIND11_FROM_STRING PyString_FromString
#define PYBIND11_STR_TYPE ::pybind11::bytes
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_PLUGIN_IMPL(name) \
static PyObject *pybind11_init_wrapper(); \
extern "C" PYBIND11_EXPORT void init##name() { \
(void)pybind11_init_wrapper(); \
} \
PyObject *pybind11_init_wrapper()
#endif
#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200
extern "C" {
struct _Py_atomic_address { void *value; };
PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current;
}
#endif
#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code
#define PYBIND11_STRINGIFY(x) #x
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
#define PYBIND11_CONCAT(first, second) first##second
#define PYBIND11_CHECK_PYTHON_VERSION \
{ \
const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \
"." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
const char *runtime_ver = Py_GetVersion(); \
size_t len = std::strlen(compiled_ver); \
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
PyErr_Format(PyExc_ImportError, \
"Python version mismatch: module was compiled for Python %s, " \
"but the interpreter version is incompatible: %s.", \
compiled_ver, runtime_ver); \
return nullptr; \
} \
}
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
/** \rst
***Deprecated in favor of PYBIND11_MODULE***
This macro creates the entry point that will be invoked when the Python interpreter
imports a plugin library. Please create a `module` in the function body and return
the pointer to its underlying Python object at the end.
.. code-block:: cpp
PYBIND11_PLUGIN(example) {
pybind11::module m("example", "pybind11 example plugin");
/// Set up bindings here
return m.ptr();
}
\endrst */
#define PYBIND11_PLUGIN(name) \
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
static PyObject *pybind11_init(); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
try { \
return pybind11_init(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
PyObject *pybind11_init()
/** \rst
This macro creates the entry point that will be invoked when the Python interpreter
imports an extension module. The module name is given as the fist argument and it
should not be in quotes. The second macro argument defines a variable of type
`py::module` which can be used to initialize the module.
.. code-block:: cpp
PYBIND11_MODULE(example, m) {
m.doc() = "pybind11 example module";
// Add bindings here
m.def("foo", []() {
return "Hello, World!";
});
}
\endrst */
#define PYBIND11_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
PYBIND11_PLUGIN_IMPL(name) { \
PYBIND11_CHECK_PYTHON_VERSION \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} PYBIND11_CATCH_INIT_EXCEPTIONS \
} \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
using ssize_t = Py_ssize_t;
using size_t = std::size_t;
/// Approach used to cast a previously unknown C++ instance into a Python object
enum class return_value_policy : uint8_t {
/** This is the default return value policy, which falls back to the policy
return_value_policy::take_ownership when the return value is a pointer.
Otherwise, it uses return_value::move or return_value::copy for rvalue
and lvalue references, respectively. See below for a description of what
all of these different policies do. */
automatic = 0,
/** As above, but use policy return_value_policy::reference when the return
value is a pointer. This is the default conversion policy for function
arguments when calling Python functions manually from C++ code (i.e. via
handle::operator()). You probably won't need to use this. */
automatic_reference,
/** Reference an existing object (i.e. do not create a new copy) and take
ownership. Python will call the destructor and delete operator when the
objects reference count reaches zero. Undefined behavior ensues when
the C++ side does the same.. */
take_ownership,
/** Create a new copy of the returned object, which will be owned by
Python. This policy is comparably safe because the lifetimes of the two
instances are decoupled. */
copy,
/** Use std::move to move the return value contents into a new instance
that will be owned by Python. This policy is comparably safe because the
lifetimes of the two instances (move source and destination) are
decoupled. */
move,
/** Reference an existing object, but do not take ownership. The C++ side
is responsible for managing the objects lifetime and deallocating it
when it is no longer used. Warning: undefined behavior will ensue when
the C++ side deletes an object that is still referenced and used by
Python. */
reference,
/** This policy only applies to methods and properties. It references the
object without taking ownership similar to the above
return_value_policy::reference policy. In contrast to that policy, the
function or propertys implicit this argument (called the parent) is
considered to be the the owner of the return value (the child).
pybind11 then couples the lifetime of the parent to the child via a
reference relationship that ensures that the parent cannot be garbage
collected while Python is still using the child. More advanced
variations of this scheme are also possible using combinations of
return_value_policy::reference and the keep_alive call policy */
reference_internal
};
NAMESPACE_BEGIN(detail)
inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); }
// Returns the size as a multiple of sizeof(void *), rounded up.
inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); }
/**
* The space to allocate for simple layout instance holders (see below) in multiple of the size of
* a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required
* to holder either a std::unique_ptr or std::shared_ptr (which is almost always
* sizeof(std::shared_ptr<T>)).
*/
constexpr size_t instance_simple_holder_in_ptrs() {
static_assert(sizeof(std::shared_ptr<int>) >= sizeof(std::unique_ptr<int>),
"pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs");
return size_in_ptrs(sizeof(std::shared_ptr<int>));
}
// Forward declarations
struct type_info;
struct value_and_holder;
struct nonsimple_values_and_holders {
void **values_and_holders;
uint8_t *status;
};
/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof')
struct instance {
PyObject_HEAD
/// Storage for pointers and holder; see simple_layout, below, for a description
union {
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
nonsimple_values_and_holders nonsimple;
};
/// Weak references
PyObject *weakrefs;
/// If true, the pointer is owned which means we're free to manage it with a holder.
bool owned : 1;
/**
* An instance has two possible value/holder layouts.
*
* Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer
* and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied
* whenever there is no python-side multiple inheritance of bound C++ types *and* the type's
* holder will fit in the default space (which is large enough to hold either a std::unique_ptr
* or std::shared_ptr).
*
* Non-simple layout applies when using custom holders that require more space than `shared_ptr`
* (which is typically the size of two pointers), or when multiple inheritance is used on the
* python side. Non-simple layout allocates the required amount of memory to have multiple
* bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
* pointer to allocated space of the required space to hold a sequence of value pointers and
* holders followed `status`, a set of bit flags (1 byte each), i.e.
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
* `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the
* beginning of the [bb...] block (but not independently allocated).
*
* Status bits indicate whether the associated holder is constructed (&
* status_holder_constructed) and whether the value pointer is registered (&
* status_instance_registered) in `registered_instances`.
*/
bool simple_layout : 1;
/// For simple layout, tracks whether the holder has been constructed
bool simple_holder_constructed : 1;
/// For simple layout, tracks whether the instance is registered in `registered_instances`
bool simple_instance_registered : 1;
/// If true, get_internals().patients has an entry for this object
bool has_patients : 1;
/// Initializes all of the above type/values/holders data (but not the instance values themselves)
void allocate_layout();
/// Destroys/deallocates all of the above
void deallocate_layout();
/// Returns the value_and_holder wrapper for the given type (or the first, if `find_type`
/// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if
/// `throw_if_missing` is false.
value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true);
/// Bit values for the non-simple status flags
static constexpr uint8_t status_holder_constructed = 1;
static constexpr uint8_t status_instance_registered = 2;
};
static_assert(std::is_standard_layout<instance>::value, "Internal error: `pybind11::detail::instance` is not standard layout!");
/// from __cpp_future__ import (convenient aliases from C++14/17)
#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910)
using std::enable_if_t;
using std::conditional_t;
using std::remove_cv_t;
using std::remove_reference_t;
#else
template <bool B, typename T = void> using enable_if_t = typename std::enable_if<B, T>::type;
template <bool B, typename T, typename F> using conditional_t = typename std::conditional<B, T, F>::type;
template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> using remove_reference_t = typename std::remove_reference<T>::type;
#endif
/// Index sequences
#if defined(PYBIND11_CPP14)
using std::index_sequence;
using std::make_index_sequence;
#else
template<size_t ...> struct index_sequence { };
template<size_t N, size_t ...S> struct make_index_sequence_impl : make_index_sequence_impl <N - 1, N - 1, S...> { };
template<size_t ...S> struct make_index_sequence_impl <0, S...> { typedef index_sequence<S...> type; };
template<size_t N> using make_index_sequence = typename make_index_sequence_impl<N>::type;
#endif
/// Make an index sequence of the indices of true arguments
template <typename ISeq, size_t, bool...> struct select_indices_impl { using type = ISeq; };
template <size_t... IPrev, size_t I, bool B, bool... Bs> struct select_indices_impl<index_sequence<IPrev...>, I, B, Bs...>
: select_indices_impl<conditional_t<B, index_sequence<IPrev..., I>, index_sequence<IPrev...>>, I + 1, Bs...> {};
template <bool... Bs> using select_indices = typename select_indices_impl<index_sequence<>, 0, Bs...>::type;
/// Backports of std::bool_constant and std::negation to accommodate older compilers
template <bool B> using bool_constant = std::integral_constant<bool, B>;
template <typename T> struct negation : bool_constant<!T::value> { };
template <typename...> struct void_t_impl { using type = void; };
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
/// Compile-time all/any/none of that check the boolean value of all template types
#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916))
template <class... Ts> using all_of = bool_constant<(Ts::value && ...)>;
template <class... Ts> using any_of = bool_constant<(Ts::value || ...)>;
#elif !defined(_MSC_VER)
template <bool...> struct bools {};
template <class... Ts> using all_of = std::is_same<
bools<Ts::value..., true>,
bools<true, Ts::value...>>;
template <class... Ts> using any_of = negation<all_of<negation<Ts>...>>;
#else
// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit
// at a slight loss of compilation efficiency).
template <class... Ts> using all_of = std::conjunction<Ts...>;
template <class... Ts> using any_of = std::disjunction<Ts...>;
#endif
template <class... Ts> using none_of = negation<any_of<Ts...>>;
template <class T, template<class> class... Predicates> using satisfies_all_of = all_of<Predicates<T>...>;
template <class T, template<class> class... Predicates> using satisfies_any_of = any_of<Predicates<T>...>;
template <class T, template<class> class... Predicates> using satisfies_none_of = none_of<Predicates<T>...>;
/// Strip the class from a method type
template <typename T> struct remove_class { };
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { typedef R type(A...); };
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...) const> { typedef R type(A...); };
/// Helper template to strip away type modifiers
template <typename T> struct intrinsic_type { typedef T type; };
template <typename T> struct intrinsic_type<const T> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T*> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T&> { typedef typename intrinsic_type<T>::type type; };
template <typename T> struct intrinsic_type<T&&> { typedef typename intrinsic_type<T>::type type; };
template <typename T, size_t N> struct intrinsic_type<const T[N]> { typedef typename intrinsic_type<T>::type type; };
template <typename T, size_t N> struct intrinsic_type<T[N]> { typedef typename intrinsic_type<T>::type type; };
template <typename T> using intrinsic_t = typename intrinsic_type<T>::type;
/// Helper type to replace 'void' in some expressions
struct void_type { };
/// Helper template which holds a list of types
template <typename...> struct type_list { };
/// Compile-time integer sum
#ifdef __cpp_fold_expressions
template <typename... Ts> constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); }
#else
constexpr size_t constexpr_sum() { return 0; }
template <typename T, typename... Ts>
constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); }
#endif
NAMESPACE_BEGIN(constexpr_impl)
/// Implementation details for constexpr functions
constexpr int first(int i) { return i; }
template <typename T, typename... Ts>
constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); }
constexpr int last(int /*i*/, int result) { return result; }
template <typename T, typename... Ts>
constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); }
NAMESPACE_END(constexpr_impl)
/// Return the index of the first type in Ts which satisfies Predicate<T>. Returns sizeof...(Ts) if
/// none match.
template <template<typename> class Predicate, typename... Ts>
constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate<Ts>::value...); }
/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
template <template<typename> class Predicate, typename... Ts>
constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate<Ts>::value...); }
/// Return the Nth element from the parameter pack
template <size_t N, typename T, typename... Ts>
struct pack_element { using type = typename pack_element<N - 1, Ts...>::type; };
template <typename T, typename... Ts>
struct pack_element<0, T, Ts...> { using type = T; };
/// Return the one and only type which matches the predicate, or Default if none match.
/// If more than one type matches the predicate, fail at compile-time.
template <template<typename> class Predicate, typename Default, typename... Ts>
struct exactly_one {
static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
static_assert(found <= 1, "Found more than one type matching the predicate");
static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
};
template <template<typename> class P, typename Default>
struct exactly_one<P, Default> { using type = Default; };
template <template<typename> class Predicate, typename Default, typename... Ts>
using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
/// Defer the evaluation of type T until types Us are instantiated
template <typename T, typename... /*Us*/> struct deferred_type { using type = T; };
template <typename T, typename... Us> using deferred_t = typename deferred_type<T, Us...>::type;
/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
/// unlike `std::is_base_of`)
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer
/// can be converted to a Base pointer)
template <typename Base, typename Derived> using is_accessible_base_of = bool_constant<
std::is_base_of<Base, Derived>::value && std::is_convertible<Derived *, Base *>::value>;
template <template<typename...> class Base>
struct is_template_base_of_impl {
template <typename... Us> static std::true_type check(Base<Us...> *);
static std::false_type check(...);
};
/// Check if a template is the base of a type. For example:
/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
template <template<typename...> class Base, typename T>
#if !defined(_MSC_VER)
using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr));
#else // MSVC2015 has trouble with decltype in template aliases
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr)) { };
#endif
/// Check if T is an instantiation of the template `Class`. For example:
/// `is_instantiation<shared_ptr, T>` is true if `T == shared_ptr<U>` where U can be anything.
template <template<typename...> class Class, typename T>
struct is_instantiation : std::false_type { };
template <template<typename...> class Class, typename... Us>
struct is_instantiation<Class, Class<Us...>> : std::true_type { };
/// Check if T is std::shared_ptr<U> where U can be anything
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
/// Check if T looks like an input iterator
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
template <typename T>
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
: std::true_type {};
template <typename T> using is_function_pointer = bool_constant<
std::is_pointer<T>::value && std::is_function<typename std::remove_pointer<T>::type>::value>;
template <typename F> struct strip_function_object {
using type = typename remove_class<decltype(&F::operator())>::type;
};
// Extracts the function signature from a function, function pointer or lambda.
template <typename Function, typename F = remove_reference_t<Function>>
using function_signature_t = conditional_t<
std::is_function<F>::value,
F,
typename conditional_t<
std::is_pointer<F>::value || std::is_member_pointer<F>::value,
std::remove_pointer<F>,
strip_function_object<F>
>::type
>;
/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member
/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used
/// in a place where passing a lambda makes sense.
template <typename T> using is_lambda = satisfies_none_of<remove_reference_t<T>,
std::is_function, std::is_pointer, std::is_member_pointer>;
/// Ignore that a variable is unused in compiler warnings
inline void ignore_unused(const int *) { }
/// Apply a function over each element of a parameter pack
#ifdef __cpp_fold_expressions
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...)
#else
using expand_side_effects = bool[];
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false }
#endif
NAMESPACE_END(detail)
/// C++ bindings of builtin Python exceptions
class builtin_exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
/// Set the error using the Python C API
virtual void set_error() const = 0;
};
#define PYBIND11_RUNTIME_EXCEPTION(name, type) \
class name : public builtin_exception { public: \
using builtin_exception::builtin_exception; \
name() : name("") { } \
void set_error() const override { PyErr_SetString(type, what()); } \
};
PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); }
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); }
template <typename T, typename SFINAE = void> struct format_descriptor { };
NAMESPACE_BEGIN(detail)
// Returns the index of the given type in the type char array below, and in the list in numpy.h
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
// complex float,double,long double. Note that the long double types only participate when long
// double is actually longer than double (it isn't under MSVC).
// NB: not only the string below but also complex.h and numpy.h rely on this order.
template <typename T, typename SFINAE = void> struct is_fmt_numeric { static constexpr bool value = false; };
template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr bool value = true;
static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
std::is_integral<T>::value ? detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value : 8 + (
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
};
NAMESPACE_END(detail)
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
static constexpr const char value[2] = { c, '\0' };
static std::string format() { return std::string(1, c); }
};
#if !defined(PYBIND11_CPP17)
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
#endif
/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;
error_scope() { PyErr_Fetch(&type, &value, &trace); }
~error_scope() { PyErr_Restore(type, value, trace); }
};
/// Dummy destructor wrapper that can be used to expose classes with a private destructor
struct nodelete { template <typename T> void operator()(T*) { } };
// overload_cast requires variable templates: C++14
#if defined(PYBIND11_CPP14)
#define PYBIND11_OVERLOAD_CAST 1
NAMESPACE_BEGIN(detail)
template <typename... Args>
struct overload_cast_impl {
constexpr overload_cast_impl() {} // MSVC 2015 needs this
template <typename Return>
constexpr auto operator()(Return (*pf)(Args...)) const noexcept
-> decltype(pf) { return pf; }
template <typename Return, typename Class>
constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept
-> decltype(pmf) { return pmf; }
template <typename Return, typename Class>
constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept
-> decltype(pmf) { return pmf; }
};
NAMESPACE_END(detail)
/// Syntax sugar for resolving overloaded function pointers:
/// - regular: static_cast<Return (Class::*)(Arg0, Arg1, Arg2)>(&Class::func)
/// - sweet: overload_cast<Arg0, Arg1, Arg2>(&Class::func)
template <typename... Args>
static constexpr detail::overload_cast_impl<Args...> overload_cast = {};
// MSVC 2015 only accepts this particular initialization syntax for this variable template.
/// Const member function selector for overload_cast
/// - regular: static_cast<Return (Class::*)(Arg) const>(&Class::func)
/// - sweet: overload_cast<Arg>(&Class::func, const_)
static constexpr auto const_ = std::true_type{};
#else // no overload_cast: providing something that static_assert-fails:
template <typename... Args> struct overload_cast {
static_assert(detail::deferred_t<std::false_type, Args...>::value,
"pybind11::overload_cast<...> requires compiling in C++14 mode");
};
#endif // overload_cast
NAMESPACE_BEGIN(detail)
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
// any standard container (or C-style array) supporting std::begin/std::end, any singleton
// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair.
template <typename T>
class any_container {
std::vector<T> v;
public:
any_container() = default;
// Can construct from a pair of iterators
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
any_container(It first, It last) : v(first, last) { }
// Implicit conversion constructor from any arbitrary container type with values convertible to T
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
// to explicitly allow implicit conversion from one:
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
// Avoid copying if given an rvalue vector of the correct type.
any_container(std::vector<T> &&v) : v(std::move(v)) { }
// Moves the vector out of an rvalue any_container
operator std::vector<T> &&() && { return std::move(v); }
// Dereferencing obtains a reference to the underlying vector
std::vector<T> &operator*() { return v; }
const std::vector<T> &operator*() const { return v; }
// -> lets you call methods on the underlying vector
std::vector<T> *operator->() { return &v; }
const std::vector<T> *operator->() const { return &v; }
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,100 +0,0 @@
/*
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
#if !defined(_MSC_VER)
# define PYBIND11_DESCR_CONSTEXPR static constexpr
#else
# define PYBIND11_DESCR_CONSTEXPR const
#endif
/* Concatenate type signatures at compile time */
template <size_t N, typename... Ts>
struct descr {
char text[N + 1];
constexpr descr() : text{'\0'} { }
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
template <size_t... Is>
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
template <typename... Chars>
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
return {{&typeid(Ts)..., nullptr}};
}
};
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
index_sequence<Is1...>, index_sequence<Is2...>) {
return {a.text[Is1]..., b.text[Is2]...};
}
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
}
template <size_t N>
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
constexpr descr<0> _(char const(&)[1]) { return {}; }
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
template <size_t...Digits> struct int_to_str<0, Digits...> {
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
};
// Ternary description (like std::conditional)
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
return _(text1);
}
template <bool B, size_t N1, size_t N2>
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
return _(text2);
}
template <bool B, typename T1, typename T2>
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
template <bool B, typename T1, typename T2>
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
return int_to_str<Size / 10, Size % 10>::digits;
}
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
constexpr descr<0> concat() { return {}; }
template <size_t N, typename... Ts>
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
template <size_t N, typename... Ts, typename... Args>
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
return d + _(", ") + concat(args...);
}
template <size_t N, typename... Ts>
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
return _("{") + descr + _("}");
}
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,335 +0,0 @@
/*
pybind11/detail/init.h: init factory function implementation and support code.
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "class.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <>
class type_caster<value_and_holder> {
public:
bool load(handle h, bool) {
value = reinterpret_cast<value_and_holder *>(h.ptr());
return true;
}
template <typename> using cast_op_type = value_and_holder &;
operator value_and_holder &() { return *value; }
static constexpr auto name = _<value_and_holder>();
private:
value_and_holder *value = nullptr;
};
NAMESPACE_BEGIN(initimpl)
inline void no_nullptr(void *ptr) {
if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
}
// Implementing functions for all forms of py::init<...> and py::init(...)
template <typename Class> using Cpp = typename Class::type;
template <typename Class> using Alias = typename Class::type_alias;
template <typename Class> using Holder = typename Class::holder_type;
template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
bool is_alias(Cpp<Class> *ptr) {
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
}
// Failing fallback version of the above for a no-alias class (always returns false)
template <typename /*Class*/>
constexpr bool is_alias(void *) { return false; }
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
// back to brace aggregate initiailization so that for aggregate initialization can be used with
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
// an alias to provide only a single Cpp factory function as long as the Alias can be
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
// inherit all the base class constructors.
template <typename Class>
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
value_and_holder &v_h, Cpp<Class> &&base) {
v_h.value_ptr() = new Alias<Class>(std::move(base));
}
template <typename Class>
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
value_and_holder &, Cpp<Class> &&) {
throw type_error("pybind11::init(): unable to convert returned instance to required "
"alias class: no `Alias<Class>(Class &&)` constructor available");
}
// Error-generating fallback for factories that don't match one of the below construction
// mechanisms.
template <typename Class>
void construct(...) {
static_assert(!std::is_same<Class, Class>::value /* always false */,
"pybind11::init(): init function must return a compatible pointer, "
"holder, or value");
}
// Pointer return v1: the factory function returns a class pointer for a registered class.
// If we don't need an alias (because this class doesn't have one, or because the final type is
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
// construct an Alias from the returned base instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
no_nullptr(ptr);
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
// We're going to try to construct an alias by moving the cpp type. Whether or not
// that succeeds, we still need to destroy the original cpp pointer (either the
// moved away leftover, if the alias construction works, or the value itself if we
// throw an error), but we can't just call `delete ptr`: it might have a special
// deleter, or might be shared_from_this. So we construct a holder around it as if
// it was a normal instance, then steal the holder away into a local variable; thus
// the holder and destruction happens when we leave the C++ scope, and the holder
// class gets to handle the destruction however it likes.
v_h.value_ptr() = ptr;
v_h.set_instance_registered(true); // To prevent init_instance from registering it
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
v_h.set_instance_registered(false);
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
} else {
// Otherwise the type isn't inherited, so we don't need an Alias
v_h.value_ptr() = ptr;
}
}
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
// ownership of the pointer.
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
no_nullptr(alias_ptr);
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
}
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
// derived type (through those holder's implicit conversion from derived class holder constructors).
template <typename Class>
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
auto *ptr = holder_helper<Holder<Class>>::get(holder);
// If we need an alias, check that the held pointer is actually an alias instance
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
"is not an alias instance");
v_h.value_ptr() = ptr;
v_h.type->init_instance(v_h.inst, &holder);
}
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
// need it, we simply move-construct the cpp value into a new instance.
template <typename Class>
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
static_assert(std::is_move_constructible<Cpp<Class>>::value,
"pybind11::init() return-by-value factory function requires a movable class");
if (Class::has_alias && need_alias)
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
else
v_h.value_ptr() = new Cpp<Class>(std::move(result));
}
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
// Alias instance (even if no the python-side inheritance is involved). The is intended for
// cases where Alias initialization is always desired.
template <typename Class>
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
static_assert(std::is_move_constructible<Alias<Class>>::value,
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
v_h.value_ptr() = new Alias<Class>(std::move(result));
}
// Implementing class for py::init<...>()
template <typename... Args>
struct constructor {
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
else
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias &&
!std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
};
// Implementing class for py::init_alias<...>()
template <typename... Args> struct alias_constructor {
template <typename Class, typename... Extra,
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
static void execute(Class &cl, const Extra&... extra) {
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
}, is_new_style_constructor(), extra...);
}
};
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc = void_type (*)(),
typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
struct factory;
// Specialization for py::init(Func)
template <typename Func, typename Return, typename... Args>
struct factory<Func, void_type (*)(), Return(Args...)> {
remove_reference_t<Func> class_factory;
factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
// The given class either has no alias or has no separate alias factory;
// this always constructs the class itself. If the class is registered with an alias
// type and an alias instance is needed (i.e. because the final type is a Python class
// inheriting from the C++ type) the returned value needs to either already be an alias
// instance, or the alias needs to be constructible from a `Class &&` argument.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
#if defined(PYBIND11_CPP14)
cl.def("__init__", [func = std::move(class_factory)]
#else
auto &func = class_factory;
cl.def("__init__", [func]
#endif
(value_and_holder &v_h, Args... args) {
construct<Class>(v_h, func(std::forward<Args>(args)...),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
};
// Specialization for py::init(Func, AliasFunc)
template <typename CFunc, typename AFunc,
typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
"pybind11::init(class_factory, alias_factory): class and alias factories "
"must have identical argument signatures");
remove_reference_t<CFunc> class_factory;
remove_reference_t<AFunc> alias_factory;
factory(CFunc &&c, AFunc &&a)
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
// The class factory is called when the `self` type passed to `__init__` is the direct
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra&... extra) && {
static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
"only be used if the class has an alias");
#if defined(PYBIND11_CPP14)
cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
#else
auto &class_func = class_factory;
auto &alias_func = alias_factory;
cl.def("__init__", [class_func, alias_func]
#endif
(value_and_holder &v_h, CArgs... args) {
if (Py_TYPE(v_h.inst) == v_h.type->type)
// If the instance type equals the registered type we don't have inheritance, so
// don't need the alias and can construct using the class function:
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
else
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
}, is_new_style_constructor(), extra...);
}
};
/// Set just the C++ state. Same as `__init__`.
template <typename Class, typename T>
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
construct<Class>(v_h, std::forward<T>(result), need_alias);
}
/// Set both the C++ and Python states
template <typename Class, typename T, typename O,
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
construct<Class>(v_h, std::move(result.first), need_alias);
setattr((PyObject *) v_h.inst, "__dict__", result.second);
}
/// Implementation for py::pickle(GetState, SetState)
template <typename Get, typename Set,
typename = function_signature_t<Get>, typename = function_signature_t<Set>>
struct pickle_factory;
template <typename Get, typename Set,
typename RetState, typename Self, typename NewInstance, typename ArgState>
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
"The type returned by `__getstate__` must be the same "
"as the argument accepted by `__setstate__`");
remove_reference_t<Get> get;
remove_reference_t<Set> set;
pickle_factory(Get get, Set set)
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
template <typename Class, typename... Extra>
void execute(Class &cl, const Extra &...extra) && {
cl.def("__getstate__", std::move(get));
#if defined(PYBIND11_CPP14)
cl.def("__setstate__", [func = std::move(set)]
#else
auto &func = set;
cl.def("__setstate__", [func]
#endif
(value_and_holder &v_h, ArgState state) {
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
Py_TYPE(v_h.inst) != v_h.type->type);
}, is_new_style_constructor(), extra...);
}
};
NAMESPACE_END(initimpl)
NAMESPACE_END(detail)
NAMESPACE_END(pybind11)

View File

@@ -1,293 +0,0 @@
/*
pybind11/detail/internals.h: Internal data structure and related functions
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "../pytypes.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
// Forward declarations
inline PyTypeObject *make_static_property_type();
inline PyTypeObject *make_default_metaclass();
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
// Thread Specific Storage (TSS) API.
#if PY_VERSION_HEX >= 0x03070000
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value))
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
#else
// Usually an int but a long on Cygwin64 with Python 3.x
# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
# if PY_MAJOR_VERSION < 3
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_delete_key_value(key)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
do { \
PyThread_delete_key_value((key)); \
PyThread_set_key_value((key), (value)); \
} while (false)
# else
# define PYBIND11_TLS_DELETE_VALUE(key) \
PyThread_set_key_value((key), nullptr)
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
PyThread_set_key_value((key), (value))
# endif
#endif
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
// which works. If not under a known-good stl, provide our own name-based hash and equality
// functions that use the type name.
#if defined(__GLIBCXX__)
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
using type_hash = std::hash<std::type_index>;
using type_equal_to = std::equal_to<std::type_index>;
#else
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
struct type_hash {
size_t operator()(const std::type_index &t) const {
size_t hash = 5381;
const char *ptr = t.name();
while (auto c = static_cast<unsigned char>(*ptr++))
hash = (hash * 33) ^ c;
return hash;
}
};
struct type_equal_to {
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
}
};
#endif
template <typename value_type>
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
struct overload_hash {
inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
size_t value = std::hash<const void *>()(v.first);
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
return value;
}
};
/// Internal data structure used to track registered instances and types.
/// Whenever binary incompatible changes are made to this structure,
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
struct internals {
type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
PyTypeObject *static_property_type;
PyTypeObject *default_metaclass;
PyObject *instance_base;
#if defined(WITH_THREAD)
PYBIND11_TLS_KEY_INIT(tstate);
PyInterpreterState *istate = nullptr;
#endif
};
/// Additional type information which does not fit into the PyTypeObject.
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
struct type_info {
PyTypeObject *type;
const std::type_info *cpptype;
size_t type_size, type_align, holder_size_in_ptrs;
void *(*operator_new)(size_t);
void (*init_instance)(instance *, const void *);
void (*dealloc)(value_and_holder &v_h);
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
void *get_buffer_data = nullptr;
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
/* A simple type never occurs as a (direct or indirect) parent
* of a class that makes use of multiple inheritance */
bool simple_type : 1;
/* True if there is no multiple inheritance in this type's inheritance tree */
bool simple_ancestors : 1;
/* for base vs derived holder_type checks */
bool default_holder : 1;
/* true if this is a type registered with py::module_local */
bool module_local : 1;
};
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
#define PYBIND11_INTERNALS_VERSION 3
#if defined(_DEBUG)
# define PYBIND11_BUILD_TYPE "_debug"
#else
# define PYBIND11_BUILD_TYPE ""
#endif
#if defined(WITH_THREAD)
# define PYBIND11_INTERNALS_KIND ""
#else
# define PYBIND11_INTERNALS_KIND "_without_thread"
#endif
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
/// Each module locally stores a pointer to the `internals` data. The data
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
inline internals **&get_internals_pp() {
static internals **internals_pp = nullptr;
return internals_pp;
}
/// Return a reference to the current `internals` data
PYBIND11_NOINLINE inline internals &get_internals() {
auto **&internals_pp = get_internals_pp();
if (internals_pp && *internals_pp)
return **internals_pp;
constexpr auto *id = PYBIND11_INTERNALS_ID;
auto builtins = handle(PyEval_GetBuiltins());
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
internals_pp = static_cast<internals **>(capsule(builtins[id]));
// We loaded builtins through python's builtins, which means that our `error_already_set`
// and `builtin_exception` may be different local classes than the ones set up in the
// initial exception translator, below, so add another for our local exception classes.
//
// libstdc++ doesn't require this (types there are identified only by name)
#if !defined(__GLIBCXX__)
(*internals_pp)->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
}
}
);
#endif
} else {
if (!internals_pp) internals_pp = new internals*();
auto *&internals_ptr = *internals_pp;
internals_ptr = new internals();
#if defined(WITH_THREAD)
#if PY_VERSION_HEX < 0x03090000
PyEval_InitThreads();
#endif
PyThreadState *tstate = PyThreadState_Get();
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
PyThread_tss_set(internals_ptr->tstate, tstate);
#else
internals_ptr->tstate = PyThread_create_key();
if (internals_ptr->tstate == -1)
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
PyThread_set_key_value(internals_ptr->tstate, tstate);
#endif
internals_ptr->istate = tstate->interp;
#endif
builtins[id] = capsule(internals_pp);
internals_ptr->registered_exception_translators.push_front(
[](std::exception_ptr p) -> void {
try {
if (p) std::rethrow_exception(p);
} catch (error_already_set &e) { e.restore(); return;
} catch (const builtin_exception &e) { e.set_error(); return;
} catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
} catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
} catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
} catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
} catch (...) {
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
return;
}
}
);
internals_ptr->static_property_type = make_static_property_type();
internals_ptr->default_metaclass = make_default_metaclass();
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
}
return **internals_pp;
}
/// Works like `internals.registered_types_cpp`, but for module-local registered types:
inline type_map<type_info *> &registered_local_types_cpp() {
static type_map<type_info *> locals{};
return locals;
}
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
template <typename... Args>
const char *c_str(Args &&...args) {
auto &strings = get_internals().static_strings;
strings.emplace_front(std::forward<Args>(args)...);
return strings.front().c_str();
}
NAMESPACE_END(detail)
/// Returns a named pointer that is shared among all extension modules (using the same
/// pybind11 version) running in the current interpreter. Names starting with underscores
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
}
/// Set the shared data that can be later recovered by `get_shared_data()`.
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
detail::get_internals().shared_data[name] = data;
return data;
}
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
/// added to the shared data under the given name and a reference to it is returned.
template<typename T>
T &get_or_create_shared_data(const std::string &name) {
auto &internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,55 +0,0 @@
/*
pybind11/detail/typeid.h: Compiler-independent access to type identifiers
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include <cstdio>
#include <cstdlib>
#if defined(__GNUG__)
#include <cxxabi.h>
#endif
#include "common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Erase all occurrences of a substring
inline void erase_all(std::string &string, const std::string &search) {
for (size_t pos = 0;;) {
pos = string.find(search, pos);
if (pos == std::string::npos) break;
string.erase(pos, search.length());
}
}
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
#if defined(__GNUG__)
int status = 0;
std::unique_ptr<char, void (*)(void *)> res {
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
if (status == 0)
name = res.get();
#else
detail::erase_all(name, "class ");
detail::erase_all(name, "struct ");
detail::erase_all(name, "enum ");
#endif
detail::erase_all(name, "pybind11::");
}
NAMESPACE_END(detail)
/// Return a string representation of a C++ type
template <typename T> static std::string type_id() {
std::string name(typeid(T).name());
detail::clean_type_id(name);
return name;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,607 +0,0 @@
/*
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "numpy.h"
#if defined(__INTEL_COMPILER)
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
#elif defined(__GNUG__) || defined(__clang__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wconversion"
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
# ifdef __clang__
// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
// under Clang, so disable that warning here:
# pragma GCC diagnostic ignored "-Wdeprecated"
# endif
# if __GNUC__ >= 7
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
# endif
#endif
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17
#endif
#include <Eigen/Core>
#include <Eigen/SparseCore>
// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
// move constructors that break things. We could detect this an explicitly copy, but an extra copy
// of matrices seems highly undesirable.
static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
NAMESPACE_BEGIN(detail)
#if EIGEN_VERSION_AT_LEAST(3,3,0)
using EigenIndex = Eigen::Index;
#else
using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
#endif
// Matches Eigen::Map, Eigen::Ref, blocks, etc:
template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
// SelfAdjointView fall into this category.
template <typename T> using is_eigen_other = all_of<
is_template_base_of<Eigen::EigenBase, T>,
negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
>;
// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
template <bool EigenRowMajor> struct EigenConformable {
bool conformable = false;
EigenIndex rows = 0, cols = 0;
EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
bool negativestrides = false; // If true, do not use stride!
EigenConformable(bool fits = false) : conformable{fits} {}
// Matrix type:
EigenConformable(EigenIndex r, EigenIndex c,
EigenIndex rstride, EigenIndex cstride) :
conformable{true}, rows{r}, cols{c} {
// TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
if (rstride < 0 || cstride < 0) {
negativestrides = true;
} else {
stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
EigenRowMajor ? cstride : rstride /* inner stride */ };
}
}
// Vector type:
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
: EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
template <typename props> bool stride_compatible() const {
// To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
// matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
return
!negativestrides &&
(props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
(EigenRowMajor ? cols : rows) == 1) &&
(props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
(EigenRowMajor ? rows : cols) == 1);
}
operator bool() const { return conformable; }
};
template <typename Type> struct eigen_extract_stride { using type = Type; };
template <typename PlainObjectType, int MapOptions, typename StrideType>
struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
template <typename PlainObjectType, int Options, typename StrideType>
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
// Helper struct for extracting information from an Eigen type
template <typename Type_> struct EigenProps {
using Type = Type_;
using Scalar = typename Type::Scalar;
using StrideType = typename eigen_extract_stride<Type>::type;
static constexpr EigenIndex
rows = Type::RowsAtCompileTime,
cols = Type::ColsAtCompileTime,
size = Type::SizeAtCompileTime;
static constexpr bool
row_major = Type::IsRowMajor,
vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
fixed_rows = rows != Eigen::Dynamic,
fixed_cols = cols != Eigen::Dynamic,
fixed = size != Eigen::Dynamic, // Fully-fixed size
dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
vector ? size : row_major ? cols : rows>::value;
static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
// Takes an input array and determines whether we can make it fit into the Eigen type. If
// the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
// (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
static EigenConformable<row_major> conformable(const array &a) {
const auto dims = a.ndim();
if (dims < 1 || dims > 2)
return false;
if (dims == 2) { // Matrix type: require exact match (or dynamic)
EigenIndex
np_rows = a.shape(0),
np_cols = a.shape(1),
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
return false;
return {np_rows, np_cols, np_rstride, np_cstride};
}
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
// is used, we want the (single) numpy stride value.
const EigenIndex n = a.shape(0),
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
if (vector) { // Eigen type is a compile-time vector
if (fixed && size != n)
return false; // Vector size mismatch
return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
}
else if (fixed) {
// The type has a fixed size, but is not a vector: abort
return false;
}
else if (fixed_cols) {
// Since this isn't a vector, cols must be != 1. We allow this only if it exactly
// equals the number of elements (rows is Dynamic, and so 1 row is allowed).
if (cols != n) return false;
return {1, n, stride};
}
else {
// Otherwise it's either fully dynamic, or column dynamic; both become a column vector
if (fixed_rows && rows != n) return false;
return {n, 1, stride};
}
}
static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
static constexpr bool show_c_contiguous = show_order && requires_row_major;
static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
static constexpr auto descriptor =
_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
_("]") +
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
// satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
// options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
// to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
// see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
// *gave* a numpy.ndarray of the right type and dimensions.
_<show_writeable>(", flags.writeable", "") +
_<show_c_contiguous>(", flags.c_contiguous", "") +
_<show_f_contiguous>(", flags.f_contiguous", "") +
_("]");
};
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
array a;
if (props::vector)
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
else
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
src.data(), base);
if (!writeable)
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
return a.release();
}
// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
// non-writeable if the given type is const.
template <typename props, typename Type>
handle eigen_ref_array(Type &src, handle parent = none()) {
// none here is to get past array's should-we-copy detection, which currently always
// copies when there is no base. Setting the base to None should be harmless.
return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
}
// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
// array that references the encapsulated data with a python-side reference to the capsule to tie
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
// not the Type of the pointer given is const.
template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
handle eigen_encapsulate(Type *src) {
capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
return eigen_ref_array<props>(*src, base);
}
// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
// types.
template<typename Type>
struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
using Scalar = typename Type::Scalar;
using props = EigenProps<Type>;
bool load(handle src, bool convert) {
// If we're in no-convert mode, only load if given an array of the correct type
if (!convert && !isinstance<array_t<Scalar>>(src))
return false;
// Coerce into an array, but don't do type conversion yet; the copy below handles it.
auto buf = array::ensure(src);
if (!buf)
return false;
auto dims = buf.ndim();
if (dims < 1 || dims > 2)
return false;
auto fits = props::conformable(buf);
if (!fits)
return false;
// Allocate the new type, then build a numpy reference into it
value = Type(fits.rows, fits.cols);
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
if (dims == 1) ref = ref.squeeze();
else if (ref.ndim() == 1) buf = buf.squeeze();
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
if (result < 0) { // Copy failed!
PyErr_Clear();
return false;
}
return true;
}
private:
// Cast implementation
template <typename CType>
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::take_ownership:
case return_value_policy::automatic:
return eigen_encapsulate<props>(src);
case return_value_policy::move:
return eigen_encapsulate<props>(new CType(std::move(*src)));
case return_value_policy::copy:
return eigen_array_cast<props>(*src);
case return_value_policy::reference:
case return_value_policy::automatic_reference:
return eigen_ref_array<props>(*src);
case return_value_policy::reference_internal:
return eigen_ref_array<props>(*src, parent);
default:
throw cast_error("unhandled return_value_policy: should not happen!");
};
}
public:
// Normal returned non-reference, non-const value:
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
return cast_impl(&src, return_value_policy::move, parent);
}
// If you return a non-reference const, we mark the numpy array readonly:
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
return cast_impl(&src, return_value_policy::move, parent);
}
// lvalue reference return; default (automatic) becomes copy
static handle cast(Type &src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
policy = return_value_policy::copy;
return cast_impl(&src, policy, parent);
}
// const lvalue reference return; default (automatic) becomes copy
static handle cast(const Type &src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
policy = return_value_policy::copy;
return cast(&src, policy, parent);
}
// non-const pointer return
static handle cast(Type *src, return_value_policy policy, handle parent) {
return cast_impl(src, policy, parent);
}
// const pointer return
static handle cast(const Type *src, return_value_policy policy, handle parent) {
return cast_impl(src, policy, parent);
}
static constexpr auto name = props::descriptor;
operator Type*() { return &value; }
operator Type&() { return value; }
operator Type&&() && { return std::move(value); }
template <typename T> using cast_op_type = movable_cast_op_type<T>;
private:
Type value;
};
// Base class for casting reference/map/block/etc. objects back to python.
template <typename MapType> struct eigen_map_caster {
private:
using props = EigenProps<MapType>;
public:
// Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
// to stay around), but we'll allow it under the assumption that you know what you're doing (and
// have an appropriate keep_alive in place). We return a numpy array pointing directly at the
// ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
// that this means you need to ensure you don't destroy the object in some other way (e.g. with
// an appropriate keep_alive, or with a reference to a statically allocated matrix).
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
switch (policy) {
case return_value_policy::copy:
return eigen_array_cast<props>(src);
case return_value_policy::reference_internal:
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
case return_value_policy::reference:
case return_value_policy::automatic:
case return_value_policy::automatic_reference:
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
default:
// move, take_ownership don't make any sense for a ref/map:
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
}
}
static constexpr auto name = props::descriptor;
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
// you end up here if you try anyway.
bool load(handle, bool) = delete;
operator MapType() = delete;
template <typename> using cast_op_type = MapType;
};
// We can return any map-like object (but can only load Refs, specialized next):
template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
: eigen_map_caster<Type> {};
// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
// copying (it requires some extra effort in many cases).
template <typename PlainObjectType, typename StrideType>
struct type_caster<
Eigen::Ref<PlainObjectType, 0, StrideType>,
enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
> : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
private:
using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
using props = EigenProps<Type>;
using Scalar = typename props::Scalar;
using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
using Array = array_t<Scalar, array::forcecast |
((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
(props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
// Delay construction (these have no default constructor)
std::unique_ptr<MapType> map;
std::unique_ptr<Type> ref;
// Our array. When possible, this is just a numpy array pointing to the source data, but
// sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
// layout, or is an array of a type that needs to be converted). Using a numpy temporary
// (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
Array copy_or_ref;
public:
bool load(handle src, bool convert) {
// First check whether what we have is already an array of the right type. If not, we can't
// avoid a copy (because the copy is also going to do type conversion).
bool need_copy = !isinstance<Array>(src);
EigenConformable<props::row_major> fits;
if (!need_copy) {
// We don't need a converting copy, but we also need to check whether the strides are
// compatible with the Ref's stride requirements
Array aref = reinterpret_borrow<Array>(src);
if (aref && (!need_writeable || aref.writeable())) {
fits = props::conformable(aref);
if (!fits) return false; // Incompatible dimensions
if (!fits.template stride_compatible<props>())
need_copy = true;
else
copy_or_ref = std::move(aref);
}
else {
need_copy = true;
}
}
if (need_copy) {
// We need to copy: If we need a mutable reference, or we're not supposed to convert
// (either because we're in the no-convert overload pass, or because we're explicitly
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
if (!convert || need_writeable) return false;
Array copy = Array::ensure(src);
if (!copy) return false;
fits = props::conformable(copy);
if (!fits || !fits.template stride_compatible<props>())
return false;
copy_or_ref = std::move(copy);
loader_life_support::add_patient(copy_or_ref);
}
ref.reset();
map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
ref.reset(new Type(*map));
return true;
}
operator Type*() { return ref.get(); }
operator Type&() { return *ref; }
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
private:
template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
Scalar *data(Array &a) { return a.mutable_data(); }
template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
const Scalar *data(Array &a) { return a.data(); }
// Attempt to figure out a constructor of `Stride` that will work.
// If both strides are fixed, use a default constructor:
template <typename S> using stride_ctor_default = bool_constant<
S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
std::is_default_constructible<S>::value>;
// Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
// Eigen::Stride, and use it:
template <typename S> using stride_ctor_dual = bool_constant<
!stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
// Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
// it (passing whichever stride is dynamic).
template <typename S> using stride_ctor_outer = bool_constant<
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
std::is_constructible<S, EigenIndex>::value>;
template <typename S> using stride_ctor_inner = bool_constant<
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
std::is_constructible<S, EigenIndex>::value>;
template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
static S make_stride(EigenIndex, EigenIndex) { return S(); }
template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
};
// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
// load() is not supported, but we can cast them into the python domain by first copying to a
// regular Eigen::Matrix, then casting that.
template <typename Type>
struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
protected:
using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
using props = EigenProps<Matrix>;
public:
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
handle h = eigen_encapsulate<props>(new Matrix(src));
return h;
}
static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
static constexpr auto name = props::descriptor;
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
// types but not bound arguments). We still provide them (with an explicitly delete) so that
// you end up here if you try anyway.
bool load(handle, bool) = delete;
operator Type() = delete;
template <typename> using cast_op_type = Type;
};
template<typename Type>
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
typedef typename Type::Scalar Scalar;
typedef remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())> StorageIndex;
typedef typename Type::Index Index;
static constexpr bool rowMajor = Type::IsRowMajor;
bool load(handle src, bool) {
if (!src)
return false;
auto obj = reinterpret_borrow<object>(src);
object sparse_module = module::import("scipy.sparse");
object matrix_type = sparse_module.attr(
rowMajor ? "csr_matrix" : "csc_matrix");
if (!obj.get_type().is(matrix_type)) {
try {
obj = matrix_type(obj);
} catch (const error_already_set &) {
return false;
}
}
auto values = array_t<Scalar>((object) obj.attr("data"));
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
auto nnz = obj.attr("nnz").cast<Index>();
if (!values || !innerIndices || !outerIndices)
return false;
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
return true;
}
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
const_cast<Type&>(src).makeCompressed();
object matrix_type = module::import("scipy.sparse").attr(
rowMajor ? "csr_matrix" : "csc_matrix");
array data(src.nonZeros(), src.valuePtr());
array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
array innerIndices(src.nonZeros(), src.innerIndexPtr());
return matrix_type(
std::make_tuple(data, innerIndices, outerIndices),
std::make_pair(src.rows(), src.cols())
).release();
}
PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
+ npy_format_descriptor<Scalar>::name + _("]"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(__GNUG__) || defined(__clang__)
# pragma GCC diagnostic pop
#elif defined(_MSC_VER)
# pragma warning(pop)
#endif

View File

@@ -1,200 +0,0 @@
/*
pybind11/embed.h: Support for embedding the interpreter
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include "eval.h"
#if defined(PYPY_VERSION)
# error Embedding the interpreter is not supported with PyPy
#endif
#if PY_MAJOR_VERSION >= 3
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" PyObject *pybind11_init_impl_##name() { \
return pybind11_init_wrapper_##name(); \
}
#else
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
extern "C" void pybind11_init_impl_##name() { \
pybind11_init_wrapper_##name(); \
}
#endif
/** \rst
Add a new module to the table of builtins for the interpreter. Must be
defined in global scope. The first macro parameter is the name of the
module (without quotes). The second parameter is the variable which will
be used as the interface to add functions and classes to the module.
.. code-block:: cpp
PYBIND11_EMBEDDED_MODULE(example, m) {
// ... initialize functions and classes here
m.def("foo", []() {
return "Hello, World!";
});
}
\endrst */
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
try { \
PYBIND11_CONCAT(pybind11_init_, name)(m); \
return m.ptr(); \
} catch (pybind11::error_already_set &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} catch (const std::exception &e) { \
PyErr_SetString(PyExc_ImportError, e.what()); \
return nullptr; \
} \
} \
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
PYBIND11_CONCAT(pybind11_init_impl_, name)); \
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
struct embedded_module {
#if PY_MAJOR_VERSION >= 3
using init_t = PyObject *(*)();
#else
using init_t = void (*)();
#endif
embedded_module(const char *name, init_t init) {
if (Py_IsInitialized())
pybind11_fail("Can't add new modules after the interpreter has been initialized");
auto result = PyImport_AppendInittab(name, init);
if (result == -1)
pybind11_fail("Insufficient memory to add a new module");
}
};
NAMESPACE_END(detail)
/** \rst
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
optional parameter can be used to skip the registration of signal handlers (see the
`Python documentation`_ for details). Calling this function again after the interpreter
has already been initialized is a fatal error.
If initializing the Python interpreter fails, then the program is terminated. (This
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
of throwing exceptions on errors.)
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
\endrst */
inline void initialize_interpreter(bool init_signal_handlers = true) {
if (Py_IsInitialized())
pybind11_fail("The interpreter is already running");
Py_InitializeEx(init_signal_handlers ? 1 : 0);
// Make .py files in the working directory available by default
module::import("sys").attr("path").cast<list>().append(".");
}
/** \rst
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
after this. In addition, pybind11 objects must not outlive the interpreter:
.. code-block:: cpp
{ // BAD
py::initialize_interpreter();
auto hello = py::str("Hello, World!");
py::finalize_interpreter();
} // <-- BOOM, hello's destructor is called after interpreter shutdown
{ // GOOD
py::initialize_interpreter();
{ // scoped
auto hello = py::str("Hello, World!");
} // <-- OK, hello is cleaned up properly
py::finalize_interpreter();
}
{ // BETTER
py::scoped_interpreter guard{};
auto hello = py::str("Hello, World!");
}
.. warning::
The interpreter can be restarted by calling `initialize_interpreter` again.
Modules created using pybind11 can be safely re-initialized. However, Python
itself cannot completely unload binary extension modules and there are several
caveats with regard to interpreter restarting. All the details can be found
in the CPython documentation. In short, not all interpreter memory may be
freed, either due to reference cycles or user-created global data.
\endrst */
inline void finalize_interpreter() {
handle builtins(PyEval_GetBuiltins());
const char *id = PYBIND11_INTERNALS_ID;
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
// It could also be stashed in builtins, so look there too:
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
internals_ptr_ptr = capsule(builtins[id]);
Py_Finalize();
if (internals_ptr_ptr) {
delete *internals_ptr_ptr;
*internals_ptr_ptr = nullptr;
}
}
/** \rst
Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
This a move-only guard and only a single instance can exist.
.. code-block:: cpp
#include <pybind11/embed.h>
int main() {
py::scoped_interpreter guard{};
py::print(Hello, World!);
} // <-- interpreter shutdown
\endrst */
class scoped_interpreter {
public:
scoped_interpreter(bool init_signal_handlers = true) {
initialize_interpreter(init_signal_handlers);
}
scoped_interpreter(const scoped_interpreter &) = delete;
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
~scoped_interpreter() {
if (is_valid)
finalize_interpreter();
}
private:
bool is_valid = true;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,117 +0,0 @@
/*
pybind11/exec.h: Support for evaluating Python expressions and statements
from strings and files
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
enum eval_mode {
/// Evaluate a string containing an isolated expression
eval_expr,
/// Evaluate a string containing a single statement. Returns \c none
eval_single_statement,
/// Evaluate a string containing a sequence of statement. Returns \c none
eval_statements
};
template <eval_mode mode = eval_expr>
object eval(str expr, object global = globals(), object local = object()) {
if (!local)
local = global;
/* PyRun_String does not accept a PyObject / encoding specifier,
this seems to be the only alternative */
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
template <eval_mode mode = eval_expr, size_t N>
object eval(const char (&s)[N], object global = globals(), object local = object()) {
/* Support raw string literals by removing common leading whitespace */
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
: str(s);
return eval<mode>(expr, global, local);
}
inline void exec(str expr, object global = globals(), object local = object()) {
eval<eval_statements>(expr, global, local);
}
template <size_t N>
void exec(const char (&s)[N], object global = globals(), object local = object()) {
eval<eval_statements>(s, global, local);
}
template <eval_mode mode = eval_statements>
object eval_file(str fname, object global = globals(), object local = object()) {
if (!local)
local = global;
int start;
switch (mode) {
case eval_expr: start = Py_eval_input; break;
case eval_single_statement: start = Py_single_input; break;
case eval_statements: start = Py_file_input; break;
default: pybind11_fail("invalid evaluation mode");
}
int closeFile = 1;
std::string fname_str = (std::string) fname;
#if PY_VERSION_HEX >= 0x03040000
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
#elif PY_VERSION_HEX >= 0x03000000
FILE *f = _Py_fopen(fname.ptr(), "r");
#else
/* No unicode support in open() :( */
auto fobj = reinterpret_steal<object>(PyFile_FromString(
const_cast<char *>(fname_str.c_str()),
const_cast<char*>("r")));
FILE *f = nullptr;
if (fobj)
f = PyFile_AsFile(fobj.ptr());
closeFile = 0;
#endif
if (!f) {
PyErr_Clear();
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
}
#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
local.ptr());
(void) closeFile;
#else
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
local.ptr(), closeFile);
#endif
if (!result)
throw error_already_set();
return reinterpret_steal<object>(result);
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,108 +0,0 @@
/*
pybind11/functional.h: std::function<> support
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <functional>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename Return, typename... Args>
struct type_caster<std::function<Return(Args...)>> {
using type = std::function<Return(Args...)>;
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
using function_type = Return (*) (Args...);
public:
bool load(handle src, bool convert) {
if (src.is_none()) {
// Defer accepting None to other overloads (if we aren't in convert mode):
if (!convert) return false;
return true;
}
if (!isinstance<function>(src))
return false;
auto func = reinterpret_borrow<function>(src);
/*
When passing a C++ function as an argument to another C++
function via Python, every function call would normally involve
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
Here, we try to at least detect the case where the function is
stateless (i.e. function pointer or lambda function without
captured variables), in which case the roundtrip can be avoided.
*/
if (auto cfunc = func.cpp_function()) {
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
auto rec = (function_record *) c;
if (rec && rec->is_stateless &&
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
struct capture { function_type f; };
value = ((capture *) &rec->data)->f;
return true;
}
}
// ensure GIL is held during functor destruction
struct func_handle {
function f;
func_handle(function&& f_) : f(std::move(f_)) {}
func_handle(const func_handle&) = default;
~func_handle() {
gil_scoped_acquire acq;
function kill_f(std::move(f));
}
};
// value = [hfunc = func_handle(std::move(func))](Args... args) -> Return {
// gil_scoped_acquire acq;
// object retval(hfunc.f(std::forward<Args>(args)...));
// /* Visual studio 2015 parser issue: need parentheses around this expression */
// return (retval.template cast<Return>());
// };
struct func_wrapper {
func_handle hfunc;
func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {}
Return operator()(Args... args) const {
gil_scoped_acquire acq;
object retval(hfunc.f(std::forward<Args>(args)...));
/* Visual studio 2015 parser issue: need parentheses around this expression */
return (retval.template cast<Return>());
}
};
value = func_wrapper(func_handle(std::move(func)));
return true;
}
template <typename Func>
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
if (!f_)
return none().inc_ref();
auto result = f_.template target<function_type>();
if (result)
return cpp_function(*result, policy).release();
else
return cpp_function(std::forward<Func>(f_), policy).release();
}
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
+ make_caster<retval_type>::name + _("]"));
};
NAMESPACE_END(detail)
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,207 +0,0 @@
/*
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
Copyright (c) 2017 Henry F. Schreiner
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <streambuf>
#include <ostream>
#include <string>
#include <memory>
#include <iostream>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
// Buffer that writes to Python instead of C++
class pythonbuf : public std::streambuf {
private:
using traits_type = std::streambuf::traits_type;
const size_t buf_size;
std::unique_ptr<char[]> d_buffer;
object pywrite;
object pyflush;
int overflow(int c) {
if (!traits_type::eq_int_type(c, traits_type::eof())) {
*pptr() = traits_type::to_char_type(c);
pbump(1);
}
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
}
int sync() {
if (pbase() != pptr()) {
// This subtraction cannot be negative, so dropping the sign
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
{
gil_scoped_acquire tmp;
pywrite(line);
pyflush();
}
setp(pbase(), epptr());
}
return 0;
}
public:
pythonbuf(object pyostream, size_t buffer_size = 1024)
: buf_size(buffer_size),
d_buffer(new char[buf_size]),
pywrite(pyostream.attr("write")),
pyflush(pyostream.attr("flush")) {
setp(d_buffer.get(), d_buffer.get() + buf_size - 1);
}
/// Sync before destroy
~pythonbuf() {
sync();
}
};
NAMESPACE_END(detail)
/** \rst
This a move-only guard that redirects output.
.. code-block:: cpp
#include <pybind11/iostream.h>
...
{
py::scoped_ostream_redirect output;
std::cout << "Hello, World!"; // Python stdout
} // <-- return std::cout to normal
You can explicitly pass the c++ stream and the python object,
for example to guard stderr instead.
.. code-block:: cpp
{
py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
std::cerr << "Hello, World!";
}
\endrst */
class scoped_ostream_redirect {
protected:
std::streambuf *old;
std::ostream &costream;
detail::pythonbuf buffer;
public:
scoped_ostream_redirect(
std::ostream &costream = std::cout,
object pyostream = module::import("sys").attr("stdout"))
: costream(costream), buffer(pyostream) {
old = costream.rdbuf(&buffer);
}
~scoped_ostream_redirect() {
costream.rdbuf(old);
}
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
};
/** \rst
Like `scoped_ostream_redirect`, but redirects cerr by default. This class
is provided primary to make ``py::call_guard`` easier to make.
.. code-block:: cpp
m.def("noisy_func", &noisy_func,
py::call_guard<scoped_ostream_redirect,
scoped_estream_redirect>());
\endrst */
class scoped_estream_redirect : public scoped_ostream_redirect {
public:
scoped_estream_redirect(
std::ostream &costream = std::cerr,
object pyostream = module::import("sys").attr("stderr"))
: scoped_ostream_redirect(costream,pyostream) {}
};
NAMESPACE_BEGIN(detail)
// Class to redirect output as a context manager. C++ backend.
class OstreamRedirect {
bool do_stdout_;
bool do_stderr_;
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
public:
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
void enter() {
if (do_stdout_)
redirect_stdout.reset(new scoped_ostream_redirect());
if (do_stderr_)
redirect_stderr.reset(new scoped_estream_redirect());
}
void exit() {
redirect_stdout.reset();
redirect_stderr.reset();
}
};
NAMESPACE_END(detail)
/** \rst
This is a helper function to add a C++ redirect context manager to Python
instead of using a C++ guard. To use it, add the following to your binding code:
.. code-block:: cpp
#include <pybind11/iostream.h>
...
py::add_ostream_redirect(m, "ostream_redirect");
You now have a Python context manager that redirects your output:
.. code-block:: python
with m.ostream_redirect():
m.print_to_cout_function()
This manager can optionally be told which streams to operate on:
.. code-block:: python
with m.ostream_redirect(stdout=true, stderr=true):
m.noisy_function_with_error_printing()
\endrst */
inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::string name = "ostream_redirect") {
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
.def("__enter__", &detail::OstreamRedirect::enter)
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
}
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

View File

@@ -1,168 +0,0 @@
/*
pybind11/operator.h: Metatemplates for operator overloading
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#if defined(__clang__) && !defined(__INTEL_COMPILER)
# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
#elif defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Enumeration with all supported operator types
enum op_id : int {
op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
op_repr, op_truediv, op_itruediv, op_hash
};
enum op_type : int {
op_l, /* base type on left */
op_r, /* base type on right */
op_u /* unary operator */
};
struct self_t { };
static const self_t self = self_t();
/// Type for an unused type slot
struct undefined_t { };
/// Don't warn about an unused variable
inline self_t __self() { return self; }
/// base template of operator implementations
template <op_id, op_type, typename B, typename L, typename R> struct op_impl { };
/// Operator implementation generator
template <op_id id, op_type ot, typename L, typename R> struct op_ {
template <typename Class, typename... Extra> void execute(Class &cl, const Extra&... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
using Base = typename Class::type;
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
using op = op_impl<id, ot, Base, L_type, R_type>;
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
#if PY_MAJOR_VERSION < 3
if (id == op_truediv || id == op_itruediv)
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
&op::execute, is_operator(), extra...);
#endif
}
};
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
static B execute_cast(const L &l, const R &r) { return B(expr); } \
}; \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
static char const* name() { return "__" #rid "__"; } \
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
static B execute_cast(const R &r, const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
return op_<op_##id, op_l, self_t, self_t>(); \
} \
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
} \
template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
return op_<op_##id, op_r, T, self_t>(); \
}
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
static B execute_cast(L &l, const R &r) { return B(expr); } \
}; \
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
return op_<op_##id, op_l, self_t, T>(); \
}
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l) -> decltype(expr) { return expr; } \
static B execute_cast(const L &l) { return B(expr); } \
}; \
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
return op_<op_##id, op_u, self_t, undefined_t>(); \
}
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
#undef PYBIND11_BINARY_OPERATOR
#undef PYBIND11_INPLACE_OPERATOR
#undef PYBIND11_UNARY_OPERATOR
NAMESPACE_END(detail)
using detail::self;
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(_MSC_VER)
# pragma warning(pop)
#endif

View File

@@ -1,65 +0,0 @@
/*
pybind11/options.h: global settings that are configurable at runtime.
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class options {
public:
// Default RAII constructor, which leaves settings as they currently are.
options() : previous_state(global_state()) {}
// Class is non-copyable.
options(const options&) = delete;
options& operator=(const options&) = delete;
// Destructor, which restores settings that were in effect before.
~options() {
global_state() = previous_state;
}
// Setter methods (affect the global state):
options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
// Getter methods (return the global state):
static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
static bool show_function_signatures() { return global_state().show_function_signatures; }
// This type is not meant to be allocated on the heap.
void* operator new(size_t) = delete;
private:
struct state {
bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
};
static state &global_state() {
static state instance;
return instance;
}
state previous_state;
};
NAMESPACE_END(PYBIND11_NAMESPACE)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,386 +0,0 @@
/*
pybind11/stl.h: Transparent conversion for STL data types
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11.h"
#include <set>
#include <unordered_set>
#include <map>
#include <unordered_map>
#include <iostream>
#include <list>
#include <deque>
#include <valarray>
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
#endif
#ifdef __has_include
// std::optional (but including it in c++14 mode isn't allowed)
# if defined(PYBIND11_CPP17) && __has_include(<optional>)
# include <optional>
# define PYBIND11_HAS_OPTIONAL 1
# endif
// std::experimental::optional (but not allowed in c++11 mode)
# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
!__has_include(<optional>))
# include <experimental/optional>
# define PYBIND11_HAS_EXP_OPTIONAL 1
# endif
// std::variant
# if defined(PYBIND11_CPP17) && __has_include(<variant>)
# include <variant>
# define PYBIND11_HAS_VARIANT 1
# endif
#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
# include <optional>
# include <variant>
# define PYBIND11_HAS_OPTIONAL 1
# define PYBIND11_HAS_VARIANT 1
#endif
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
template <typename T, typename U>
using forwarded_type = conditional_t<
std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
/// used for forwarding a container's elements.
template <typename T, typename U>
forwarded_type<T, U> forward_like(U &&u) {
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
}
template <typename Type, typename Key> struct set_caster {
using type = Type;
using key_conv = make_caster<Key>;
bool load(handle src, bool convert) {
if (!isinstance<pybind11::set>(src))
return false;
auto s = reinterpret_borrow<pybind11::set>(src);
value.clear();
for (auto entry : s) {
key_conv conv;
if (!conv.load(entry, convert))
return false;
value.insert(cast_op<Key &&>(std::move(conv)));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Key>::policy(policy);
pybind11::set s;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
if (!value_ || !s.add(value_))
return handle();
}
return s.release();
}
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
};
template <typename Type, typename Key, typename Value> struct map_caster {
using key_conv = make_caster<Key>;
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<dict>(src))
return false;
auto d = reinterpret_borrow<dict>(src);
value.clear();
for (auto it : d) {
key_conv kconv;
value_conv vconv;
if (!kconv.load(it.first.ptr(), convert) ||
!vconv.load(it.second.ptr(), convert))
return false;
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
dict d;
return_value_policy policy_key = policy;
return_value_policy policy_value = policy;
if (!std::is_lvalue_reference<T>::value) {
policy_key = return_value_policy_override<Key>::policy(policy_key);
policy_value = return_value_policy_override<Value>::policy(policy_value);
}
for (auto &&kv : src) {
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
if (!key || !value)
return handle();
d[key] = value;
}
return d.release();
}
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
};
template <typename Type, typename Value> struct list_caster {
using value_conv = make_caster<Value>;
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src) || isinstance<str>(src))
return false;
auto s = reinterpret_borrow<sequence>(src);
value.clear();
reserve_maybe(s, &value);
for (auto it : s) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value.push_back(cast_op<Value &&>(std::move(conv)));
}
return true;
}
private:
template <typename T = Type,
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
void reserve_maybe(sequence, void *) { }
public:
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
if (!std::is_lvalue_reference<T>::value)
policy = return_value_policy_override<Value>::policy(policy);
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
}
return l.release();
}
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
};
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
: list_caster<std::vector<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
: list_caster<std::deque<Type, Alloc>, Type> { };
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
: list_caster<std::list<Type, Alloc>, Type> { };
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0> struct array_caster {
using value_conv = make_caster<Value>;
private:
template <bool R = Resizable>
bool require_size(enable_if_t<R, size_t> size) {
if (value.size() != size)
value.resize(size);
return true;
}
template <bool R = Resizable>
bool require_size(enable_if_t<!R, size_t> size) {
return size == Size;
}
public:
bool load(handle src, bool convert) {
if (!isinstance<sequence>(src))
return false;
auto l = reinterpret_borrow<sequence>(src);
if (!require_size(l.size()))
return false;
size_t ctr = 0;
for (auto it : l) {
value_conv conv;
if (!conv.load(it, convert))
return false;
value[ctr++] = cast_op<Value &&>(std::move(conv));
}
return true;
}
template <typename T>
static handle cast(T &&src, return_value_policy policy, handle parent) {
list l(src.size());
size_t index = 0;
for (auto &&value : src) {
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
if (!value_)
return handle();
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
}
return l.release();
}
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
};
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
: array_caster<std::array<Type, Size>, Type, false, Size> { };
template <typename Type> struct type_caster<std::valarray<Type>>
: array_caster<std::valarray<Type>, Type, true> { };
template <typename Key, typename Compare, typename Alloc> struct type_caster<std::set<Key, Compare, Alloc>>
: set_caster<std::set<Key, Compare, Alloc>, Key> { };
template <typename Key, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> { };
template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>>
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> { };
template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> { };
// This type caster is intended to be used for std::optional and std::experimental::optional
template<typename T> struct optional_caster {
using value_conv = make_caster<typename T::value_type>;
template <typename T_>
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
if (!src)
return none().inc_ref();
policy = return_value_policy_override<typename T::value_type>::policy(policy);
return value_conv::cast(*std::forward<T_>(src), policy, parent);
}
bool load(handle src, bool convert) {
if (!src) {
return false;
} else if (src.is_none()) {
return true; // default-constructed value is already empty
}
value_conv inner_caster;
if (!inner_caster.load(src, convert))
return false;
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
return true;
}
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
};
#if PYBIND11_HAS_OPTIONAL
template<typename T> struct type_caster<std::optional<T>>
: public optional_caster<std::optional<T>> {};
template<> struct type_caster<std::nullopt_t>
: public void_caster<std::nullopt_t> {};
#endif
#if PYBIND11_HAS_EXP_OPTIONAL
template<typename T> struct type_caster<std::experimental::optional<T>>
: public optional_caster<std::experimental::optional<T>> {};
template<> struct type_caster<std::experimental::nullopt_t>
: public void_caster<std::experimental::nullopt_t> {};
#endif
/// Visit a variant and cast any found type to Python
struct variant_caster_visitor {
return_value_policy policy;
handle parent;
using result_type = handle; // required by boost::variant in C++11
template <typename T>
result_type operator()(T &&src) const {
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
}
};
/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
/// automatically using argument-dependent lookup. Users can provide specializations for other
/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
template <template<typename...> class Variant>
struct visit_helper {
template <typename... Args>
static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
return visit(std::forward<Args>(args)...);
}
};
/// Generic variant caster
template <typename Variant> struct variant_caster;
template <template<typename...> class V, typename... Ts>
struct variant_caster<V<Ts...>> {
static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
template <typename U, typename... Us>
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
auto caster = make_caster<U>();
if (caster.load(src, convert)) {
value = cast_op<U>(caster);
return true;
}
return load_alternative(src, convert, type_list<Us...>{});
}
bool load_alternative(handle, bool, type_list<>) { return false; }
bool load(handle src, bool convert) {
// Do a first pass without conversions to improve constructor resolution.
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
// slot of the variant. Without two-pass loading `double` would be filled
// because it appears first and a conversion is possible.
if (convert && load_alternative(src, false, type_list<Ts...>{}))
return true;
return load_alternative(src, convert, type_list<Ts...>{});
}
template <typename Variant>
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
std::forward<Variant>(src));
}
using Type = V<Ts...>;
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
};
#if PYBIND11_HAS_VARIANT
template <typename... Ts>
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> { };
#endif
NAMESPACE_END(detail)
inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
os << (std::string) str(obj);
return os;
}
NAMESPACE_END(PYBIND11_NAMESPACE)
#if defined(_MSC_VER)
#pragma warning(pop)
#endif

View File

@@ -1,630 +0,0 @@
/*
pybind11/std_bind.h: Binding generators for STL data types
Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "detail/common.h"
#include "operators.h"
#include <algorithm>
#include <sstream>
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
NAMESPACE_BEGIN(detail)
/* SFINAE helper class used by 'is_comparable */
template <typename T> struct container_traits {
template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
template <typename T2> static std::false_type test_comparable(...);
template <typename T2> static std::true_type test_value(typename T2::value_type *);
template <typename T2> static std::false_type test_value(...);
template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
template <typename T2> static std::false_type test_pair(...);
static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
static constexpr const bool is_element = !is_pair && !is_vector;
};
/* Default: is_comparable -> std::false_type */
template <typename T, typename SFINAE = void>
struct is_comparable : std::false_type { };
/* For non-map data structures, check whether operator== can be instantiated */
template <typename T>
struct is_comparable<
T, enable_if_t<container_traits<T>::is_element &&
container_traits<T>::is_comparable>>
: std::true_type { };
/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
template <typename T>
struct is_comparable<T, enable_if_t<container_traits<T>::is_vector>> {
static constexpr const bool value =
is_comparable<typename T::value_type>::value;
};
/* For pairs, recursively check the two data types */
template <typename T>
struct is_comparable<T, enable_if_t<container_traits<T>::is_pair>> {
static constexpr const bool value =
is_comparable<typename T::first_type>::value &&
is_comparable<typename T::second_type>::value;
};
/* Fallback functions */
template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args &...) { }
template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
template<typename Vector, typename Class_>
void vector_if_copy_constructible(enable_if_t<is_copy_constructible<Vector>::value, Class_> &cl) {
cl.def(init<const Vector &>(), "Copy constructor");
}
template<typename Vector, typename Class_>
void vector_if_equal_operator(enable_if_t<is_comparable<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type;
cl.def(self == self);
cl.def(self != self);
cl.def("count",
[](const Vector &v, const T &x) {
return std::count(v.begin(), v.end(), x);
},
arg("x"),
"Return the number of times ``x`` appears in the list"
);
cl.def("remove", [](Vector &v, const T &x) {
auto p = std::find(v.begin(), v.end(), x);
if (p != v.end())
v.erase(p);
else
throw value_error();
},
arg("x"),
"Remove the first item from the list whose value is x. "
"It is an error if there is no such item."
);
cl.def("__contains__",
[](const Vector &v, const T &x) {
return std::find(v.begin(), v.end(), x) != v.end();
},
arg("x"),
"Return true the container contains ``x``"
);
}
// Vector modifiers -- requires a copyable vector_type:
// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems
// silly to allow deletion but not insertion, so include them here too.)
template <typename Vector, typename Class_>
void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
using T = typename Vector::value_type;
using SizeType = typename Vector::size_type;
using DiffType = typename Vector::difference_type;
cl.def("append",
[](Vector &v, const T &value) { v.push_back(value); },
arg("x"),
"Add an item to the end of the list");
cl.def(init([](iterable it) {
auto v = std::unique_ptr<Vector>(new Vector());
v->reserve(len_hint(it));
for (handle h : it)
v->push_back(h.cast<T>());
return v.release();
}));
cl.def("extend",
[](Vector &v, const Vector &src) {
v.insert(v.end(), src.begin(), src.end());
},
arg("L"),
"Extend the list by appending all the items in the given list"
);
cl.def("extend",
[](Vector &v, iterable it) {
const size_t old_size = v.size();
v.reserve(old_size + len_hint(it));
try {
for (handle h : it) {
v.push_back(h.cast<T>());
}
} catch (const cast_error &) {
v.erase(v.begin() + static_cast<typename Vector::difference_type>(old_size), v.end());
try {
v.shrink_to_fit();
} catch (const std::exception &) {
// Do nothing
}
throw;
}
},
arg("L"),
"Extend the list by appending all the items in the given list"
);
cl.def("insert",
[](Vector &v, SizeType i, const T &x) {
if (i > v.size())
throw index_error();
v.insert(v.begin() + (DiffType) i, x);
},
arg("i") , arg("x"),
"Insert an item at a given position."
);
cl.def("pop",
[](Vector &v) {
if (v.empty())
throw index_error();
T t = v.back();
v.pop_back();
return t;
},
"Remove and return the last item"
);
cl.def("pop",
[](Vector &v, SizeType i) {
if (i >= v.size())
throw index_error();
T t = v[i];
v.erase(v.begin() + (DiffType) i);
return t;
},
arg("i"),
"Remove and return the item at index ``i``"
);
cl.def("__setitem__",
[](Vector &v, SizeType i, const T &t) {
if (i >= v.size())
throw index_error();
v[i] = t;
}
);
/// Slicing protocol
cl.def("__getitem__",
[](const Vector &v, slice slice) -> Vector * {
size_t start, stop, step, slicelength;
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
throw error_already_set();
Vector *seq = new Vector();
seq->reserve((size_t) slicelength);
for (size_t i=0; i<slicelength; ++i) {
seq->push_back(v[start]);
start += step;
}
return seq;
},
arg("s"),
"Retrieve list elements using a slice object"
);
cl.def("__setitem__",
[](Vector &v, slice slice, const Vector &value) {
size_t start, stop, step, slicelength;
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
throw error_already_set();
if (slicelength != value.size())
throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
for (size_t i=0; i<slicelength; ++i) {
v[start] = value[i];
start += step;
}
},
"Assign list elements using a slice object"
);
cl.def("__delitem__",
[](Vector &v, SizeType i) {
if (i >= v.size())
throw index_error();
v.erase(v.begin() + DiffType(i));
},
"Delete the list elements at index ``i``"
);
cl.def("__delitem__",
[](Vector &v, slice slice) {
size_t start, stop, step, slicelength;
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
throw error_already_set();
if (step == 1 && false) {
v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
} else {
for (size_t i = 0; i < slicelength; ++i) {
v.erase(v.begin() + DiffType(start));
start += step - 1;
}
}
},
"Delete list elements using a slice object"
);
}
// If the type has an operator[] that doesn't return a reference (most notably std::vector<bool>),
// we have to access by copying; otherwise we return by reference.
template <typename Vector> using vector_needs_copy = negation<
std::is_same<decltype(std::declval<Vector>()[typename Vector::size_type()]), typename Vector::value_type &>>;
// The usual case: access and iterate by reference
template <typename Vector, typename Class_>
void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type;
using SizeType = typename Vector::size_type;
using ItType = typename Vector::iterator;
cl.def("__getitem__",
[](Vector &v, SizeType i) -> T & {
if (i >= v.size())
throw index_error();
return v[i];
},
return_value_policy::reference_internal // ref + keepalive
);
cl.def("__iter__",
[](Vector &v) {
return make_iterator<
return_value_policy::reference_internal, ItType, ItType, T&>(
v.begin(), v.end());
},
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
}
// The case for special objects, like std::vector<bool>, that have to be returned-by-copy:
template <typename Vector, typename Class_>
void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
using T = typename Vector::value_type;
using SizeType = typename Vector::size_type;
using ItType = typename Vector::iterator;
cl.def("__getitem__",
[](const Vector &v, SizeType i) -> T {
if (i >= v.size())
throw index_error();
return v[i];
}
);
cl.def("__iter__",
[](Vector &v) {
return make_iterator<
return_value_policy::copy, ItType, ItType, T>(
v.begin(), v.end());
},
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
}
template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
-> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
using size_type = typename Vector::size_type;
cl.def("__repr__",
[name](Vector &v) {
std::ostringstream s;
s << name << '[';
for (size_type i=0; i < v.size(); ++i) {
s << v[i];
if (i != v.size() - 1)
s << ", ";
}
s << ']';
return s.str();
},
"Return the canonical string representation of this list."
);
}
// Provide the buffer interface for vectors if we have data() and we have a format for it
// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
template <typename Vector, typename = void>
struct vector_has_data_and_format : std::false_type {};
template <typename Vector>
struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
// Add the buffer interface to a vector
template <typename Vector, typename Class_, typename... Args>
enable_if_t<detail::any_of<std::is_same<Args, buffer_protocol>...>::value>
vector_buffer(Class_& cl) {
using T = typename Vector::value_type;
static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
// numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
format_descriptor<T>::format();
cl.def_buffer([](Vector& v) -> buffer_info {
return buffer_info(v.data(), static_cast<ssize_t>(sizeof(T)), format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
});
cl.def(init([](buffer buf) {
auto info = buf.request();
if (info.ndim != 1 || info.strides[0] % static_cast<ssize_t>(sizeof(T)))
throw type_error("Only valid 1D buffers can be copied to a vector");
if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
auto vec = std::unique_ptr<Vector>(new Vector());
vec->reserve((size_t) info.shape[0]);
T *p = static_cast<T*>(info.ptr);
ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
T *end = p + info.shape[0] * step;
for (; p != end; p += step)
vec->push_back(*p);
return vec.release();
}));
return;
}
template <typename Vector, typename Class_, typename... Args>
enable_if_t<!detail::any_of<std::is_same<Args, buffer_protocol>...>::value> vector_buffer(Class_&) {}
NAMESPACE_END(detail)
//
// std::vector
//
template <typename Vector, typename holder_type = std::unique_ptr<Vector>, typename... Args>
class_<Vector, holder_type> bind_vector(handle scope, std::string const &name, Args&&... args) {
using Class_ = class_<Vector, holder_type>;
// If the value_type is unregistered (e.g. a converting type) or is itself registered
// module-local then make the vector binding module-local as well:
using vtype = typename Vector::value_type;
auto vtype_info = detail::get_type_info(typeid(vtype));
bool local = !vtype_info || vtype_info->module_local;
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
// Declare the buffer interface if a buffer_protocol() is passed in
detail::vector_buffer<Vector, Class_, Args...>(cl);
cl.def(init<>());
// Register copy constructor (if possible)
detail::vector_if_copy_constructible<Vector, Class_>(cl);
// Register comparison-related operators and functions (if possible)
detail::vector_if_equal_operator<Vector, Class_>(cl);
// Register stream insertion operator (if possible)
detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
// Modifiers require copyable vector value type
detail::vector_modifiers<Vector, Class_>(cl);
// Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive
detail::vector_accessor<Vector, Class_>(cl);
cl.def("__bool__",
[](const Vector &v) -> bool {
return !v.empty();
},
"Check whether the list is nonempty"
);
cl.def("__len__", &Vector::size);
#if 0
// C++ style functions deprecated, leaving it here as an example
cl.def(init<size_type>());
cl.def("resize",
(void (Vector::*) (size_type count)) & Vector::resize,
"changes the number of elements stored");
cl.def("erase",
[](Vector &v, SizeType i) {
if (i >= v.size())
throw index_error();
v.erase(v.begin() + i);
}, "erases element at index ``i``");
cl.def("empty", &Vector::empty, "checks whether the container is empty");
cl.def("size", &Vector::size, "returns the number of elements");
cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
cl.def("pop_back", &Vector::pop_back, "removes the last element");
cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements");
cl.def("reserve", &Vector::reserve, "reserves storage");
cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage");
cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");
cl.def("clear", &Vector::clear, "clears the contents");
cl.def("swap", &Vector::swap, "swaps the contents");
cl.def("front", [](Vector &v) {
if (v.size()) return v.front();
else throw index_error();
}, "access the first element");
cl.def("back", [](Vector &v) {
if (v.size()) return v.back();
else throw index_error();
}, "access the last element ");
#endif
return cl;
}
//
// std::map, std::unordered_map
//
NAMESPACE_BEGIN(detail)
/* Fallback functions */
template <typename, typename, typename... Args> void map_if_insertion_operator(const Args &...) { }
template <typename, typename, typename... Args> void map_assignment(const Args &...) { }
// Map assignment when copy-assignable: just copy the value
template <typename Map, typename Class_>
void map_assignment(enable_if_t<std::is_copy_assignable<typename Map::mapped_type>::value, Class_> &cl) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
cl.def("__setitem__",
[](Map &m, const KeyType &k, const MappedType &v) {
auto it = m.find(k);
if (it != m.end()) it->second = v;
else m.emplace(k, v);
}
);
}
// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting
template<typename Map, typename Class_>
void map_assignment(enable_if_t<
!std::is_copy_assignable<typename Map::mapped_type>::value &&
is_copy_constructible<typename Map::mapped_type>::value,
Class_> &cl) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
cl.def("__setitem__",
[](Map &m, const KeyType &k, const MappedType &v) {
// We can't use m[k] = v; because value type might not be default constructable
auto r = m.emplace(k, v);
if (!r.second) {
// value type is not copy assignable so the only way to insert it is to erase it first...
m.erase(r.first);
m.emplace(k, v);
}
}
);
}
template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &cl, std::string const &name)
-> decltype(std::declval<std::ostream&>() << std::declval<typename Map::key_type>() << std::declval<typename Map::mapped_type>(), void()) {
cl.def("__repr__",
[name](Map &m) {
std::ostringstream s;
s << name << '{';
bool f = false;
for (auto const &kv : m) {
if (f)
s << ", ";
s << kv.first << ": " << kv.second;
f = true;
}
s << '}';
return s.str();
},
"Return the canonical string representation of this map."
);
}
NAMESPACE_END(detail)
template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) {
using KeyType = typename Map::key_type;
using MappedType = typename Map::mapped_type;
using Class_ = class_<Map, holder_type>;
// If either type is a non-module-local bound type then make the map binding non-local as well;
// otherwise (e.g. both types are either module-local or converting) the map will be
// module-local.
auto tinfo = detail::get_type_info(typeid(MappedType));
bool local = !tinfo || tinfo->module_local;
if (local) {
tinfo = detail::get_type_info(typeid(KeyType));
local = !tinfo || tinfo->module_local;
}
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
cl.def(init<>());
// Register stream insertion operator (if possible)
detail::map_if_insertion_operator<Map, Class_>(cl, name);
cl.def("__bool__",
[](const Map &m) -> bool { return !m.empty(); },
"Check whether the map is nonempty"
);
cl.def("__iter__",
[](Map &m) { return make_key_iterator(m.begin(), m.end()); },
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
cl.def("items",
[](Map &m) { return make_iterator(m.begin(), m.end()); },
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
);
cl.def("__getitem__",
[](Map &m, const KeyType &k) -> MappedType & {
auto it = m.find(k);
if (it == m.end())
throw key_error();
return it->second;
},
return_value_policy::reference_internal // ref + keepalive
);
cl.def("__contains__",
[](Map &m, const KeyType &k) -> bool {
auto it = m.find(k);
if (it == m.end())
return false;
return true;
}
);
// Assignment provided only if the type is copyable
detail::map_assignment<Map, Class_>(cl);
cl.def("__delitem__",
[](Map &m, const KeyType &k) {
auto it = m.find(k);
if (it == m.end())
throw key_error();
m.erase(it);
}
);
cl.def("__len__", &Map::size);
return cl;
}
NAMESPACE_END(PYBIND11_NAMESPACE)

View File

@@ -1,5 +1,6 @@
#include "triton/codegen/pass.h"
#include "triton/codegen/target.h"
#include "triton/codegen/extern_lib.h"
#include "triton/driver/error.h"
#include "triton/driver/llvm.h"
#include "triton/ir/builder.h"
@@ -19,7 +20,6 @@
#include <stdexcept>
#include <string>
#include "llvm/IR/Module.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Verifier.h"
namespace py = pybind11;
@@ -83,8 +83,8 @@ void cu_enqueue(uint64_t stream, uint64_t kernel,
CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
CU_LAUNCH_PARAM_END
};
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (CUstream)stream, nullptr, config);
}
@@ -97,8 +97,8 @@ void hip_enqueue(uint64_t stream, uint64_t kernel,
HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
HIP_LAUNCH_PARAM_END
};
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
block_0, block_1, block_2,
shared_mem, (hipStream_t)stream, nullptr, config);
}
@@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){
// Launch
void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
int num_warps, int num_stages) {
int num_warps, int num_stages, py::dict& extern_libs) {
size_t len = PyList_Size(args.ptr());
params.reserve(8*len); // 8 max bytes by argument
char* params_ptr = &params[0];
@@ -226,18 +226,29 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
// copy param
std::memcpy(params_ptr, &value, 8);
params_ptr += 8;
// udpate cache key
// update cache key
cache_key += dtype_cache_key_part(arg.attr("dtype"));
cache_key += "*";
cache_key += "[multipleof(";
size_t range_size = get_pointer_range_size(value);
size_t range_size;
try {
range_size = get_pointer_range_size(value);
} catch (...) {
throw std::runtime_error("argument tensor #" + std::to_string(i) + " is not on cuda! " + std::string(py::str(arg)));
}
cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
cache_key += ")]";
continue;
}
// argument is `constexpr`
if(py::hasattr(arg, "value")){
if (py::hasattr(arg, "value")) {
py::object value = arg.attr("value");
// check if value is a callable object using PyCallable_Check
if (PyCallable_Check(value.ptr())) {
throw std::runtime_error(
"constant argument cannot be a callable object: " +
std::string(py::str(arg)));
}
py::object name = arg_names[i];
constants[name] = value;
py::object repr = py::repr(value);
@@ -256,6 +267,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
throw std::runtime_error(err_msg);
}
params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
for (auto item : extern_libs) {
cache_key += "-" + item.first.cast<std::string>();
cache_key += "_" + item.second.cast<std::string>();
}
}
//
@@ -286,9 +302,9 @@ void init_triton_runtime(py::module &&m) {
// cache key
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
py::function add_to_cache, py::object grid){
m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
py::dict extern_libs, py::function add_to_cache, py::object grid){
// parse arguments to compute cache key, compile-time constants and packed kernel arguments
long _num_warps = PyLong_AsLong(num_warps.ptr());
long _num_stages = PyLong_AsLong(num_stages.ptr());
@@ -296,13 +312,14 @@ void init_triton_runtime(py::module &&m) {
std::string params;
size_t params_size;
py::dict constants;
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages);
parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
params_size, constants, _num_warps, _num_stages, extern_libs);
// get cached binary
py::str key(cache_key);
py::bool_ noop = false;
if(!bin_cache.contains(key)) {
noop = add_to_cache(key, args, device, num_warps, num_stages);
noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs);
}
if (noop)
return (py::object)py::none();
@@ -334,8 +351,8 @@ void init_triton_runtime(py::module &&m) {
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
py::gil_scoped_release allow_threads;
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
}
return bin;
@@ -355,7 +372,7 @@ void init_triton_runtime(py::module &&m) {
m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
if (backend == HOST)
return 0;
if(backend == CUDA)
if(backend == CUDA)
return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
if(backend == ROCM)
return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
@@ -405,7 +422,7 @@ void init_triton_runtime(py::module &&m) {
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
});
}
/*****************************************************************************/
@@ -413,133 +430,110 @@ void init_triton_runtime(py::module &&m) {
/*****************************************************************************/
typedef std::map<std::string, py::object> asm_map_t;
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
// CUDA
std::tuple<uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
// load assembly
std::string assembly;
if(asm_map.find("cubin") != asm_map.end())
assembly = py::cast<std::string>(asm_map["cubin"]);
else
assembly = py::cast<std::string>(asm_map["ptx"]);
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
if(n_shared_bytes > 49152 && shared_optin > 49152){
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
int n_spills, n_reg;
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
drv::dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
// ROCM
std::tuple<uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::bytes _assembly = asm_map["hsaco"];
std::string assembly = py::cast<std::string>(_assembly);
// HSA-CO -> hipModule
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
// Handle to the kernel
hipFunction_t fun;
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
// record asm
return std::make_tuple((uint64_t)mod, (uint64_t)fun);
}
// ---------------------------------------
// ---------------------------------------
// Compile Triton-IR to assembly
// ---------------------------------------
// CUDA
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
asm_map_t &asm_map){
int n_shared_bytes;
py::gil_scoped_release allow_threads;
llvm::LLVMContext ctx;
// device properties
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
int version;
std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
llir.flush();
asm_map["llir"] = py::cast(tmp);
// LLVM-IR -> PTX
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
asm_map["ptx"] = py::cast(ptx);
// PTX -> Binary
std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
if(!cubin.empty()){
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;
}
return std::make_tuple(name, asm_map, n_shared_bytes);
}
// HIP
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(const std::string& name, ir::module &ir,
uint64_t device, int num_warps, int num_stages,
asm_map_t &asm_map){
llvm::LLVMContext ctx;
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::amd_cl_target target;
int n_shared_bytes;
auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes);
std::string tmp;
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
llir.flush();
asm_map["llir"] = py::cast(tmp);
// LLVM-IR -> HSA-CO
std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
asm_map["hsaco"] = py::cast(path);
return std::make_tuple(name, asm_map, n_shared_bytes);
}
// ---------------------------------------
void init_triton_codegen(py::module &&m) {
m.def(
"compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) {
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate
asm_map_t asm_map;
std::ostringstream ttir;
ir.print(ttir);
asm_map["ttir"] = py::cast(ttir.str());
llvm::LLVMContext ctx;
if(backend == CUDA)
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
if(backend == ROCM)
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
}, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::gil_scoped_release allow_threads;
if(backend == CUDA)
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
if(backend == ROCM)
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
}, py::return_value_policy::take_ownership);
m.def("compile_ttir",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
std::ostringstream ttir;
int n_shared_bytes;
std::string tmp;
std::string ptx;
std::string cubin;
std::string name;
{ // Scope where the GIL is released
py::gil_scoped_release allow_threads;
name = ir.get_function_list()[0]->get_name();
ir.print(ttir);
llvm::LLVMContext ctx;
// construct extern lib map
triton::codegen::ExternLibMap extern_lib_map;
for (auto item : extern_libs) {
auto name = item.first.cast<std::string>();
auto path = item.second.cast<std::string>();
extern_lib_map.emplace(
name, triton::codegen::create_extern_lib(name, path));
}
// device properties
if (cc == 0) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
cc = major*10 + minor;
}
int version;
std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> NVPTX LLVM-IR
triton::codegen::nvidia_cu_target target(cc);
auto llvm = triton::codegen::add_passes_to_emit_bin(
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
llvm::raw_string_ostream llir(tmp);
llir << *llvm;
llir.flush();
// LLVM-IR -> PTX
ptx = drv::llir_to_ptx(llvm.get(), cc, version);
// PTX -> Binary
cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
}
asm_map_t asm_map;
asm_map["ttir"] = py::cast(ttir.str());
asm_map["llir"] = py::cast(tmp);
asm_map["ptx"] = py::cast(ptx);
if(!cubin.empty()){
py::bytes bytes(cubin);
asm_map["cubin"] = bytes;
}
return std::make_tuple(name, asm_map, n_shared_bytes);
},
py::return_value_policy::take_ownership);
// ---------------------------------------
// Load provided assembly code into driver
// ---------------------------------------
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
py::gil_scoped_release allow_threads;
// create driver handles
CUfunction fun;
CUmodule mod;
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
// get allocated registers and spilled registers from the function
int n_regs = 0;
int n_spills = 0;
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
if(n_shared_bytes > 49152 && shared_optin > 49152){
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
}
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
},
py::return_value_policy::take_ownership
);
struct InstanceDescriptor
{
std::unordered_set<int> divisibleBy16;
std::unordered_set<int> equalTo1;
};
py::class_<InstanceDescriptor>(m, "instance_descriptor")
.def(py::init<>())
.def(py::init<std::unordered_set<int>, std::unordered_set<int>>())
.def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16)
.def_readonly("equal_to_1", &InstanceDescriptor::equalTo1);
}
@@ -556,22 +550,30 @@ void init_triton_ir(py::module &&m) {
.value("CA", ir::load_inst::CA)
.value("CG", ir::load_inst::CG)
.export_values();
py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
.value("NORMAL", ir::load_inst::NORMAL)
.value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
.value("EVICT_LAST", ir::load_inst::EVICT_LAST)
.export_values();
py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
.value("ADD", ir::reduce_inst::ADD)
.value("FADD", ir::reduce_inst::FADD)
.value("MIN", ir::reduce_inst::MIN)
.value("MAX", ir::reduce_inst::MAX)
.value("UMIN", ir::reduce_inst::UMIN)
.value("UMAX", ir::reduce_inst::UMAX)
.value("ARGMIN", ir::reduce_inst::ARGMIN)
.value("ARGMAX", ir::reduce_inst::ARGMAX)
.value("ARGUMIN", ir::reduce_inst::ARGUMIN)
.value("ARGUMAX", ir::reduce_inst::ARGUMAX)
.value("FMIN", ir::reduce_inst::FMIN)
.value("FMAX", ir::reduce_inst::FMAX)
.value("ARGFMIN", ir::reduce_inst::ARGFMIN)
.value("ARGFMAX", ir::reduce_inst::ARGFMAX)
.value("XOR", ir::reduce_inst::XOR);
py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
.value("ADD", ir::atomic_rmw_op_t::Add)
.value("FADD", ir::atomic_rmw_op_t::FAdd)
@@ -588,13 +590,13 @@ void init_triton_ir(py::module &&m) {
.def(py::init<>());
py::class_<ir::value>(m, "value")
.def("multiple_of", [](ir::value *self, int val) {
.def("multiple_of", [](ir::value *self, std::vector<unsigned> val) {
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
instr->set_metadata(ir::metadata::multiple_of, val);
} else
throw std::runtime_error("multiple_of");
})
.def("max_contiguous", [](ir::value *self, int val) {
.def("max_contiguous", [](ir::value *self, std::vector<unsigned> val) {
if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
instr->set_metadata(ir::metadata::max_contiguous, val);
} else
@@ -674,6 +676,7 @@ void init_triton_ir(py::module &&m) {
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("is_struct", &ir::type::is_struct_ty)
.def("is_void", &ir::type::is_void_ty)
.def("is_bool", &ir::type::is_bool_ty)
.def("is_fp8", &ir::type::is_fp8_ty)
@@ -699,23 +702,41 @@ void init_triton_ir(py::module &&m) {
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference)
.def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type");
py::class_<ir::function_type, ir::type>(m, "function_type")
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
.def_property_readonly("arg_tys", [](ir::function_type* self){
return std::vector<ir::type*>(self->params_begin(), self->params_end());
});
py::class_<ir::integer_type, ir::type>(m, "integer_type");
py::class_<ir::block_type, ir::type>(m, "block_type")
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
py::class_<ir::module>(m, "module")
py::class_<ir::struct_type, ir::type>(m, "struct_type")
.def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
py::class_<ir::module>(m, "module", py::dynamic_attr())
.def(py::init<std::string, ir::builder &>())
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
const auto metadatas = self->get_metadatas();
auto it = metadatas.find(name);
if (it != metadatas.end())
if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
instr->set_metadata(it->second.first, it->second.second);
}
.def("has_function", &ir::module::has_function)
.def("get_function", &ir::module::get_function, ret::reference)
.def("get_functions", &ir::module::get_function_list, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("print", [](ir::module *self) {
self->print(std::cout);
})
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference);
.def("reset_ret_ty", &ir::module::reset_ret_ty)
.def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
const auto metadatas = self->get_metadatas();
auto it = metadatas.find(name);
if (it != metadatas.end())
if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
instr->set_metadata(it->second.first, it->second.second);
}
})
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;
py::enum_<eattr>(m, "attribute_kind")
@@ -728,17 +749,23 @@ void init_triton_ir(py::module &&m) {
.value("not_implemented", eattr::not_implemented);
py::class_<ir::attribute>(m, "attribute")
.def(py::init<eattr, int>());
.def(py::init<eattr, int>())
.def_property_readonly("value", &ir::attribute::get_value);
py::class_<ir::function>(m, "function")
.def_property_readonly("args", &ir::function::args)
.def_property_readonly("attrs", &ir::function::attrs)
.def("add_attr", &ir::function::add_attr);
.def("set_is_kernel", &ir::function::set_is_kernel)
.def("add_attr", &ir::function::add_attr)
.def("has_attr", &ir::function::has_attr)
.def("get_attrs", &ir::function::get_attributes);
py::class_<ir::argument, ir::value>(m, "argument");
py::class_<ir::argument, ir::value>(m, "argument")
.def_property_readonly("parent", &ir::argument::get_parent, ret::reference)
.def_property_readonly("arg_no", &ir::argument::get_arg_no);
py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference)
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
.def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
.def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
ir::basic_block::iterator it = self->get_first_non_phi();
@@ -748,14 +775,19 @@ void init_triton_ir(py::module &&m) {
}, ret::reference)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder::iterator>(m, "bb_iterator");
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
// control flow
.def("call", &ir::builder::create_call, ret::reference)
.def("launch", &ir::builder::create_launch, ret::reference)
.def("br", &ir::builder::create_br, ret::reference)
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
.def("ret", &ir::builder::create_ret, ret::reference)
// insertion block/point, insert points are represented as (*bb, *instr)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
@@ -802,6 +834,8 @@ void init_triton_ir(py::module &&m) {
.def("create_br", &ir::builder::create_br, ret::reference)
.def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
.def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
// Dequantize instructions
.def("create_dequantize", &ir::builder::create_dequantize, ret::reference)
// Cast instructions
.def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
.def("create_cast", &ir::builder::create_cast, ret::reference)
@@ -814,6 +848,8 @@ void init_triton_ir(py::module &&m) {
.def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
.def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
.def("create_downcast", &ir::builder::create_downcast, ret::reference)
.def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference)
.def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
// phi
.def("create_phi", &ir::builder::create_phi, ret::reference)
// Binary instructions
@@ -823,27 +859,27 @@ void init_triton_ir(py::module &&m) {
.def("create_frem", &ir::builder::create_frem, ret::reference)
.def("create_fadd", &ir::builder::create_fadd, ret::reference)
.def("create_fsub", &ir::builder::create_fsub, ret::reference)
.def("create_mul", &ir::builder::create_mul, ret::reference,
py::arg("lhs"), py::arg("rhs"),
.def("create_mul", &ir::builder::create_mul, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
.def("create_udiv", &ir::builder::create_udiv, ret::reference)
.def("create_srem", &ir::builder::create_srem, ret::reference)
.def("create_urem", &ir::builder::create_urem, ret::reference)
.def("create_add", &ir::builder::create_add, ret::reference,
py::arg("lhs"), py::arg("rhs"),
.def("create_add", &ir::builder::create_add, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_sub", &ir::builder::create_sub, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_shl", &ir::builder::create_shl, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_lshr", &ir::builder::create_lshr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
.def("create_ashr", &ir::builder::create_ashr, ret::reference,
py::arg("lhs"), py::arg("rhs"),
py::arg("lhs"), py::arg("rhs"),
py::arg("has_nuw")=false, py::arg("has_nsw")=false)
// GEP
.def("create_gep", &ir::builder::create_gep, ret::reference)
@@ -890,7 +926,11 @@ void init_triton_ir(py::module &&m) {
// atomic
.def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
.def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
// Utilities
.def("create_clock", &ir::builder::create_clock, ret::reference)
.def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
// Extern instruction
.def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference)
// Built-in instruction
.def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
.def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
@@ -903,6 +943,9 @@ void init_triton_ir(py::module &&m) {
.def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
.def("create_reduce", &ir::builder::create_reduce, ret::reference)
.def("create_select", &ir::builder::create_select, ret::reference)
// struct
.def("insert_value", &ir::builder::create_insert_value, ret::reference)
.def("extract_value", &ir::builder::create_extract_value, ret::reference)
// Intrinsics
// These have no place in the IR, and hopefully they can be removed at some point
.def("create_umulhi", &ir::builder::create_umulhi, ret::reference)

View File

@@ -128,7 +128,7 @@ elementwise_data = {
1024 * 16: 0.0219,
1024 * 64: 0.0791,
1024 * 256: 0.243,
1024 * 1024: 0.534,
1024 * 1024: 0.530,
1024 * 4096: 0.796,
1024 * 16384: 0.905,
1024 * 65536: 0.939,
@@ -152,7 +152,7 @@ def test_elementwise(N):
cur_mem_clock = nvsmi(['clocks.current.memory'])[0]
ref_mem_clock = mem_clocks[DEVICE_NAME]
max_gpu_perf = get_dram_gbps()
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memmory must run at {ref_mem_clock} MHz'
assert abs(cur_mem_clock - ref_mem_clock) < 10, f'GPU memory must run at {ref_mem_clock} MHz'
z = torch.empty((N, ), dtype=torch.float16, device='cuda')
x = torch.randn_like(z)
y = torch.randn_like(z)

Some files were not shown because too many files have changed in this diff Show More