323 Commits

Author SHA1 Message Date
Yang Hau
8650b4d1cb [DRIVER] Fix typos (#939) 2022-12-02 11:13:46 -08:00
Crutcher Dunnavant
44f577984d Fix format double substitution bug: {i} => {{i}} (#886)
The previous `{i}` was silently expanding to the `i` from the
enumeration loop on `regular_args` (when it wasn't empty).
2022-11-20 11:44:42 -08:00
Crutcher Dunnavant
0e4691e6dd [FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883)
Ran mypy over `build_extern.py`, cleaned up type annotations.

Found a fixed a bug where `ExternLibrary(format=)` was being ignored.
2022-11-17 18:45:30 +01:00
Natalia Gimelshein
0d7e753227 [TESTING] use torch.int for autotuning cache (#840)
For stupid reasons, ops on int8 are 3 times slower than on int, and for
another set of stupid reasons we are not using cudaMemset for `zero_`,
so using `int8` buffer in `do_bench` makes it slow.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-11-04 18:05:16 -07:00
Shintaro Iwasaki
77bc5187b5 Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.

## What this PR does

- Skip unsupported tests on P100.
  - Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.

Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).

### Results on P100 (Google Cloud)

```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```

<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>

### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
  OS: Ubuntu 18.04 LTS
  Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried.  I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed).  Needed to wait an hour for approval
```

### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"

## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools

## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'

## 4. Test
pytest test/unit
```

### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
```

</p></details>
2022-11-03 00:11:52 -07:00
Chenggang Zhao
f16138d447 [Frontend] Interface fixes for libdevice (#830)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
2022-11-01 10:51:58 -07:00
Mark Saroufim
578ada7740 [DOCS] Add install from source instructions to README (#821) 2022-10-31 11:08:18 -07:00
Phil Tillet
6311d70406 Revert "[BUILD] Now using cibuildwheel default"
This reverts commit 584086f08c.
2022-10-29 17:15:47 -07:00
Phil Tillet
584086f08c [BUILD] Now using cibuildwheel default 2022-10-29 16:59:06 -07:00
Keren Zhou
3ca667dfa8 [Frontend] Return a scalar if all input args are scalar (#816) 2022-10-28 23:27:06 -07:00
Yanbo Liang
5ca1ed0101 Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during
triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need
add missing triton support for bf16/fp16/fp64.
2022-10-24 19:41:25 -07:00
Keren Zhou
db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00
Twizzes
ddae106c0e [DOCS] Update installation.rst to fix windows build error (#747) 2022-10-13 13:27:15 -07:00
Keren Zhou
bc98aead33 [Backend] Fix for mov.u8 (#766)
Init a potential fix for mov.u8 which is not supported by ptx for now.
Use mov.u16 instead and cast it to u8.
2022-10-12 14:32:27 -07:00
Yu Guo
71b46acc42 [IR] Added special-purpose dequantize instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
2022-10-12 14:14:45 -07:00
Philippe Tillet
33e6f0df7f [DRIVER] Bumped CUDA requirement to 11.4+. This is to avoid bad performance surprises as older ptxas are much slower. (#769)
This also makes codegen simpler by avoiding special handling of eviction policies
2022-10-12 12:02:30 -07:00
Philippe Tillet
af76c989eb [RUNTIME] Make entry point cache key depend on triton version hash (#765) 2022-10-11 13:24:30 -07:00
Bin Bao
09cc2d454b [FRONTEND] Fix a bool tensor storing problem (#746) 2022-10-10 12:11:50 -07:00
Felipe Petroski Such
5d4b26d380 [RUNTIME] support multiple devices in the same process (#757) 2022-10-09 20:30:04 -07:00
Chris
9a11a567ce [DOCS] Fixed typos in 01-vector-add.py (#751) 2022-10-09 18:12:46 -07:00
Keren Zhou
11345e9b74 [RUNTIME] Add callback functions for external tools (#738) 2022-10-05 14:46:55 -07:00
Philippe Tillet
bdfdb9a1d2 [RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742) 2022-10-05 11:00:32 -07:00
shenggan
77c752dc78 [RUNTIME] remove fixed cu_include_dir (#739)
Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
2022-10-04 19:49:57 -07:00
Natalia Gimelshein
d3c925db8a [FRONTEND] properly broadcast scalar where condition (#736) 2022-10-04 12:44:03 -07:00
fdrocha
2b0f877fad [RUNTIME] Support environments with multiple cudalibs (#733) 2022-10-03 18:36:24 +00:00
Keren Zhou
4a2d3b7d79 [RUNTIME] Dump llvm, ttir, and sass to help debugging (#732) 2022-10-03 00:39:52 +00:00
Natalia Gimelshein
f55960e773 [FRONTEND] fix broadcasting for where (#729)
Fixes #532, all 3 inputs to where have to be broadcast together.
2022-10-01 13:18:47 -07:00
Phil Tillet
b244db06da [TUTORIALS] Attention tutorial fixup 2022-09-30 19:31:43 -07:00
Shintaro Iwasaki
7b61303ea1 [CODEGEN] Fix extract_N_bufferable in layout analysis (#728) 2022-09-30 12:21:22 -07:00
Shintaro Iwasaki
ae59f51c2d [CODEGEN] Fix an inliner to call a function with a phi-node (#727) 2022-09-29 21:36:40 -07:00
albanD
f45e31ba7c [FRONTEND] Make sure to hold the gil when creating python objects (#726)
Without this patch, a debug version of python complains that:
```
Fatal Python error: Python memory allocator called without holding the GIL
Python runtime state: initialized
```
2022-09-29 18:06:22 -07:00
Philippe Tillet
dad97528b2 [TESTING] allclose fixup (#724) 2022-09-28 22:49:05 +00:00
Jason Ansel
998fd5f9af [FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
2022-09-24 13:41:47 -07:00
Shintaro Iwasaki
3ac929b48b [BUILD] Download pybind11 in setup.py (#703)
Based on the discussion in #700, this PR enables downloading pybind11 in
`setup.py` without `git submodule` instead of copy-pasting pybind11
code. The downloaded pybind11 will be in `~/.triton/pybind` (like
`llvm`).
2022-09-23 15:54:07 -07:00
Jason Ansel
579c03615d [FRONTEND] Reduce number of compiles in JITFunction (#704)
I suspect this was the cause of the "new compiles even on a warm cache"
behavior I was seeing, though haven't 100% confirmed it.

Python `set()` iteration order is nondeterministic when you create a new
process. So the same args could produce different `instance_descriptor`s
and have false cache misses.
2022-09-23 21:44:52 +00:00
Philippe Tillet
25e1b36785 Revert "[pybind11] Use git-submodule for pybind11" (#701)
Reverts openai/triton#699
2022-09-23 12:25:38 -07:00
Shintaro Iwasaki
61d104ab3a [FRONTEND] Use git-submodule for pybind11 (#699)
This PR changes the `pybind11` source code management from copy-paste to
a package controlled by git-submodule.

See the discussion in #694 for details.
2022-09-23 09:55:03 -07:00
Philippe Tillet
8c3d4d5749 [RUNTIME] now decoupling entry point from cubin (#696) 2022-09-22 16:44:22 -07:00
Shintaro Iwasaki
df67068bb0 [pybind11] Update pybind11 to 2.10.0 (#691)
This PR updates the version of pybind11 to 2.10.0 (the latest stable).
2022-09-21 20:18:02 -07:00
Philippe Tillet
677ddae618 [FRONTEND] Add warmup for triton.jit() (#684)
This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object

Co-authored-by: Jason Ansel <jansel@jansel.net>
2022-09-21 19:13:20 +00:00
Jason Ansel
6abe813d1c Fix issue breaking cudagraphs (#685)
@ngimel figured this one out. 

The errors we were seeing from cudagraphs capture were coming from
`cuStreamGetCtx` which is not allowed while a stream is capturing.

It appears the result of `cuStreamGetCtx()` isn't even used, so I
believe it can just be removed.
2022-09-21 10:20:48 -07:00
Philippe Tillet
e318185eb4 [DOCS] Improved README.md wording (#683)
Initial wording dates from a time where nobody knew Triton, and
comparing it to CUDA helped differentiate it from other existing DSLs.
But nowadays this comparison doesn't make much sense; Triton is its own
thing, and some people may even still be more productive in CUDA than
Triton -- language preferences are subjective after all.
2022-09-20 18:09:43 -07:00
Philippe Tillet
7dc2a70edb Revert "Add .warmup() for triton.jit()" (#682)
Reverts openai/triton#671

It seems like for some reason this caused out-of-memory errors on some
of our internal workloads. I'm reverting this so that HEAD can be used
in production at OpenAI, and I will work on digging into this issue
asynchronously.
2022-09-20 16:05:14 -07:00
Philippe Tillet
48f30550f1 [FRONTEND] Now using raw compiler syscalls when possible (#678) 2022-09-19 21:01:36 -07:00
Jason Ansel
93b1adc53b [FRONTEND] Add .warmup() for triton.jit() (#671) 2022-09-18 23:09:34 -07:00
Phil Tillet
82956e5d6b [PACKAGING] Added missing package 2022-09-18 17:34:05 -07:00
Philippe Tillet
2baf333d44 [DOCS] Fixed typos (#670) 2022-09-18 17:13:12 -07:00
Jason Ansel
49f6bc3f2b [FRONTEND] Fix filename too long error in new runtime (#669) 2022-09-18 21:26:29 +00:00
Phil Tillet
00f4ef6958 [CI] wheel/docs workflows now only run on V100 machine 2022-09-18 13:28:35 -07:00
Jason Ansel
e647402fd3 Fix warning in generated C code (#667) 2022-09-18 12:57:32 -07:00
Philippe Tillet
4a77dfb042 [FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
2022-09-18 08:51:48 -07:00
Ian Bearman
889d9e34a1 [REPO] update gitignore (#666)
Update `.gitignore` to include `.vs` and `.vscode`
2022-09-17 14:25:28 -07:00
Shintaro Iwasaki
c668d6596e [DOCS] Fix spelling (#664)
This PR applies minor spelling fix in comments and string literals to
`master`. It shouldn't hurt anything.
2022-09-16 12:26:40 -07:00
Sophia Wisdom
4580a04710 [FRONTEND] Improve error message for CPU tensors (#654)
Redo of #651 against master. Fixes #525 by catching CUDA error when we
check pytorch tensor size and rethrowing a more informative error that
says why we failed.
2022-09-14 14:26:42 -07:00
Philippe Tillet
cfbbc7b43a [CI] Added V100 tag to disambiguate self-hosted runners (#653) 2022-09-14 13:47:50 -07:00
Yunxing Dai
59a8e25f43 [DOCS] Fix typo (#650) 2022-09-14 12:17:05 -07:00
Da Yan
437ced38c2 fp8 <> bf16 conversion (#637)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-08-30 14:20:12 -07:00
Da Yan
210a296699 [BACKEND] bf16 flash-attention (#636) 2022-08-26 20:40:55 -07:00
Daniil Fukalov
fe0c29b9ec Fix inconsistent struct declaration instead of class. (#632)
Looks like typo.
2022-08-26 16:20:21 -07:00
Phil Wang
7394d732ad [DOCS] support for variable head dimensions in flash attention triton tutorial (#623) 2022-08-15 19:16:49 -07:00
Da Yan
3e2953f357 Allow multiple_of and max_contiguous to accept n-d values (#617) 2022-08-10 09:59:32 -07:00
Daniil Fukalov
cc79376222 Fix deprectaion warning on CreateGEP(Value *, ArrayRef<Value *>, const Twine &) (#608)
This variant of CreateGEP() is already removed in LLVM 14.
2022-08-07 17:10:18 -07:00
Daniil Fukalov
7b91c7befd Fix "warning: control reaches end of non-void function". (#607) 2022-08-02 16:12:48 -07:00
Sharad Vikram
968f59027e Expose module.print in pybind (#604) 2022-07-29 21:36:08 -07:00
Anton Kostin
923d468187 Update LICENSE (#602) 2022-07-25 09:30:03 -07:00
Jason Ansel
027321cdcf [FRONTEND] Make tl.rand() 1-exclusive (#601) 2022-07-24 17:47:23 -07:00
Jason Ansel
e02e56dc63 [FRONTEND] Add missing rfloordiv (#598)
* [FRONTEND] Add missing rfloordiv

* fix tests
2022-07-23 21:54:12 -07:00
Philippe Tillet
ab56d310dd [BACKEND][IR] Fixed up internal dtype size for booleans (1bit -> 8bit) (#600) 2022-07-23 20:08:03 -07:00
Da Yan
f28caddbf8 [FRONTEND] Allow tl.where to select pointers (#595) 2022-07-21 09:54:27 -07:00
Keren Zhou
af85f5fa46 [FRONTEND] Refresh cache when the source code of outlined functions are changed (#590) 2022-07-20 17:34:07 -07:00
daadaada
9b2bc88d11 [BACKEND] Better bf16 support (#588) 2022-07-19 21:22:37 -07:00
Philippe Tillet
86cab58d89 [CI] Changed dev wheel date to UTC time to match CRON schedule (#587) 2022-07-18 14:54:13 -07:00
Phil Tillet
5b04331dd2 [TUTORIALS] Added more credits in fused attention tutorial 2022-07-13 23:48:58 -07:00
Jason Ansel
0a3f3d5f25 [PACKAGING] Include triton/language/libdevice.10.bc in package data (#582) 2022-07-13 23:45:27 -07:00
Keren Zhou
4912916c11 [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) 2022-07-13 15:52:21 -07:00
Phil Tillet
971f5782b4 [tutorials] Added flash attention credits in tutorial 2022-07-11 18:56:48 -07:00
Philippe Tillet
d5eb9bc230 [tutorial] Added bwd in fused attention example (#579)
Doesn't work on V100
2022-07-11 15:43:46 -07:00
Jason Ansel
c9a2b9c7d4 [FRONTEND] Add missing args to get_simd_tflops() (#578) 2022-07-11 14:37:59 -07:00
Philippe Tillet
4a399a7e40 [BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
2022-07-06 20:03:04 -07:00
vesuppi
22105bc33b [FRONTEND] Added type check in semantic arange (#572) 2022-07-03 15:25:37 -07:00
Keren Zhou
4bf509889b [BUILD] Change the default build type to Release (#571) 2022-07-01 12:17:22 -07:00
Keren Zhou
a74cce375f [FRONTEND] Raise broadcast error (#555) 2022-06-30 17:32:07 -07:00
Philippe Tillet
f733327ba4 [BACKEND][CODEGEN] Disabling L2 residency control by default (#570) 2022-06-29 17:05:13 -07:00
Natalia Gimelshein
1bbb2430d9 [TUTORIALS] adjust heuristics for dwdb kernel (#565) 2022-06-29 17:00:22 -07:00
Kashif Rasul
1895ceaa2d [TUTORIAL] Fix f-string for older python (#569)
fixes issue #568
2022-06-29 09:39:10 -07:00
Philippe Tillet
feb7a2a0dc [FRONTEND] Hotfix for store argument order (#567) 2022-06-28 00:24:02 -07:00
Philippe Tillet
5b4c8f221e [BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
2022-06-27 11:49:19 -07:00
Keren Zhou
87413bc925 [BACKEND] Fix layout convert for non-contiguous input (#564) 2022-06-25 23:12:03 -07:00
Keren Zhou
d345ddf837 [DOCS] Separate atomic cas from other atomic operations since operands are very different (#559) 2022-06-22 17:51:17 -07:00
Keren Zhou
b02bac41ba [CI] Change cache dir (#561) 2022-06-22 11:44:35 -07:00
Keren Zhou
a428cf0bb2 [FRONTEND] Fix pytorch warning. (#560)
UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc').
2022-06-20 20:12:09 -07:00
Keren Zhou
b5e728cb14 Add argmin argmax (#552) 2022-06-15 13:55:20 -07:00
Jason Ansel
6b9756532f [BACKEND] Remove print in coalesce.cc (#551) 2022-06-15 13:13:20 -07:00
Madeleine Thompson
8ce2c12e33 [PYTHON] move ephemeral files to homedir (#549)
This prevents potential conflicts with other users on shared machines.
2022-06-13 19:37:52 -07:00
Keren Zhou
93209c07e0 [BACKEND][CODEGEN] Fix reduce uint (#547) 2022-06-13 16:43:57 -07:00
Philippe Tillet
58c8889235 [FRONTEND] Fix scanline layout (#548) 2022-06-13 16:21:10 -07:00
Natalia Gimelshein
7094657aa9 [FRONTEND] fix bool conversion of floating types (#545) 2022-06-13 15:52:37 -07:00
Keren Zhou
38573d1261 [FRONTEND] Return allocated registers and spilled registers for users (#541) 2022-06-07 18:37:12 -07:00
Mengchi Zhang
2cdc6d35c4 [FRONTEND] Give col_per_thread an initial value to make the compiler happy (#535)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-06-06 12:48:23 -07:00
TC
f13cbaab9f [FRONTEND] assert that num_warps is a power of 2 (#539) 2022-06-06 11:37:08 -07:00
Philippe Tillet
751e325d2e [TUTORIALS] Fixed typo 2022-06-05 13:33:21 -07:00
Philippe Tillet
801c8a4c92 [TUTORIALS] Fixed typo 2022-06-05 12:32:07 -07:00
Philippe Tillet
8876e53206 [BACKEND] Restored reduction bugfixes 2022-06-03 11:38:52 -07:00
Philippe Tillet
a60374a597 Revert "[BACKEND] Various bug fixes; making reductions faster (#533)".
This is a more stable commit that produce bitwise identical code to earlier
versions. Using commits after this one may lead to slightly different numerics
2022-06-03 11:36:06 -07:00
Philippe Tillet
efa04cac1f [FRONTEND] A couple of bugfixes (#534) 2022-06-02 16:57:37 -07:00
Philippe Tillet
3e7500dfe6 [BACKEND] Various bug fixes; making reductions faster (#533) 2022-05-31 17:14:44 -07:00
Bert Maher
37037bb3be [FRONTEND] Default cache dir to /tmp/triton_$USER (#527) 2022-05-27 13:51:05 -07:00
Philippe Tillet
c82a206684 [FRONTEND] Better dot error message (#531) 2022-05-26 17:41:09 -07:00
Philippe Tillet
0e2883020a [BACKEND] Fixed typo in alignment analysis (#528) 2022-05-25 20:01:19 -07:00
Bert Maher
43fec2adca [FRONTEND] Add binding for create_int_to_ptr (#526) 2022-05-25 15:26:18 -07:00
Philippe Tillet
011bc83c1b [FRONTEND] For loops now promote initial value (#524) 2022-05-24 13:20:10 -07:00
Natalia Gimelshein
96bff90471 [FRONTEND] faster jit function launch (#523)
With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
2022-05-24 12:08:49 -07:00
daadaada
d5eaa8dfa0 Making the generated Triton IR deterministic & a script to compare cached assembly (#522) 2022-05-24 08:56:36 -07:00
Shantanu
80f6a2698b [FRONTEND] Ensure version_key is called at most once (#519)
Co-authored-by: hauntsaninja <>
2022-05-23 13:40:08 -07:00
daadaada
205a493b10 [FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)
Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
2022-05-21 09:45:54 -07:00
Jiabao Lei
abea3dc2c6 [FRONTEND] provide device kwargs && fix fstring error for py<3.8 (#515)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-05-14 16:21:46 -07:00
Philippe Tillet
d35617bea1 [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) 2022-05-14 15:26:13 -07:00
Mengchi Zhang
d1a22a94e6 [FRONTEND] Add empty return value and remove protect to open the access to contained_tys_vec_t (#514)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-05-13 11:46:12 -07:00
Jason Ansel
d954a05989 [FRONTEND] Handle torch.uint8 args (#513)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-05-12 13:07:39 -07:00
Philippe Tillet
0835a4fb05 [TUTORIALS] Removed #noformat in layer norm tutorial 2022-05-12 12:41:25 -07:00
Philippe Tillet
c736ba7c3e [TUTORIALS] Fixed formatting 2022-05-12 12:31:23 -07:00
Philippe Tillet
cd30a99aa2 [TUTORIALS] fixed formatting 2022-05-12 12:28:22 -07:00
Philippe Tillet
d87435e536 [TUTORIALS] Layer norm tutorial now uses residency control (#510) 2022-05-05 19:53:54 -07:00
Sriram Murali
7c9bc5a47b [CODEGEN] Change return type of generator::packed_type to appease build warnings (#507) 2022-05-04 20:03:37 -07:00
Philippe Tillet
95feb10ec9 [FRONTEND] fixup (#505) 2022-04-30 14:25:06 -07:00
Philippe Tillet
11a908655d [FRONTEND] Fixup 2022-04-29 14:35:09 -07:00
Phil Tillet
cd78ce4888 [FRONTEND] Improved error message when assigning None to non-constexpr 2022-04-29 09:17:54 -07:00
Philippe Tillet
ae2a1ab225 [BACKEND] Alignment pass improvements (#503) 2022-04-25 21:16:00 -07:00
Philippe Tillet
7d544799a0 [BACKEND] Now disabling L2 eviction policy for sm < 80 2022-04-25 09:35:36 -07:00
Philippe Tillet
3ca792043f [TEST] Added test for vectorization 2022-04-24 13:50:48 -07:00
Philippe Tillet
bda209002e [BACKEND][CODEGEN] vectorization bugfix (#502) 2022-04-23 13:18:33 -07:00
Philippe Tillet
0cc3b1129b [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) 2022-04-21 23:56:01 -07:00
Philippe Tillet
7d6c504e8d [TESTING] Added testing utilities for fixing clock and using cuda-memcheck (#500) 2022-04-21 22:40:10 -07:00
Philippe Tillet
073be1d2ee [FRONTEND] check that tensors have power-of-two number of elements (#499) 2022-04-14 19:30:02 -07:00
Philippe Tillet
5c7122004c [TUTORIALS] Tutorial shouldn't expose clock. Just removed it. 2022-04-14 17:33:44 -07:00
Philippe Tillet
dc4d40faec [FRONTEND] now mangle constexpr float containing "e-" 2022-04-14 10:26:48 -07:00
Philippe Tillet
25f6689508 [FRONTEND] rename current stream monkey patch (#495) 2022-04-13 11:45:55 -07:00
Philippe Tillet
76bfac9f15 [FRONTEND] Improved constexpr handling (#493) 2022-04-12 00:02:54 -07:00
Philippe Tillet
14b0fd4cfb [FRONTEND] Added possibility for users to customize current stream query (#492) 2022-04-07 12:11:32 -07:00
Philippe Tillet
6424771f55 [CI] Documentation fixup 2022-04-07 09:42:35 -07:00
Philippe Tillet
9f08ecd684 [FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
2022-04-06 16:13:53 -07:00
Philippe Tillet
2bed6fc850 [LANG] Added support for device functions (#484) 2022-04-03 20:58:16 -07:00
apd10
e85c7a7fc7 Bugfix in ptxas path. (#487)
Bug: "ret" value is destroyed when a failing "ptxas --version" is run
overwriting the previous valid "ret" value.

Fix: keep rets only for those runs which are successful. Pick the first
one
2022-03-30 20:45:41 -07:00
Philippe Tillet
bace26143d [TUTORIALS] Removed leftover print 2022-03-28 16:53:23 -07:00
Philippe Tillet
e0cc488055 [FRONTEND] Added tl.clock and tl.globaltimer (#485) 2022-03-28 16:15:43 -07:00
Philippe Tillet
76a9ee50a8 Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)
This reverts commit 539961072c.
2022-03-24 17:16:50 -07:00
Philippe Tillet
ea6d1f1b85 [DRIVER] LLVM driver fixup (#482)
Current way of doing things is probably not super thread safe. init is shared between threads and some threads my not call the LLVMInitialize* function.
2022-03-23 00:24:45 -07:00
Keren Zhou
a4f68165cd [FRONTEND] Hot fix for lineno (#481)
Override __reduce__ to make CompilationError pickable and print out error messages
2022-03-22 22:09:49 -07:00
daadaada
539961072c [FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py
Integer signedness now moved from C++ to python
Cleaner frontend type

Co-authored-by: Phil Tillet <phil@openai.com>
2022-03-16 21:25:30 -07:00
Yongjik Kim
0dd2ec2e3a [FRONTEND] Add an assert in case we get a CPU tensor. (#478) 2022-03-16 14:38:56 -07:00
Philippe Tillet
d4d8eaf6c0 [FRONTEND] improved caching mechanism (#474)
Co-authored-by: Greg Brockman <gdb@gregbrockman.com>
Co-authored-by: Christopher Hesse <christopherhesse@users.noreply.github.com>
2022-03-15 12:20:51 -07:00
Doğukan Tuna
21f8a0646d [DOCS] Minor README.md (#470)
Added binary distribution for quick installation
2022-03-05 00:50:37 -08:00
Philippe Tillet
a50a47a85b [CODEGEN] Reverted some changes from previous PR; fixed vectorization characteristics of mma layout (#469) 2022-03-04 01:53:31 -08:00
Philippe Tillet
bb5765df5c [CODEGEN] Now padding shared memory for layout conversion (#468) 2022-03-03 22:19:05 -08:00
daadaada
d9dd97492f Use unique_ptr in ir::context_impl (#462)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-02-24 16:07:10 -08:00
Philippe Tillet
98ed7db8c1 [CODEGEN] Improvements and bugfixes (#463) 2022-02-24 14:56:24 -08:00
daadaada
a9dfdcaaa9 [FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456) 2022-02-11 22:34:42 -08:00
Philippe Tillet
9b100302d3 [FRONTEND] Now using pybind11 to release GIL (#458) 2022-02-10 01:57:39 -08:00
Philippe Tillet
40093a9878 [DOCS] Multiple versions are now supported (#457) 2022-02-09 01:32:41 -08:00
Philippe Tillet
4941bc7001 [DOCS] Some more fixes (#455) 2022-02-08 16:53:56 -08:00
Philippe Tillet
2fdf0a4fe8 [DOCS] changed build command 2022-02-08 11:45:21 -08:00
Philippe Tillet
077d6c8ff0 [DOCS] re-activated tutorials 2022-02-08 11:42:39 -08:00
Philippe Tillet
822ddcd14b [DOCS] Added versioning (#453) 2022-02-08 11:28:18 -08:00
Philippe Tillet
7b48340ffd [CI] Some fixes for the build (#451) 2022-02-06 19:11:33 -08:00
Philippe Tillet
5a8a544d10 [OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code 
* at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks.
 * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros
 * blocksparse softmax now no longer modifies any data in-place
 * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention.
  * unit tests now test backward pass
2022-02-06 18:00:45 -08:00
Philippe Tillet
69ff52ea1f [CODEGEN] removed buggy (and mostly useless) optimization in peephole pass (#449) 2022-02-05 21:37:23 -08:00
TC
137bb67fad [LANG] Add fp16 to fp8 conversion (#444) 2022-02-02 20:42:09 -08:00
Philippe Tillet
3b20170fa3 Merge pull request #448 from openai/v2.0
`v2.0` is now merged into `master`
2022-01-30 20:49:08 -08:00
Philippe Tillet
b0d6e2f322 [STYLE] run autopep 2022-01-30 20:27:44 -08:00
Philippe Tillet
2922dc141c Merge branch 'master' into v2.0 2022-01-30 20:25:01 -08:00
Philippe Tillet
807d8a1945 [ALL] Merge master (#447) 2022-01-30 20:21:20 -08:00
Philippe Tillet
bef76b142a [BACKEND] float division is now approximate by default (#446) 2022-01-29 18:29:29 -08:00
Philippe Tillet
bd52e530a0 [OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445) 2022-01-28 21:40:30 -08:00
daadaada
e68d6a7776 [BACKEND] Making the warp-level tile "more square" to increase data-reuse for tl.dot. (#442)
* Increase smem data-reuse for some layouts

* tweak

* Keep the original tiling logic for sm < 80

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-01-27 09:59:54 -08:00
daadaada
59d371c6eb [BACKEND] Added Int8 mma (#440) 2022-01-27 09:12:44 -08:00
Benjamin Lefaudeux
3a23c1dd33 [BACKEND] minor, hotfix for gcc compilation (#439) 2022-01-23 14:24:02 -08:00
Philippe Tillet
ccf9abe0ba [FRONTEND][RANDOM] Improved backward compatibility of RNG (#438)
The unsigned int PR definitely improved our RNG. However, it requires
different floating point arithmetics which, means the results are not
bit-wise identical to how they were before. This commit revives backward
compatibility, but we should change it back to the "right" way later.
2022-01-21 18:05:55 -08:00
Philippe Tillet
4c97d1ecd7 [FRONTEND] Bunch of fixes here and there (#436) 2022-01-20 10:55:59 -08:00
Philippe Tillet
e0c5709cc8 [FRONTEND] Fixed semantics bug on ptr to bool conversions (#432) 2022-01-17 18:00:03 -08:00
daadaada
2a944ded53 [TESTS] Added bfloat16 tests (#430) 2022-01-13 23:38:32 -08:00
Philippe Tillet
4c94359199 [FRONTEND] Alignment fix-up (#428) 2022-01-11 23:11:58 -08:00
Philippe Tillet
bbc78f6516 [FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427) 2022-01-11 11:08:49 -08:00
Botao Yu
bf32205edc [OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425) 2022-01-11 11:07:16 -08:00
daadaada
94a2e10fe5 [BACKEND] Add bf16 & tf32 mma supports (on A100) (#426) 2022-01-11 10:20:31 -08:00
Madeleine Thompson
efdabe6073 [STYLE] check python with flake8 (#424)
I've been using this locally to find errors without running tests, and now that we're using autopep8, it passes with minimal suppressions. This is also what turned up the issues with the tutorials, which were fixed in #422.
2022-01-07 15:28:36 -08:00
Madeleine Thompson
a70acfec77 [STYLE] add isort and autopep8 config files and check on CI (#423)
Also a fix a few more style issues from the "aggressive" mode of autopep8.
2022-01-07 13:11:34 -08:00
Madeleine Thompson
9801aa7b56 [DOCS] fix tutorials for v2.0 (#422)
- Fix meta-parameter usage on tutorials.
- Install tutorial dependencies on CI.
- Switch from `requirements-test.txt` to `extras_require` for test dependencies, and also use it for tutorial dependencies.
- Make some performance tests deterministic.
2022-01-07 12:34:38 -08:00
Madeleine Thompson
8bf551ae7a [STYLE] run autopep8 and isort (#421)
Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
2022-01-06 14:34:17 -08:00
Shantanu
6f7acad48f [CODEGEN] Avoid use of deprecated AST nodes (#418)
Co-authored-by: hauntsaninja <>
2022-01-06 12:04:33 -08:00
Madeleine Thompson
120cda015e [FRONTEND] use unsigned integers to simplify RNG (#417) 2022-01-06 10:49:09 -08:00
Philippe Tillet
001fb757fe [OPS][BLOCKSPARSE] Added .contiguous() in blocksparse inputs when necessary (#420) 2022-01-06 09:56:22 -08:00
Madeleine Thompson
0ab9d67bad uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
2022-01-05 15:27:17 -08:00
Madeleine Thompson
d8db0308cb [TEST] use numpy for reference results in test_core.py (#409)
Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support.

This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency.
2022-01-04 13:07:29 -08:00
Philippe Tillet
03f1256f60 [FRONTEND] Added volatile flag for load (#407) 2021-12-30 22:33:24 -08:00
Noah Ziems
3edc2633e9 [TUTORIALS] Fix 01-vector-add.py typo (#406) 2021-12-29 15:09:34 -08:00
Madeleine Thompson
985798f101 add missing bfloat16 repr and improve assertions (#403)
- `BF16TyID` was missing a repr implementation.
- Throw a better exception on impossible casts.
- Add a few assertions. Tested with a debug build.
- Add `pointer_dtype.__str__` to aid kernel debugging.
2021-12-23 17:01:17 -08:00
Philippe Tillet
d8fce83e7a [FRONTEND] Remade exception picklable 2021-12-21 22:14:06 -08:00
Philippe Tillet
a425f24d54 [FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
2021-12-21 21:29:47 -08:00
Philippe Tillet
2509124dd0 [DRIVER] Fixed some issue with how ptxas is used (#399)
Now using tmpnam and properly deleting temporaries when an exception is raised
2021-12-21 14:31:51 -08:00
daadaada
39d4bfed83 [OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
2021-12-21 09:56:10 -08:00
Madeleine Thompson
5cdb948c05 [FRONTEND] signed-integer math fixes and testing (#395)
- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway.
- Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now.
- When testing signed integer operations, use random numbers from the full range of the type.
- Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type.
- Fix a bad `CompilationError` invocation.
- Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
2021-12-21 09:46:05 -08:00
daadaada
4a8953efa3 [FRONTEND] Replace the legacy print call in triton.cc with the SlotTracker-based one. (#396)
The legacy print call will assign names (e.g., %10) to values, which can be undesirable in some cases.
2021-12-18 18:03:22 -08:00
Madeleine Thompson
fa62b4a8f6 [FRONTEND] better stringification (#394)
- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace.
- Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
2021-12-17 20:11:45 -08:00
Philippe Tillet
4e93b41c52 [GENERAL] Some minor fixups (#393)
* [RUNTIME] Now displaying error message when generated PTX is invalid

* [CODEGEN] Now converting `if` condition to bool implicitly
2021-12-17 18:06:21 -08:00
Philippe Tillet
e062812969 [CODEGEN] Disabled peephole for masked load + select -- masked_load
doesn't work as expected when vectorized
2021-12-17 12:44:47 -08:00
Victor
eb077fc993 [RUNTIME] fixed NVidia DLL names on Windows (#392) 2021-12-16 22:09:52 -08:00
Philippe Tillet
e0b92c1380 [FRONTEND] Reverted from .random import *. There are still some
namespace errors in the Triton frontend apparently
2021-12-16 18:37:51 -08:00
Philippe Tillet
558555630f [FRONTEND] Added xor_sum 2021-12-16 17:55:35 -08:00
Madeleine Thompson
e575ae3443 [FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace.
- Make an undeclared dependency on `pytest` explicit.
- Fix deprecated `description-file` use.
- `#ifdef` out a deprecated `PyEval_InitThreads` call.
- Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests.
- Fix a deprecated cast in `test_core.py`.
- Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7.
- Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer.
- Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred.
- Fix a few bad escapes.
2021-12-10 15:19:20 -08:00
Philippe Tillet
9def2424ab [RUNTIME] Fix typo in IfExp 2021-12-09 15:14:41 -08:00
Philippe Tillet
e31b9b4e66 [RUNTIME] Better support for None (#387)
* regression test fails but it doesn't make sense to me.
2021-12-09 13:21:22 -08:00
Victor
73b04d71b2 Fixes for building on Windows (#382)
* make C++ code compatible with Windows + MSVC

* added dlfcn-win32 for cross-platform dlopen

* fixed building and pip install on Windows

* fixed shared library file name under Windows
2021-12-07 14:10:58 -08:00
Victor
0ff1a26b70 fixed p2p tests failing when there are no supported p2p devices (#386) 2021-12-06 18:14:03 -08:00
Philippe Tillet
f23bf55f15 [RUNTIME] release the gil on launch (#383) 2021-12-03 13:01:01 -08:00
Philippe Tillet
8ec9f037bb [BACKEND/CODE_GEN] Fixed float32 matmul problem (#380) 2021-11-30 22:00:56 -08:00
Philippe Tillet
c86ad9c9ab [FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379) 2021-11-29 19:11:26 -08:00
daadaada
1296eb877b [RUNTIME] Config hook v2.0 (#373)
* Add pre_hook to triton.Config
* Use argument names in triton.heuristics
* Update base perf
* Remove meta from heuristics
2021-11-21 11:20:59 -08:00
Philippe Tillet
5693b582ea [RUNTIME] Now using pybind11 to avoid memory leaks (#377) 2021-11-21 02:30:22 -08:00
Philippe Tillet
edd4b0c8b7 [CODEGEN] Fixed issue with jit function passed as constexpr 2021-11-16 09:53:34 -08:00
Philippe Tillet
5b7ba3eb96 [CODEGEN] Reverted to old launch method (memory leak?) 2021-11-16 01:21:03 -08:00
Philippe Tillet
791b953b21 [CODEGEN] Reverted to old way to query current stream 2021-11-16 00:17:27 -08:00
Philippe Tillet
b908095872 [VERSION] Bumped triton.__version__ to 2.0.0 2021-11-12 15:10:36 -08:00
Philippe Tillet
01cc3d4503 [RUNTIME] Restored do_not_specialize (#374) 2021-11-12 15:06:55 -08:00
Philippe Tillet
e66bf76354 [RUNTIME] Bunch of bugfixes (#372) 2021-11-12 00:55:00 -08:00
Philippe Tillet
f7ab96cfd7 [FRONTEND] Fixed some issues with constexpr 2021-11-09 13:03:09 -08:00
daadaada
9a02dddf29 Fix sdd_lut (#368) 2021-11-08 08:25:05 -08:00
Philippe Tillet
5d54352164 [FRONTEND] Significantly reduce kernel launch time (#367) 2021-11-04 13:25:24 -07:00
Philippe Tillet
2acaa4d0dd [LANG] Added support for constexpr (#361) 2021-10-30 00:32:58 -07:00
Philippe Tillet
b7f0e87dc2 [DRIVER] Removed std::cout log message 2021-10-29 10:42:10 -07:00
Philippe Tillet
770ea96cca [PACKAGING] Bumped dev version to 2.0.0 2021-10-29 01:28:17 -07:00
Philippe Tillet
969d6de8a2 [PACKAGING] Bumped dev version to 1.1.2 2021-10-29 01:26:21 -07:00
Philippe Tillet
2d6df9b518 [PACKAGING] Bumped dev version to 1.1.2 2021-10-29 01:24:19 -07:00
Philippe Tillet
1b842f8e5e [CI] Now running integration tests on pull requests on branch v2.0 2021-10-29 01:11:12 -07:00
Philippe Tillet
d3e584d4ba Revert "[DRIVER] Fixed CUDA 10.1 bug (#357)" (#358)
This reverts commit d35014ba47.
2021-10-26 15:04:49 -07:00
Philippe Tillet
d35014ba47 [DRIVER] Fixed CUDA 10.1 bug (#357) 2021-10-26 11:17:06 -07:00
Philippe Tillet
5ce1b726dc [CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356) 2021-10-24 02:30:46 -07:00
daadaada
858dec8372 [CODEGEN] Add cache modifier to tl.load (#351)
* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
2021-10-17 22:14:04 -07:00
Philippe Tillet
90ded16c32 [DOCS] Added placeholder docstring for layernorm tutorial 2021-10-15 19:04:01 -07:00
Philippe Tillet
abbc554838 [VERSION] Bumped version to 1.1.1 (#350) 2021-10-14 18:09:39 -07:00
Philippe Tillet
9b32075062 [CODEGEN] Some compiler improvements (#349) 2021-10-13 17:49:39 -07:00
Stephen McGroarty
c2e6b90ff1 [CODEGEN] Fixes masked load exception (#342) 2021-10-13 13:31:52 -07:00
Philippe Tillet
bfacc191b3 [FRONTEND] Now cache re-compiles when language changes (#348) 2021-10-13 12:29:57 -07:00
Shantanu
f5ad168686 [PYTHON] Fix up __version__ (#345)
Co-authored-by: hauntsaninja <>
2021-10-13 00:09:00 -07:00
Philippe Tillet
c3c0ff0552 [LANGUAGE] Fixed issue with duplicates in large arrays of random uniform numbers (#338) 2021-10-10 15:22:34 -07:00
daadaada
9e9d781912 [CODEGEN] Pipeline fixup (#336) 2021-10-10 01:47:11 -07:00
daadaada
d5f20dbce0 [IR] Fix error when building in debug mode (#331) 2021-10-08 21:40:20 -07:00
Philippe Tillet
d4baad426d [DOCS] Added layer norm example (#326) 2021-10-08 11:02:10 -07:00
Philippe Tillet
5123db0b7d [LANG] Various (relatively minor) improvements (#320) 2021-10-04 18:39:40 -07:00
Min Xu
12b6158c5c [DOCS] Minor fix (#317)
Co-authored-by: Min Xu <min.xu.public@gmail.com>
2021-09-30 17:33:08 -07:00
Philippe Tillet
b352b16567 [DOCS] Installation documentation now doesn't suggest to run regression
tests
2021-09-29 18:32:33 -07:00
Philippe Tillet
d132b7442b [DOCS] Minor README edits 2021-09-28 00:39:33 -07:00
Philippe Tillet
44442db96e [VERSION] Bumped to 1.1 (#313) 2021-09-28 00:25:42 -07:00
Philippe Tillet
bfcfad7abe [FRONTEND] Disable P2P (#312) 2021-09-27 21:18:27 -07:00
Philippe Tillet
2c287544cb [OPS] Faster and cleaner block-sparse implementation (#311) 2021-09-27 18:25:16 -07:00
Philippe Tillet
c3756d1c33 [FRONTEND] Add do_not_specialize to triton.jit to prevent specialization of kernel argument (#309) 2021-09-24 20:27:10 -07:00
Philippe Tillet
83da3febf2 [FRONTEND] Added simple hook for when something is written to the cache (#308) 2021-09-23 22:23:17 -07:00
Shantanu
0735061fce [FRONTEND] fix for unpickleable keys (#307)
In #306, I added the key to the cache data, so we can introspect to
investigate cache misses. Unfortunately, the key isn't pickleable,
so just add the str version instead.

Co-authored-by: hauntsaninja <>
2021-09-23 21:23:59 -07:00
Shantanu
2066ccd87e [FRONTEND] single file caches (#306)
Co-authored-by: hauntsaninja <>
2021-09-23 20:21:19 -07:00
Philippe Tillet
e22d92c63c [RUNTIME] removed obsolete putenv call (#305) 2021-09-23 17:51:58 -07:00
Shantanu
87f8d9f163 [PYTHON] Fix up __version__ (#304)
This should match setup.py

Co-authored-by: hauntsaninja <>
Co-authored-by: Philippe Tillet <phil@openai.com>
2021-09-23 17:36:33 -07:00
Philippe Tillet
ec2e7b8f48 [CODEGEN] Fixed nasty bug in coalesce pass (#303) 2021-09-23 17:05:11 -07:00
Shantanu
d253eb8719 [FRONTEND] Add cache_version to triton.jit (#301) 2021-09-23 16:45:54 -07:00
Philippe Tillet
5211f23a63 [FRONTEND] updated TensorWrapper (#299) 2021-09-22 13:53:27 -07:00
Philippe Tillet
2849e7a773 [CODEGEN] now re-coalescing before atomics (#298) 2021-09-22 13:35:53 -07:00
Philippe Tillet
41dbaf3b3f [FRONTEND] Fixed typo in cache for .dumb db (#296) 2021-09-21 17:03:41 -07:00
Philippe Tillet
c151e0f6aa [FRONTEND] Simplified detection of corrupted cache (#295) 2021-09-21 16:36:24 -07:00
Philippe Tillet
e96edc16ff [FRONTEND] Compute cache now supports atomic writes (#294)
Note that killing a Triton process while it updates the cache will result in the cache being wiped out. This is because copying a whole `db` to a temporary file can be quite expensive on some systems.
2021-09-21 14:10:02 -07:00
Benjamin Lefaudeux
b53f5f3803 [OPS][BLOCKSPARSE] safeguarding a couple more configurations (#292) 2021-09-20 17:15:31 -07:00
Philippe Tillet
a12827848d [FRONTEND] Now using exist_ok=True when creating cache directories (#288) 2021-09-18 23:44:21 -07:00
Philippe Tillet
6e5b0b4301 [FRONTEND] Added on-disk cache for compiled kernels (#287) 2021-09-18 22:48:26 -07:00
Benjamin Lefaudeux
bd855ac13d [DOCS] Adding some doc on the benchmarks + requirements file (#285) 2021-09-18 16:37:30 -07:00
Philippe Tillet
313d6488f6 [CODEGEN] Fixed over-aggressive division handling in alignment pass (#280) 2021-09-15 00:40:17 -07:00
Philippe Tillet
da5063d898 [TEST] Added performance regression tests (#283) 2021-09-14 01:46:32 -07:00
Philippe Tillet
8fdd7e7ed6 [LANG] Fixed semantics of boolean load/store (#282) 2021-09-13 17:39:06 -07:00
Philippe Tillet
3e395bc84e [LANG] Fixed semantics of NaN in float comparisons (#281) 2021-09-13 15:06:29 -07:00
Min Xu
cecca90bea [DOCS] update installation doc and add gitignore (#279)
Co-authored-by: Min Xu <min.xu.public@gmail.com>
2021-09-12 21:11:45 -07:00
Philippe Tillet
4163d32c49 [DOCS] Fixed leftover exit() in 01-vector-add tutorial 2021-09-10 15:52:26 -07:00
Philippe Tillet
34369906b4 [PYTHON] Fix-up the previous commit 2021-09-10 11:13:25 -07:00
Philippe Tillet
ac10551d55 [PYTHON] Now providing triton.next_power_of_2 (#273) 2021-09-10 11:05:44 -07:00
Philippe Tillet
43723ccb95 [FRONTEND] Removed circular import that broke Python 3.6 support (#272) 2021-09-09 13:46:55 -07:00
Philippe Tillet
585e5cd0ec [TEST] Added test for empty kernel (#271) 2021-09-09 10:20:37 -07:00
Philippe Tillet
94c83d30ce [GENERAL] Removed deprecated driver files and added basic compatibility with rocm (#268)
- Removed driver module -- accelerator runtime is handled by pytorch
- Added basic support for ROCM based on @micmelesse 's PR -- now can execute empty kernel on AMD devices without any compile-time changes
- Now only using PREFER_SHARED for kernels when the size of shared memory is greater than 49k. Otherwise there can be poor L1 performance for broadcast tensors
2021-09-09 00:04:28 -07:00
Szymon Sidor
8bedcce9be [LANG] Added seeded random number generation - philox (#261) 2021-09-02 22:02:40 -07:00
Philippe Tillet
c069ef907e [PYTHON] triton.language is now a submodule rather than a single file (#260) 2021-09-02 13:30:14 -07:00
Philippe Tillet
8a882b215f [CODEGEN] Fixed performance regression on vectorized loads (#259) 2021-09-02 01:07:31 -07:00
Philippe Tillet
768e0ded28 [CODEGEN] Fixed bug in pipelining pass and casting semantics analysis (#257) 2021-09-01 20:58:47 -07:00
Rohit Dwivedula
c0daffc625 [DOCS] @heuristics -> @triton.heuristics in some snippets (#253) 2021-09-01 18:50:17 -07:00
daadaada
274d613488 [IR] Better printer (#256) 2021-09-01 09:55:12 -07:00
Philippe Tillet
4ff3714d61 [CODEGEN] Various bugfixes and stability improvements in compiler backend (#240) 2021-08-30 11:50:35 -07:00
daadaada
85426dbaf7 [DOCS] Add comments in layout.h (#249) 2021-08-28 18:07:32 -07:00
milesial
5b29da719d [DRIVER] Add CUDA P2P support (#209) 2021-08-20 21:00:54 -07:00
Sasank Chilamkurthy
6aa5720d75 [DOCS] use numel for num_elements in elementwise tutorial (#228) 2021-08-19 19:35:12 -07:00
Philippe Tillet
f26a48a3b4 [DOCS] Various improvements (#224)
- Added docstr for autotune, Config, heuristics
- Added docstr for atomics
- Hiding internal _builder argument used for built-in language primitives
- Re-factor docstr to use common templates between similar functions.
2021-08-18 11:15:53 -07:00
Philippe Tillet
226fde6ea1 [CODEGEN] Now using atomic_rmw code path for atomic_xchg (#222) 2021-08-17 16:33:23 -07:00
Philippe Tillet
64b8e7222d [LICENSE] Edit copyright notice (#219) 2021-08-17 09:25:19 -07:00
Philippe Tillet
a714b6b856 [PYTHON] re-activated auto-tuner configurations for triton.ops.matmul (#212) 2021-08-16 22:56:21 -07:00
Philippe Tillet
bb1eebb4b4 [CODEGEN] Fixed bug for visit_reduce1d with 64-bit data-types (#207) 2021-08-14 21:07:01 -07:00
Philippe Tillet
6e7593b446 added reset_to_zero in vector addition (#205) 2021-08-14 10:58:38 -07:00
Philippe Tillet
c45c2e9684 [DOCS] Added docs for cos/sin/sqrt (#204) 2021-08-14 10:34:07 -07:00
Philippe Tillet
c7a272cb91 [FRONTEND] Added default arguments for range (#203) 2021-08-14 10:11:18 -07:00
Philippe Tillet
b120d70a0a [CI] Moved from assert_allclose to assert_almost_equal (#200) 2021-08-12 12:00:30 -07:00
Philippe Tillet
70e28ff380 [DOCS] Minor modifications of the matmul tutorial (#199)
Making the code more compact and fixing inconsistencies between text variable names and final python program.
2021-08-11 18:59:15 -07:00
Philippe Tillet
398d4b4aeb [DOCS] softmax tutorial fixup (#198) 2021-08-11 17:35:00 -07:00
Philippe Tillet
83da7065da [DRIVER] Portability fixup (#195) 2021-08-07 18:53:11 -07:00
Philippe Tillet
298da78058 [CODEGEN/DRIVER] Tweaks for performance optimization (#193) 2021-08-07 16:41:44 -07:00
Nicholas Joseph
6cd1ec3955 [DOCS] Fix formatting mistakes (#192) 2021-08-06 12:58:43 -07:00
Nicholas Joseph
68f7eeba92 [DOCS] Improve matmul tutorial readability (#188) 2021-08-05 16:05:56 -07:00
Nicholas Joseph
4e6f667c2f [DOCS] Improve readability of 02-fused-softmax.py (#186) 2021-08-05 09:39:07 -07:00
Nicholas Joseph
23c71538fc [DOCS] Improve tutorial readability (#185) 2021-08-05 09:27:06 -07:00
Philippe Tillet
3cb77aa126 [README] Added "we're hiring!" with link to some of our blog posts (#180) 2021-08-02 16:46:26 -07:00
Xiangru Lian
9967e9d4b4 [DOCS] Fix fused softmax example script naive softmax implementation (#178) 2021-08-02 09:37:31 -07:00
Philippe Tillet
e8031fe61f [DRIVER] More robust support of unsupported CUDA version (#179) 2021-08-02 09:06:55 -07:00
milesial
b7cdf670c3 [DOCS] Fix related work (#172) 2021-08-01 11:06:37 -07:00
daadaada
c7060eadb2 [CODEGEN] Fix bug in auto-pipeline pass when a value depends on multiple phis (#164) 2021-07-31 23:40:36 -07:00
Philippe Tillet
c0bb895d9d [BUILD] More portable detection of terminfo (#173) 2021-07-31 17:09:49 -07:00
Philippe Tillet
a34c57402f [PYTHON] Improved error message for CPU (#167) 2021-07-30 09:47:27 -07:00
Ikko Ashimine
2293afece7 [README] GitHub format (#165)
Github -> GitHub
2021-07-30 09:47:08 -07:00
Philippe Tillet
cb5c280691 [DOCS] Added contributions section to README.md 2021-07-29 11:40:34 -07:00
Reid Draper
2322d6df2a [CI] Update ptillet to openai (#152) 2021-07-29 11:39:50 -07:00
Philippe Tillet
2f0f51be50 [DRIVER] No longer crashing when encountering CUDA version >11.4 2021-07-29 11:27:55 -07:00
Philippe Tillet
41ecd96300 [DOCS] minor grammar improvements 2021-07-28 14:18:31 -07:00
Avi Radinsky
d3851d8989 [DOCS] Typo fix (#151) 2021-07-28 12:07:12 -07:00
Philippe Tillet
4b9df06568 [CI] Bumped dev version to 1.0.1 and fixed permissions in documentation.yml (#149) 2021-07-28 04:35:14 -07:00
201 changed files with 24495 additions and 49778 deletions

View File

@@ -8,33 +8,48 @@ jobs:
Build-Documentation:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
- name: Checkout gh-pages
uses: actions/checkout@v1
with:
ref: 'gh-pages'
- name: Clear docs
run: |
rm -r /tmp/triton-docs
continue-on-error: true
- name: Checkout branch
uses: actions/checkout@v1
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e .
- name: Build docs
run: |
git fetch origin master:master
cd docs
make html
sphinx-multiversion . _build/html/
- name: Publish docs
run: |
git branch
# update docs
mkdir /tmp/triton-docs;
mv docs/_build/html/* /tmp/triton-docs/
git checkout gh-pages
sh ./update-website.sh
git remote set-url origin git@github.com:ptillet/triton.git
git push
cp -r CNAME /tmp/triton-docs/
cp -r index.html /tmp/triton-docs/
cp -r .nojekyll /tmp/triton-docs/
rm -r *
cp -r /tmp/triton-docs/* .
# ln -s master/index.html .
# mv master docs
git add .
git commit -am "[GH-PAGES] Updated website"
# publish docs
eval `ssh-agent -s`
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
git remote set-url origin git@github.com:openai/triton.git
git push

View File

@@ -11,24 +11,44 @@ jobs:
Integration-Tests:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Clear cache
run: |
rm -r ~/.triton/
continue-on-error: true
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e .
pip3 install -e '.[tests]'
- name: Run benchmarks
run: |
cd python/bench
python3 -m run
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Run unit tests
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Flake8
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"
- name: Unit tests
run: |
pytest .
cd python/test/unit
pytest -vs .
- name: Regression tests
run: |
cd python/test/regression
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc

View File

@@ -8,7 +8,7 @@ jobs:
Build-Wheels:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
@@ -18,7 +18,7 @@ jobs:
- name: Patch setup.py
run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
build/
__pycache__
.pytest_cache
python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
.vscode
.vs

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "deps/dlfcn-win32"]
path = deps/dlfcn-win32
url = https://github.com/dlfcn-win32/dlfcn-win32.git

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -1,14 +1,13 @@
cmake_minimum_required(VERSION 3.6)
include(ExternalProject)
if(NOT TRITON_LLVM_BUILD_DIR)
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
endif()
set(CMAKE_CXX_STANDARD 17)
project(triton)
include(CTest)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()
# Options
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
@@ -20,8 +19,22 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
if(NOT WIN32)
find_library(TERMINFO_LIBRARY tinfo)
endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
@@ -29,7 +42,20 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
# LLVM
##########
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx")
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
include_directories(${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
add_definitions(${LLVM_DEFINITIONS_LIST})
llvm_map_components_to_libnames(LLVM_LIBRARIES support core
NVPTXInfo nvptxcodegen
AMDGPUInfo AMDGPUcodegen
)
else()
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
endif()
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
@@ -37,14 +63,61 @@ if("${LLVM_LIBRARY_DIR}" STREQUAL "")
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES libLLVMNVPTXCodeGen.a libLLVMSelectionDAG.a libLLVMipo.a libLLVMInstrumentation.a
libLLVMVectorize.a libLLVMLinker.a libLLVMIRReader.a libLLVMAsmParser.a libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a libLLVMDebugInfoDWARF.a libLLVMCodeGen.a libLLVMTarget.a libLLVMScalarOpts.a
libLLVMInstCombine.a libLLVMAggressiveInstCombine.a libLLVMTransformUtils.a libLLVMBitWriter.a
libLLVMAnalysis.a libLLVMProfileData.a libLLVMObject.a libLLVMTextAPI.a libLLVMMCParser.a
libLLVMBitReader.a libLLVMCore.a libLLVMRemarks.a libLLVMBitstreamReader.a libLLVMNVPTXDesc.a
libLLVMMC.a libLLVMDebugInfoCodeView.a libLLVMDebugInfoMSF.a libLLVMBinaryFormat.a libLLVMNVPTXInfo.a
libLLVMSupport.a libLLVMDemangle.a)
set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a
libLLVMAMDGPUDisassembler.a
libLLVMMCDisassembler.a
libLLVMAMDGPUCodeGen.a
libLLVMMIRParser.a
libLLVMGlobalISel.a
libLLVMSelectionDAG.a
libLLVMipo.a
libLLVMInstrumentation.a
libLLVMVectorize.a
libLLVMLinker.a
libLLVMIRReader.a
libLLVMAsmParser.a
libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a
libLLVMDebugInfoDWARF.a
libLLVMCodeGen.a
libLLVMTarget.a
libLLVMScalarOpts.a
libLLVMInstCombine.a
libLLVMAggressiveInstCombine.a
libLLVMTransformUtils.a
libLLVMBitWriter.a
libLLVMAnalysis.a
libLLVMProfileData.a
libLLVMObject.a
libLLVMTextAPI.a
libLLVMBitReader.a
libLLVMAMDGPUAsmParser.a
libLLVMMCParser.a
libLLVMAMDGPUDesc.a
libLLVMAMDGPUUtils.a
libLLVMMC.a
libLLVMDebugInfoCodeView.a
libLLVMDebugInfoMSF.a
libLLVMCore.a
libLLVMRemarks.a
libLLVMBitstreamReader.a
libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
endif()
include_directories("${LLVM_INCLUDE_DIRS}")
@@ -68,12 +141,25 @@ endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
if (WIN32 AND BUILD_PYTHON_MODULE)
find_package(Python3 REQUIRED COMPONENTS Development)
Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
endif()
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
target_link_libraries(triton ${LLVM_LIBRARIES} z tinfo)
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z)
endif()
if(BUILD_PYTHON_MODULE)
if(BUILD_PYTHON_MODULE AND NOT WIN32)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS
if(APPLE)

View File

@@ -1,4 +1,6 @@
/* Copyright 2018-2021 Philippe Tillet
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
@@ -19,8 +21,3 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
// The compiler front-end is based on a modified version of WGTCC
// https://github.com/wgtdkp/wgtcc
// Copyright (c) 2016 wgtdkp

View File

@@ -12,12 +12,52 @@
# Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment for expressing tensor math workloads that offers high flexibility, developer productivity and end to end performance.
The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton!
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
# Quick Installation
You can install the latest stable release of Triton from pip:
```bash
pip install triton
```
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
And the latest nightly release:
```bash
pip install -U --pre triton
```
# Install from source
```
git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
```
# Changelog
Version 1.1 is out! New features include:
- Many, many bugfixes
- More documentation
- Automatic on-disk caching of compiled binary objects
- Random Number Generation
- Faster (up to 2x on A100), cleaner blocksparse ops
# Contributing
Community contributions are more than welcome, whether it be to fix bugs or to add new features. Feel free to open GitHub issues about your contribution ideas, and we will review them. A contributor's guide containing general guidelines is coming soon!
If youre interested in joining our team and working on Triton & GPU kernels, [were hiring](https://openai.com/jobs/#acceleration)!
# Compatibility
Supported Platforms:
@@ -29,4 +69,4 @@ Supported Hardware:
# Disclaimer
Triton is a fairly recent project, and it is under active development. We expect it to be pretty useful in a wide variety of cases, but don't be surprised if it's a bit rough around the edges :)
Triton is a fairly recent project, and it is under active development. We expect it to be pretty useful in a wide variety of cases, but don't be surprised if it's a bit rough around the edges :)

View File

@@ -25,7 +25,7 @@
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
#
# Note: The variable names were chosen in conformance with the offical CMake
# Note: The variable names were chosen in conformance with the official CMake
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
# Try suffixed versions to pick up the newest LLVM install available on Debian
@@ -196,4 +196,4 @@ include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(LLVM
REQUIRED_VARS LLVM_ROOT_DIR
VERSION_VAR LLVM_VERSION_STRING)
VERSION_VAR LLVM_VERSION_STRING)

1
deps/dlfcn-win32 vendored Submodule

Submodule deps/dlfcn-win32 added at 522c301ec3

27
docs/_templates/versions.html vendored Normal file
View File

@@ -0,0 +1,27 @@
{%- if current_version %}
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: {{ current_version.name }}
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
{%- if versions.tags %}
<dl>
<dt>Tags</dt>
{%- for item in versions.tags %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
{%- if versions.branches %}
<dl>
<dt>Branches</dt>
{%- for item in versions.branches %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
</div>
</div>
{%- endif %}

View File

@@ -24,25 +24,39 @@
# -- General configuration ------------------------------------------------
def process_sig(app, what, name, obj, options, signature, return_annotation):
if signature and '_builder' in signature:
signature = signature.split('_builder')[0] + ")"
return (signature, return_annotation)
def setup(app):
"""Customize function args retrieving to get args under decorator."""
import sphinx
import triton
import os
app.connect("autodoc-process-signature", process_sig)
os.system("pip install -e ../python")
def forward_jit_fn(func):
old = func
def wrapped(obj, **kwargs):
if isinstance(obj, triton.code_gen.JITFunction):
import triton
if isinstance(obj, triton.runtime.JITFunction):
obj = obj.fn
return old(obj)
return wrapped
old_documenter = sphinx.ext.autosummary.get_documenter
def documenter(app, obj, parent):
if isinstance(obj, triton.code_gen.JITFunction):
import triton
if isinstance(obj, triton.runtime.JITFunction):
obj = obj.fn
return old_documenter(app, obj, parent)
@@ -56,9 +70,17 @@ def setup(app):
import sys
import os
sys.path.insert(0, os.path.abspath('../python/'))
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion']
autosummary_generate = True
# versioning config
smv_tag_whitelist = r'^(v1.1.2)$'
smv_branch_whitelist = r'^master$'
smv_remote_whitelist = None
smv_released_pattern = r'^tags/.*$'
smv_outputdir_format = '{ref.name}'
smv_prefer_remote_refs = False
# Sphinx gallery
extensions += ['sphinx_gallery.gen_gallery']
from sphinx_gallery.sorting import FileNameSortKey
@@ -68,10 +90,18 @@ sphinx_gallery_conf = {
'filename_pattern': '',
'ignore_pattern': r'__init__\.py',
'within_subsection_order': FileNameSortKey,
'reference_url': {
'sphinx_gallery': None,
}
}
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
html_sidebars = {
'**': [
'_templates/versions.html',
],
}
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:

View File

@@ -8,6 +8,8 @@ Binary Distributions
You can install the latest stable release of Triton from pip:
.. code-block:: bash
pip install triton
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
@@ -31,22 +33,25 @@ You can install the Python package from source by running the following commands
.. code-block:: bash
git clone https://github.com/ptillet/triton.git;
cd triton/python;
git clone https://github.com/openai/triton.git;
cd triton;
git submodule update --init --recursive;
cd python;
pip install cmake; # build time dependency
pip install -e .
Note that, if llvm-11 is not present on your system, the setup.py script will download LLVM static libraries on the web and link against that.
Note that, if llvm-11 is not present on your system and you are on linux, the setup.py script will download the official LLVM11 static libraries link against that. For windows users, LLVM must be installed and configured in PATH.
You can then test your installation by running the unit tests:
.. code-block:: bash
pytest -vs .
pip install -e '.[tests]'
pytest -vs test/unit/
and the benchmarks
.. code-block:: bash
cd bench/
python -m run --with-plots --result-dir /tmp/triton-bench
python -m run --with-plots --result-dir /tmp/triton-bench

Binary file not shown.

After

Width:  |  Height:  |  Size: 465 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View File

@@ -1,7 +1,7 @@
Welcome to Triton's documentation!
==================================
Triton is an language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
Triton is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.
Getting Started
---------------
@@ -49,4 +49,4 @@ Check out the following documents to learn more about Triton and how it compares
:hidden:
programming-guide/chapter-1/introduction
programming-guide/chapter-2/related-work
programming-guide/chapter-2/related-work

View File

@@ -2,7 +2,7 @@
Related Work
==============
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlights its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
At first sight, Triton may seem like just yet another DSL for DNNs. The purpose of this section is to contextualize Triton and highlight its differences with the two leading approaches in this domain: polyhedral compilation and scheduling languages.
-----------------------
Polyhedral Compilation
@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
Program Representation
+++++++++++++++++++++++
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
.. table::
:widths: 50 50
@@ -121,7 +121,7 @@ Limitations
Unfortunately, polyhedral compilers suffer from two major limitations that have prevented its adoption as a universal method for code generation in neural networks.
First, the set of possible program transformations $\Omega = \{ \Theta_S ~|~ S \in \text{program} \}$ is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_.
First, the set of possible program transformations :math:`\Omega = \{ \Theta_S ~|~ S \in \text{program} \}` is large, and grows with the number of statements in the program as well as with the size of their iteration domain. Verifying the legality of each transformation can also require the resolution of complex integer linear programs, making polyhedral compilation very computationally expensive. To make matters worse, hardware properties (e.g., cache size, number of SMs) and contextual characteristics (e.g., input tensor shapes) also have to be taken into account by this framework, leading to expensive auto-tuning procedures [SATO2019]_.
Second, the polyhedral framework is not very generally applicable; SCoPs are relatively common [GIRBAL2006]_ but require loop bounds and array subscripts to be affine functions of loop indices, which typically only occurs in regular, dense computations. For this reason, this framework still has to be successfully applied to sparse -- or even structured-sparse -- neural networks, whose importance has been rapidly rising over the past few years.
@@ -131,7 +131,7 @@ On the other hand, blocked program representations advocated by this dissertatio
Scheduling Languages
-----------------------
Separation of concerns \cite{dijkstra82} is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
Separation of concerns [DIJKSTRA82]_ is a well-known design principle in computer science: programs should be decomposed into modular layers of abstraction that separate the semantics of their algorithms from the details of their implementation. Systems like Halide and TVM push this philosophy one step further, and enforce this separation at the grammatical level through the use of a **scheduling language**. The benefits of this methodology are particularly visible in the case of matrix multiplication, where, as one can see below, the definition of the algorithm (Line 1-7) is completely disjoint from its implementation (Line 8-16), meaning that both can be maintained, optimized and distributed independently.
.. code-block:: python
:linenos:
@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
Limitations
++++++++++++
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse com-putations, whose iteration spaces may be irregular.
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
.. table::
:widths: 50 50
@@ -206,4 +206,5 @@ References
.. [GROSSER2012] T. Grosser et al., "Polly - Performing Polyhedral Optimizations on a Low-Level Intermediate Representation", Parallel Processing Letters 2012
.. [SATO2019] Y. Sato et al., "An Autotuning Framework for Scalable Execution of Tiled Code via Iterative Polyhedral Compilation", TACO 2019
.. [GIRBAL2006] S. Girbal et al., "Semi-Automatic Composition of Loop Transformations for Deep Parallelism and Memory Hierarchies", International Journal of Parallel Programming 2006
.. [MULLAPUDI2016] R. Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016
.. [DIJKSTRA82] E. W. Dijkstra et al., "On the role of scientific thought", Selected writings on computing: a personal perspective 1982
.. [MULLAPUDI2016] R. Mullapudi et al., "Automatically scheduling halide image processing pipelines", TOG 2016

View File

@@ -80,6 +80,9 @@ Math Ops
exp
log
cos
sin
sqrt
sigmoid
softmax
@@ -95,6 +98,22 @@ Reduction Ops
min
sum
Atomic Ops
---------------
.. autosummary::
:toctree: generated
:nosignatures:
atomic_cas
atomic_xchg
atomic_add
atomic_max
atomic_min
atomic_and
atomic_or
atomic_xor
Comparison ops
---------------
@@ -106,6 +125,19 @@ Comparison ops
minimum
maximum
.. _Random Number Generation:
Random Number Generation
-------------------------
.. autosummary::
:toctree: generated
:nosignatures:
randint4x
randint
rand
randn
Compiler Hint Ops
-------------------
@@ -114,4 +146,4 @@ Compiler Hint Ops
:toctree: generated
:nosignatures:
multiple_of
multiple_of

View File

@@ -7,4 +7,7 @@ triton
:toctree: generated
:nosignatures:
jit
jit
autotune
heuristics
Config

View File

@@ -12,7 +12,9 @@ namespace ir {
class phi_node;
class splat_inst;
class cast_inst;
class cmp_inst;
class reshape_inst;
class dequantize_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
@@ -33,8 +35,10 @@ private:
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_dequantize(ir::dequantize_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
@@ -42,6 +46,7 @@ private:
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
@@ -52,6 +57,7 @@ private:
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_dequantize(ir::dequantize_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
@@ -65,6 +71,7 @@ public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
std::vector<cst_info> get_cst_info(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;

View File

@@ -25,9 +25,11 @@ private:
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_dequantize(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i, bool connect_ret=true);
void update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async=false);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);

View File

@@ -93,7 +93,80 @@ protected:
shape_t shape_;
};
class mma_layout: public data_layout {
class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
virtual int contig_per_thread(size_t k) = 0;
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public:
enum TensorCoreType : uint8_t {
// floating-point tensor core instr
FP32_FP16_FP16_FP32 = 0, // default
FP32_BF16_BF16_FP32,
FP32_TF32_TF32_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
//
NOT_APPLICABLE,
};
// Used on nvidia GPUs with sm >= 80
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
{FP32_FP16_FP16_FP32, {16, 8, 16}},
{FP32_BF16_BF16_FP32, {16, 8, 16}},
{FP32_TF32_TF32_FP32, {16, 8, 8}},
{INT32_INT1_INT1_INT32, {16, 8, 256}},
{INT32_INT4_INT4_INT32, {16, 8, 64}},
{INT32_INT8_INT8_INT32, {16, 8, 32}},
};
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
{FP32_FP16_FP16_FP32, {8, 8, 8}},
{FP32_BF16_BF16_FP32, {8, 8, 8}},
{FP32_TF32_TF32_FP32, {8, 8, 4}},
{INT32_INT1_INT1_INT32, {8, 8, 64}},
{INT32_INT4_INT4_INT32, {8, 8, 32}},
{INT32_INT8_INT8_INT32, {8, 8, 16}},
};
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
{FP32_FP16_FP16_FP32, 8},
{FP32_BF16_BF16_FP32, 8},
{FP32_TF32_TF32_FP32, 4},
{INT32_INT1_INT1_INT32, 128},
{INT32_INT4_INT4_INT32, 32},
{INT32_INT8_INT8_INT32, 16},
};
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
@@ -101,24 +174,45 @@ public:
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b);
shared_layout* layout_b,
ir::value *dot);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int spt(size_t k) { return spt_.at(k); }
int rep(size_t k) { return rep_.at(k); }
int contig_per_thread(size_t k) { return contig_per_thread_.at(k); }
// helpers for generator.cc
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
// setter
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
private:
// fragment per warp
std::vector<int> fpw_;
// shape per warp
std::vector<int> spw_;
// warp per tile
std::vector<int> wpt_;
// shape per tile
std::vector<int> spt_;
// repetitions
std::vector<int> rep_;
// contiguous per thread
std::vector<int> contig_per_thread_;
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
};
struct scanline_layout: public data_layout {
class scanline_layout: public distributed_layout {
public:
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
@@ -129,9 +223,13 @@ struct scanline_layout: public data_layout {
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
int contig_per_thread(size_t k) { return nts_.at(k); }
public:
int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
private:
// micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_;
// nano tile size. The size of a tile held by a thread.
std::vector<int> nts_;
};
@@ -148,7 +246,7 @@ struct N_buffer_info_t {
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
// abstract for dot and corresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
@@ -161,7 +259,8 @@ public:
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align);
analysis::align* align, target *tgt,
bool is_tmp = false);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
@@ -176,7 +275,10 @@ public:
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
bool is_tmp() const { return is_tmp_; }
private:
size_t size_;
@@ -188,6 +290,10 @@ private:
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
bool is_tmp_;
};
@@ -206,12 +312,20 @@ private:
void create(size_t id, const std::vector<ir::value*>& values);
public:
void create_tmp_layout(size_t id, data_layout* arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
ir::instruction* i,
bool is_index = false);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
bool has(size_t id) { return layouts_.find(id) != layouts_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
@@ -219,6 +333,18 @@ public:
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
int has_tmp_index(ir::value* i) { return tmp_index_.find(i) != tmp_index_.end(); }
int tmp_index(ir::value* i) { return tmp_index_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// layout checkers
bool is_scanline(ir::instruction* i);
bool is_coalesced_scanline(ir::instruction* i);
bool is_mma(ir::instruction* i);
bool is_a100_mma(ir::instruction* i);
// execution
void run(ir::module &mod);
@@ -233,6 +359,7 @@ private:
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
std::map<ir::value*, size_t> tmp_index_;
};
}

View File

@@ -1,12 +1,14 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
#include "llvm/ADT/MapVector.h"
#include <set>
#include <vector>
namespace triton{
namespace ir{
@@ -42,14 +44,14 @@ struct segment {
class liveness {
private:
typedef std::map<shared_layout*, segment> intervals_map_t;
typedef llvm::MapVector<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); }
segment get(shared_layout* v) const { return intervals_.lookup(v); }
// run
void run(ir::module &mod);

View File

@@ -0,0 +1,90 @@
#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_
#define _TRITON_CODE_GEN_EXTERN_LIB_H_
#include <memory>
#include <string>
#include <map>
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
namespace triton {
namespace codegen {
///
/// \brief ExternLib is a class that represents a library of external functions.
///
class ExternLib {
public:
ExternLib(const std::string &name, const std::string &path)
: name_(name), path_(path) {}
virtual ~ExternLib() = default;
virtual const std::string &name() const { return name_; }
virtual const std::string &path() const { return path_; }
///
/// \brief Load the library and return the module.
///
std::unique_ptr<llvm::Module> load(llvm::LLVMContext &ctx);
///
/// \brief Link the module into the given module.
///
void link(std::unique_ptr<llvm::Module> &llvm,
std::unique_ptr<llvm::Module> &mod);
///
/// \brief Run load, link, and opt on the module.
///
virtual void install(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) {
auto mod = load(ctx);
link(llvm, mod);
opt(ctx, llvm);
}
///
/// \brief Run opt on the module.
///
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) = 0;
private:
std::string name_;
std::string path_;
};
///
/// \brief ExternLibMap is a map of ExternLibs from their names to their paths.
///
typedef std::map<std::string, std::unique_ptr<ExternLib>> ExternLibMap;
///
/// \brief Concrete class for NVIDIA's libdevice library.
///
class LibDevice final : public ExternLib {
public:
LibDevice(const std::string &name, const std::string &path)
: ExternLib(name, path) {}
virtual ~LibDevice() = default;
virtual void opt(llvm::LLVMContext &ctx,
std::unique_ptr<llvm::Module> &llvm) override;
};
///
/// \brief Create an ExternLib instance based on the name and path.
///
std::unique_ptr<ExternLib> create_extern_lib(const std::string &lib_name,
const std::string &lib_path);
} // namespace codegen
} // namespace triton
#endif

View File

@@ -3,9 +3,19 @@
#include <memory>
#include "extern_lib.h"
namespace llvm{
class Module;
class LLVMContext;
}
namespace triton{
namespace codegen {
class target;
}
namespace ir{
class module;
}
@@ -21,10 +31,10 @@ namespace codegen{
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, bool force_nc_cache,
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target,
int num_warps, int num_stages, int &shared_static,
const ExternLibMap &extern_libs);
}
}

View File

@@ -4,7 +4,9 @@
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/ir/instructions.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/extern_lib.h"
#include <functional>
// forward
@@ -24,6 +26,7 @@ namespace llvm{
class IRBuilder;
class ArrayType;
class Function;
class StructType;
}
namespace triton{
@@ -114,18 +117,28 @@ private:
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
void packed_type(ir::value* i);
void forward_declare(ir::function* fn);
Value *cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty);
public:
private:
typedef std::function<void(
std::pair<Value *, Value *> &acc, std::function<Value *()> load_value_fn,
std::function<Value *()> load_index_fn, bool is_first)>
acc_fn_t;
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps,
bool force_nc_cache = false);
unsigned num_warps);
void visit_value(ir::value* v);
void visit_call_inst(ir::call_inst*);
void visit_launch_inst(ir::launch_inst *);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
@@ -135,9 +148,19 @@ public:
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int16_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*, Value*, Value*, Value*, Value*> int32_to_float16x8(
Value *in0, Value *scale_x512, Value *shift
);
std::tuple<Value*, Value*, Value*, Value*> int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift);
std::tuple<Value*, Value*> prepare_scale_shift(Value *scale, Value *shift);
void visit_dequantize_inst(ir::dequantize_inst*);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
@@ -148,18 +171,21 @@ public:
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_cat_inst(ir::cat_inst*);
void visit_extract_value_inst(ir::extract_value_inst *);
void visit_insert_value_inst(ir::insert_value_inst *);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_cos_inst(ir::cos_inst*);
void visit_umulhi_inst(ir::umulhi_inst* x);
void visit_sin_inst(ir::sin_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
@@ -167,11 +193,13 @@ public:
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_recoalesce_inst(ir::recoalesce_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
@@ -180,6 +208,9 @@ public:
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
void visit_clock_inst(ir::clock_inst*);
void visit_globaltimer_inst(ir::globaltimer_inst*);
void visit_extern_elementwise_inst(ir::extern_elementwise_inst*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
@@ -195,12 +226,21 @@ public:
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
// Add a new external library based on given name and path if it doesn't exist
void add_extern_lib(const std::string &lib_name, const std::string &lib_path);
private:
// Get all external libraries
const ExternLibMap &get_extern_lib_map() {
return extern_lib_map_;
}
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
std::map<std::string, std::unique_ptr<ExternLib>> extern_lib_map_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
@@ -212,7 +252,6 @@ private:
std::set<ir::value*> seen_;
unsigned num_warps_;
bool force_nc_cache_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
@@ -234,10 +273,11 @@ private:
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
std::map<ir::value*, Function*> fns_;
// helper for creating llvm values
adder add;
@@ -249,6 +289,9 @@ private:
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
// Eviction policies
std::map<ir::load_inst::EVICTION_POLICY, Value*> policies_;
};
}

View File

@@ -32,10 +32,12 @@ private:
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
bool has_sm80_;
analysis::align* align_;
analysis::layouts* layout_;
};

View File

@@ -15,22 +15,30 @@ namespace ir {
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
bool is_shmem_op(ir::instruction* i, int op);
bool is_shmem_res(ir::value* i);
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*,ir::value*>& copies);
public:
cts(bool use_async = false): use_async_(use_async) {}
cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {}
void run(ir::module &mod);
private:
bool use_async_;
bool has_sm80_;
analysis::layouts* layouts_;
};
}
}
}
#endif
#endif

View File

@@ -0,0 +1,31 @@
#pragma once
#include <list>
namespace triton {
namespace ir {
class module;
class function;
class call_inst;
class builder;
}
namespace codegen{
namespace transform{
struct fncmp {
bool operator()(ir::function* x, ir::function* y) const;
};
class inliner {
public:
inliner() {}
void do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, std::list<ir::call_inst*>& callsites);
void run(ir::module &mod);
};
}
}
}

View File

@@ -30,13 +30,15 @@ private:
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_insert_extract(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
private:
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);

View File

@@ -1,137 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_BACKEND_H_
#define _TRITON_DRIVER_BACKEND_H_
#include <map>
#include <list>
#include <vector>
#include "triton/driver/context.h"
namespace llvm
{
class Module;
}
namespace triton
{
namespace driver
{
class buffer;
class stream;
class device;
class context;
class platform;
class module;
class kernel;
struct backend
{
// platforms
class platforms
{
friend class backend;
private:
static void init();
public:
static void get(std::vector<driver::platform*> &results);
private:
static std::vector<driver::platform*> cache_;
};
// devices
class devices
{
friend class backend;
private:
static void init(const std::vector<platform *> &platforms);
public:
static void get(std::vector<driver::device*>& devs);
private:
static std::vector<driver::device*> cache_;
};
// modules
class modules
{
friend class backend;
public:
static void release();
private:
static std::map<std::tuple<driver::stream*, std::string>, driver::module*> cache_;
};
// kernels
class kernels
{
friend class backend;
public:
static void release();
static driver::kernel* get(driver::module* mod, const std::string & name);
private:
static std::map<std::tuple<module*, std::string>, driver::kernel*> cache_;
};
// contexts
class contexts
{
friend class backend;
private:
static void init(const std::vector<device *> &);
static void release();
public:
static driver::context* get_default();
static driver::context* import(CUcontext ctx)
{
for(driver::context* x: cache_){
driver::cu_context* cu_x = (driver::cu_context*)x;
if(*cu_x->cu()==ctx)
return x;
}
cache_.emplace_back(new driver::cu_context(ctx, false));
return cache_.back();
}
static void get(std::list<driver::context*> &);
private:
static std::list<driver::context*> cache_;
};
// streams
class streams
{
friend class backend;
private:
static void init(std::list<context*> const &);
static void release();
public:
static void get(driver::context*, std::vector<driver::stream *> &streams);
static driver::stream* get(driver::context*, unsigned int id = 0);
static driver::stream* get_default();
private:
static std::map<driver::context*, std::vector<driver::stream*> > cache_;
};
static void init();
static void release();
static void synchronize(triton::driver::context *);
static unsigned int default_device;
};
}
}
#endif

View File

@@ -1,48 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_BUFFER_H_
#define _TRITON_DRIVER_BUFFER_H_
#include "triton/driver/handle.h"
#include "triton/driver/context.h"
namespace triton
{
namespace driver
{
class stream;
// Base
class buffer : public polymorphic_resource<CUdeviceptr, host_buffer_t> {
public:
buffer(size_t size, CUdeviceptr cl, bool take_ownership);
buffer(size_t size, host_buffer_t hst, bool take_ownership);
uintptr_t addr_as_uintptr_t();
static buffer* create(driver::context* ctx, size_t size);
size_t size();
protected:
size_t size_;
};
// CPU
class host_buffer: public buffer
{
public:
host_buffer(size_t size);
};
// CUDA
class cu_buffer: public buffer
{
public:
cu_buffer(size_t size);
cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership);
void set_zero(triton::driver::stream *queue, size_t size);
};
}
}
#endif

View File

@@ -1,50 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_CONTEXT_H_
#define _TRITON_DRIVER_CONTEXT_H_
#include "triton/driver/device.h"
#include "triton/driver/handle.h"
namespace triton
{
namespace driver
{
class context: public polymorphic_resource<CUcontext, host_context_t>{
protected:
static std::string get_cache_path();
public:
context(driver::device *dev, CUcontext cu, bool take_ownership);
context(driver::device *dev, host_context_t hst, bool take_ownership);
driver::device* device() const;
std::string const & cache_path() const;
// factory methods
static context* create(driver::device *dev);
protected:
driver::device* dev_;
std::string cache_path_;
};
// Host
class host_context: public context {
public:
host_context(driver::device* dev);
};
// CUDA
class cu_context: public context {
private:
static CUdevice get_device_of(CUcontext);
public:
//Constructors
cu_context(CUcontext cu, bool take_ownership = true);
cu_context(driver::device* dev);
};
}
}
#endif

View File

@@ -1,81 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_DEVICE_H_
#define _TRITON_DRIVER_DEVICE_H_
#include "triton/driver/platform.h"
#include "triton/driver/handle.h"
namespace triton
{
namespace codegen
{
class target;
}
namespace driver
{
class context;
// Base device
class device: public polymorphic_resource<CUdevice, host_device_t>{
public:
using polymorphic_resource::polymorphic_resource;
virtual size_t max_threads_per_block() const = 0;
virtual size_t max_shared_memory() const = 0;
virtual std::unique_ptr<codegen::target> make_target() const = 0;
};
// Host device
class host_device: public device {
public:
host_device(): device(host_device_t(), true){ }
size_t max_threads_per_block() const { return 1; }
size_t max_shared_memory() const { return 0; }
std::unique_ptr<codegen::target> make_target() const;
};
// CUDA device
class cu_device: public device {
private:
//Metaprogramming elper to get cuda info from attribute
template<CUdevice_attribute attr>
int cuGetInfo() const;
inline nvmlDevice_t nvml_device() const;
public:
cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){}
// Informations
std::string infos() const;
size_t address_bits() const;
std::vector<size_t> max_block_dim() const;
size_t warp_size() const;
// Compute Capability
void interpret_as(int cc);
int compute_capability() const;
// Identifier
std::string name() const;
std::string pci_bus_id() const;
// Clocks
size_t current_sm_clock() const;
size_t current_mem_clock() const;
size_t max_threads_per_block() const;
size_t max_shared_memory() const;
size_t max_sm_clock() const;
size_t max_mem_clock() const;
void set_max_clock();
// Target
std::unique_ptr<codegen::target> make_target() const;
private:
std::shared_ptr<int> interpreted_as_;
};
}
}
#endif

View File

@@ -10,6 +10,10 @@
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions
#include <iostream>
#include <stdexcept>
@@ -28,6 +32,7 @@ class cu_context;
template<class T> void check(T){}
void check(CUresult err);
void check(hipError_t err);
class dispatch
{
@@ -58,68 +63,127 @@ protected:
}
public:
static void release();
// Nvidia
static bool nvmlinit();
static bool cuinit();
static bool spvllvminit();
static void release();
// AMD
static bool hipinit();
// CUDA
static CUresult cuCtxGetCurrent(CUcontext *pctx);
static CUresult cuCtxSetCurrent(CUcontext ctx);
/* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
static CUresult cuInit(unsigned int Flags);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig (CUfunction hfunc, CUfunc_cache config);
// NVML
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
/* ------------------- *
* HIP
* ------------------- */
// context management
static hipError_t hipInit(unsigned int Flags);
static hipError_t hipCtxDestroy(hipCtx_t ctx);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
static hipError_t hipCtxGetDevice(hipDevice_t* result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
static hipError_t hipDriverGetVersion(int *driverVersion);
// device management
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
static hipError_t hipGetDeviceCount(int *count);
// module management
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
static hipError_t hipModuleUnload(hipModule_t hmod);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
// stream management
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
static hipError_t hipStreamSynchronize(hipStream_t hStream);
static hipError_t hipStreamDestroy(hipStream_t hStream);
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
// function management
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
// memory management
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
static hipError_t hipFree(hipDeviceptr_t dptr);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
// event management
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
// SPIR-V libraries
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
private:
@@ -127,67 +191,124 @@ private:
// Libraries
static void* cuda_;
static void* nvml_;
static void* vulkan_;
static void* spvllvm_;
static void* spvcross_;
static void* opengl_;
static void* hip_;
// CUDA functions
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuEventCreate_;
static void* cuDeviceGet_;
static void* cuMemcpyDtoH_v2_;
static void* cuStreamCreate_;
static void* cuEventElapsedTime_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuCtxCreate_v2_;
static void* cuCtxGetDevice_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuCtxEnablePeerAccess_;
static void* cuDriverGetVersion_;
static void* cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuModuleGetGlobal_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuModuleLoad_;
static void* cuLaunchKernel_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
// link management
static void* cuLinkAddFile_v2_;
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;
static void* cuModuleLoadData_;
static void* cuLinkComplete_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
static void* cuMemcpyHtoD_v2_;
static void* cuInit_;
static void* cuEventRecord_;
static void* cuCtxCreate_v2_;
// module management
static void* cuModuleGetGlobal_v2_;
static void* cuModuleLoad_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuModuleLoadData_;
static void* cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuEventDestroy_v2_;
static void* cuMemAlloc_v2_;
static void* cuPointerGetAttribute_;
static void* cuCtxGetDevice_;
static void* cuMemsetD8Async_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuLaunchKernel_;
// function management
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
// NVML
// memory management
static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
// LLVM to SPIR-V
static void* initializeLLVMToSPIRVPass_;
static void* writeSpirv_;
/* ------------------- *
* HIP
* ------------------- */
// context management
static void* hipInit_;
static void* hipCtxDestroy_;
static void* hipCtxCreate_;
static void* hipCtxPushCurrent_;
static void* hipCtxPopCurrent_;
static void* hipCtxGetDevice_;
static void* hipCtxEnablePeerAccess_;
static void* hipDriverGetVersion_;
// device management
static void* hipGetDevice_;
static void* hipDeviceGetName_;
static void* hipDeviceGetPCIBusId_;
static void* hipDeviceGetAttribute_;
static void* hipGetDeviceCount_;
// module management
static void* hipModuleGetGlobal_;
static void* hipModuleLoad_;
static void* hipModuleLoadData_;
static void* hipModuleUnload_;
static void* hipModuleLoadDataEx_;
static void* hipModuleGetFunction_;
// stream management
static void* hipStreamCreate_;
static void* hipStreamSynchronize_;
static void* hipStreamDestroy_;
static void* hipModuleLaunchKernel_;;
// function management
static void* hipFuncGetAttributes_;
static void* hipFuncSetAttribute_;
static void* hipFuncSetCacheConfig_;
// memory management
static void* hipMalloc_;
static void* hipPointerGetAttribute_;
static void* hipMemsetD8Async_;
static void* hipMemcpyDtoH_;
static void* hipFree_;
static void* hipMemcpyDtoHAsync_;
static void* hipMemcpyHtoDAsync_;
static void* hipMemcpyHtoD_;
// event management
static void* hipEventCreate_;
static void* hipEventElapsedTime_;
static void* hipEventRecord_;
static void* hipEventDestroy_;
};
}

View File

@@ -141,6 +141,78 @@ namespace triton
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
}
namespace hip
{
class base: public std::exception{};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "HIP: Error- " msg; } }
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
}
}
}

View File

@@ -1,146 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_HANDLE_H_
#define _TRITON_DRIVER_HANDLE_H_
#include <memory>
#include <map>
#include <iostream>
#include <functional>
#include <type_traits>
#include "triton/driver/dispatch.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "triton/tools/thread_pool.h"
namespace llvm
{
class ExecutionEngine;
class Function;
}
namespace triton
{
namespace driver
{
enum backend_t {
CUDA,
Host
};
// Host handles
struct host_platform_t{
};
struct host_device_t{
};
struct host_context_t{
};
struct host_stream_t{
std::shared_ptr<ThreadPool> pool;
std::shared_ptr<std::vector<std::future<void>>> futures;
std::vector<std::shared_ptr<char*>> args;
};
struct host_module_t{
std::string error;
llvm::ExecutionEngine* engine;
std::map<std::string, llvm::Function*> functions;
void(*fn)(char**, int32_t, int32_t, int32_t);
llvm::orc::ExecutionSession* ES;
llvm::orc::RTDyldObjectLinkingLayer* ObjectLayer;
llvm::orc::IRCompileLayer* CompileLayer;
llvm::DataLayout* DL;
llvm::orc::MangleAndInterner* Mangle;
llvm::orc::ThreadSafeContext* Ctx;
llvm::orc::JITDylib *MainJD;
};
struct host_function_t{
llvm::Function* fn;
};
struct host_buffer_t{
char* data;
};
// Extra CUDA handles
struct cu_event_t{
operator bool() const { return first && second; }
CUevent first;
CUevent second;
};
struct CUPlatform{
CUPlatform() : status_(dispatch::cuInit(0)) { }
operator bool() const { return status_; }
private:
CUresult status_;
};
template<class T, class CUType>
class handle_interface{
public:
//Accessors
operator CUType() const { return *(((T*)this)->cu().h_); }
//Comparison
bool operator==(handle_interface const & y) { return (CUType)(*this) == (CUType)(y); }
bool operator!=(handle_interface const & y) { return (CUType)(*this) != (CUType)(y); }
bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); }
};
template<class T>
class handle{
public:
template<class, class> friend class handle_interface;
public:
//Constructors
handle(T h, bool take_ownership = true);
handle();
~handle();
T& operator*() { return *h_; }
T const & operator*() const { return *h_; }
T* operator->() const { return h_.get(); }
protected:
std::shared_ptr<T> h_;
bool has_ownership_;
};
template<class CUType, class HostType>
class polymorphic_resource {
public:
polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership), backend_(CUDA){}
polymorphic_resource(HostType hst, bool take_ownership): hst_(hst, take_ownership), backend_(Host){}
virtual ~polymorphic_resource() { }
handle<CUType> cu() { return cu_; }
handle<HostType> hst() { return hst_; }
const handle<CUType>& cu() const { return cu_; }
const handle<HostType>& hst() const { return hst_; }
backend_t backend() { return backend_; }
protected:
handle<CUType> cu_;
handle<HostType> hst_;
backend_t backend_;
};
}
}
#endif

View File

@@ -1,53 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_KERNEL_H_
#define _TRITON_DRIVER_KERNEL_H_
#include "triton/driver/module.h"
#include "triton/driver/handle.h"
#include <memory>
namespace llvm
{
class GenericValue;
}
namespace triton
{
namespace driver
{
class cu_buffer;
// Base
class kernel: public polymorphic_resource<CUfunction, host_function_t> {
public:
kernel(driver::module* program, CUfunction fn, bool has_ownership);
kernel(driver::module* program, host_function_t fn, bool has_ownership);
driver::module* module();
static kernel* create(driver::module* program, const char* name);
private:
driver::module* program_;
};
// Host
class host_kernel: public kernel {
public:
//Constructors
host_kernel(driver::module* program, const char* name);
};
// CUDA
class cu_kernel: public kernel {
public:
//Constructors
cu_kernel(driver::module* program, const char * name);
};
}
}
#endif

View File

@@ -0,0 +1,20 @@
#include <string>
#include "triton/driver/dispatch.h"
namespace llvm{
class Module;
}
namespace triton{
namespace driver{
void init_llvm();
std::string path_to_ptxas(int& version);
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas_path, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
}
}

View File

@@ -1,84 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_MODULE_H_
#define _TRITON_DRIVER_MODULE_H_
#include <map>
#include "triton/driver/handle.h"
#include "triton/driver/context.h"
#include "triton/driver/buffer.h"
namespace llvm
{
class Module;
template<class T>
class SmallVectorImpl;
}
namespace triton
{
namespace driver
{
class cu_context;
class cu_device;
// Base
class module: public polymorphic_resource<CUmodule, host_module_t> {
protected:
void init_llvm();
enum file_type_t{
Object,
Assembly
};
public:
module(CUmodule mod, bool has_ownership);
module(host_module_t mod, bool has_ownership);
static module* create(driver::device* device, std::unique_ptr<llvm::Module> src);
void compile_llvm_module(std::unique_ptr<llvm::Module> module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string &features,
file_type_t file_type);
virtual std::unique_ptr<buffer> symbol(const char * name) const = 0;
int spilled() const { return spilled_; }
protected:
int spilled_;
};
// CPU
class host_module: public module{
public:
host_module(std::unique_ptr<llvm::Module> module);
std::unique_ptr<buffer> symbol(const char * name) const;
};
// CUDA
class cu_module: public module {
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
void init_from_ptx(const std::string& ptx, cu_device *device);
public:
cu_module(driver::device* device, std::unique_ptr<llvm::Module> module);
cu_module(driver::device* device, const std::string& source);
std::unique_ptr<buffer> symbol(const char * name) const;
std::string llir() const { return llir_; }
const std::string& ptx() const { return ptx_; }
const std::string& cubin() const { return cubin_; }
private:
std::string ptx_;
std::string cubin_;
std::string llir_;
};
}
}
#endif

View File

@@ -1,58 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_PLATFORM_H_
#define _TRITON_DRIVER_PLATFORM_H_
#include <vector>
#include <string>
#include "triton/driver/handle.h"
namespace triton
{
namespace driver
{
class device;
class platform
{
public:
// Constructor
platform(const std::string& name): name_(name){ }
// Accessors
std::string name() const { return name_; }
// Virtual methods
virtual std::string version() const = 0;
virtual void devices(std::vector<driver::device *> &devices) const = 0;
private:
std::string name_;
};
// CUDA
class cu_platform: public platform
{
public:
cu_platform(): platform("CUDA") { }
std::string version() const;
void devices(std::vector<driver::device*> &devices) const;
private:
handle<CUPlatform> cu_;
};
// Host
class host_platform: public platform
{
public:
host_platform(): platform("CPU") { }
std::string version() const;
void devices(std::vector<driver::device*> &devices) const;
};
}
}
#endif

View File

@@ -1,68 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_STREAM_H_
#define _TRITON_DRIVER_STREAM_H_
#include <map>
#include "triton/driver/context.h"
#include "triton/driver/device.h"
#include "triton/driver/handle.h"
#include "triton/driver/buffer.h"
namespace triton
{
namespace driver
{
class kernel;
class event;
class Range;
class cu_buffer;
// Base
class stream: public polymorphic_resource<CUstream, host_stream_t> {
public:
stream(CUstream, bool has_ownership);
stream(host_stream_t, bool has_ownership);
// factory
static driver::stream* create(backend_t backend);
// methods
virtual void synchronize() = 0;
virtual void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem = 0) = 0;
virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0;
virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0;
// template helpers
template<class T> void write(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T> const & x)
{ write(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
template<class T> void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector<T>& x)
{ read(buf, blocking, offset, x.size()*sizeof(T), x.data()); }
};
// Host
class host_stream: public stream {
public:
host_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
// CUDA
class cu_stream: public stream {
public:
cu_stream(CUstream str, bool take_ownership);
cu_stream();
void synchronize();
void enqueue(driver::kernel* kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, void* args, size_t args_size, size_t shared_mem);
void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr);
void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr);
};
}
}
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,131 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_CL_D3D10_H
#define __OPENCL_CL_D3D10_H
#include <d3d10.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
/******************************************************************************
* cl_khr_d3d10_sharing */
#define cl_khr_d3d10_sharing 1
typedef cl_uint cl_d3d10_device_source_khr;
typedef cl_uint cl_d3d10_device_set_khr;
/******************************************************************************/
/* Error Codes */
#define CL_INVALID_D3D10_DEVICE_KHR -1002
#define CL_INVALID_D3D10_RESOURCE_KHR -1003
#define CL_D3D10_RESOURCE_ALREADY_ACQUIRED_KHR -1004
#define CL_D3D10_RESOURCE_NOT_ACQUIRED_KHR -1005
/* cl_d3d10_device_source_nv */
#define CL_D3D10_DEVICE_KHR 0x4010
#define CL_D3D10_DXGI_ADAPTER_KHR 0x4011
/* cl_d3d10_device_set_nv */
#define CL_PREFERRED_DEVICES_FOR_D3D10_KHR 0x4012
#define CL_ALL_DEVICES_FOR_D3D10_KHR 0x4013
/* cl_context_info */
#define CL_CONTEXT_D3D10_DEVICE_KHR 0x4014
#define CL_CONTEXT_D3D10_PREFER_SHARED_RESOURCES_KHR 0x402C
/* cl_mem_info */
#define CL_MEM_D3D10_RESOURCE_KHR 0x4015
/* cl_image_info */
#define CL_IMAGE_D3D10_SUBRESOURCE_KHR 0x4016
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_D3D10_OBJECTS_KHR 0x4017
#define CL_COMMAND_RELEASE_D3D10_OBJECTS_KHR 0x4018
/******************************************************************************/
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromD3D10KHR_fn)(
cl_platform_id platform,
cl_d3d10_device_source_khr d3d_device_source,
void * d3d_object,
cl_d3d10_device_set_khr d3d_device_set,
cl_uint num_entries,
cl_device_id * devices,
cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10BufferKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D10Buffer * resource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10Texture2DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D10Texture2D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D10Texture3DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D10Texture3D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireD3D10ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseD3D10ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_0;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_D3D10_H */

View File

@@ -1,131 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_CL_D3D11_H
#define __OPENCL_CL_D3D11_H
#include <d3d11.h>
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
/******************************************************************************
* cl_khr_d3d11_sharing */
#define cl_khr_d3d11_sharing 1
typedef cl_uint cl_d3d11_device_source_khr;
typedef cl_uint cl_d3d11_device_set_khr;
/******************************************************************************/
/* Error Codes */
#define CL_INVALID_D3D11_DEVICE_KHR -1006
#define CL_INVALID_D3D11_RESOURCE_KHR -1007
#define CL_D3D11_RESOURCE_ALREADY_ACQUIRED_KHR -1008
#define CL_D3D11_RESOURCE_NOT_ACQUIRED_KHR -1009
/* cl_d3d11_device_source */
#define CL_D3D11_DEVICE_KHR 0x4019
#define CL_D3D11_DXGI_ADAPTER_KHR 0x401A
/* cl_d3d11_device_set */
#define CL_PREFERRED_DEVICES_FOR_D3D11_KHR 0x401B
#define CL_ALL_DEVICES_FOR_D3D11_KHR 0x401C
/* cl_context_info */
#define CL_CONTEXT_D3D11_DEVICE_KHR 0x401D
#define CL_CONTEXT_D3D11_PREFER_SHARED_RESOURCES_KHR 0x402D
/* cl_mem_info */
#define CL_MEM_D3D11_RESOURCE_KHR 0x401E
/* cl_image_info */
#define CL_IMAGE_D3D11_SUBRESOURCE_KHR 0x401F
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_D3D11_OBJECTS_KHR 0x4020
#define CL_COMMAND_RELEASE_D3D11_OBJECTS_KHR 0x4021
/******************************************************************************/
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromD3D11KHR_fn)(
cl_platform_id platform,
cl_d3d11_device_source_khr d3d_device_source,
void * d3d_object,
cl_d3d11_device_set_khr d3d_device_set,
cl_uint num_entries,
cl_device_id * devices,
cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11BufferKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D11Buffer * resource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11Texture2DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D11Texture2D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromD3D11Texture3DKHR_fn)(
cl_context context,
cl_mem_flags flags,
ID3D11Texture3D * resource,
UINT subresource,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireD3D11ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseD3D11ObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_D3D11_H */

View File

@@ -1,132 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_H
#define __OPENCL_CL_DX9_MEDIA_SHARING_H
#include "cl.h"
#include "cl_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
/******************************************************************************/
/* cl_khr_dx9_media_sharing */
#define cl_khr_dx9_media_sharing 1
typedef cl_uint cl_dx9_media_adapter_type_khr;
typedef cl_uint cl_dx9_media_adapter_set_khr;
#if defined(_WIN32)
#include <d3d9.h>
typedef struct _cl_dx9_surface_info_khr
{
IDirect3DSurface9 *resource;
HANDLE shared_handle;
} cl_dx9_surface_info_khr;
#endif
/******************************************************************************/
/* Error Codes */
#define CL_INVALID_DX9_MEDIA_ADAPTER_KHR -1010
#define CL_INVALID_DX9_MEDIA_SURFACE_KHR -1011
#define CL_DX9_MEDIA_SURFACE_ALREADY_ACQUIRED_KHR -1012
#define CL_DX9_MEDIA_SURFACE_NOT_ACQUIRED_KHR -1013
/* cl_media_adapter_type_khr */
#define CL_ADAPTER_D3D9_KHR 0x2020
#define CL_ADAPTER_D3D9EX_KHR 0x2021
#define CL_ADAPTER_DXVA_KHR 0x2022
/* cl_media_adapter_set_khr */
#define CL_PREFERRED_DEVICES_FOR_DX9_MEDIA_ADAPTER_KHR 0x2023
#define CL_ALL_DEVICES_FOR_DX9_MEDIA_ADAPTER_KHR 0x2024
/* cl_context_info */
#define CL_CONTEXT_ADAPTER_D3D9_KHR 0x2025
#define CL_CONTEXT_ADAPTER_D3D9EX_KHR 0x2026
#define CL_CONTEXT_ADAPTER_DXVA_KHR 0x2027
/* cl_mem_info */
#define CL_MEM_DX9_MEDIA_ADAPTER_TYPE_KHR 0x2028
#define CL_MEM_DX9_MEDIA_SURFACE_INFO_KHR 0x2029
/* cl_image_info */
#define CL_IMAGE_DX9_MEDIA_PLANE_KHR 0x202A
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_DX9_MEDIA_SURFACES_KHR 0x202B
#define CL_COMMAND_RELEASE_DX9_MEDIA_SURFACES_KHR 0x202C
/******************************************************************************/
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceIDsFromDX9MediaAdapterKHR_fn)(
cl_platform_id platform,
cl_uint num_media_adapters,
cl_dx9_media_adapter_type_khr * media_adapter_type,
void * media_adapters,
cl_dx9_media_adapter_set_khr media_adapter_set,
cl_uint num_entries,
cl_device_id * devices,
cl_uint * num_devices) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromDX9MediaSurfaceKHR_fn)(
cl_context context,
cl_mem_flags flags,
cl_dx9_media_adapter_type_khr adapter_type,
void * surface_info,
cl_uint plane,
cl_int * errcode_ret) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireDX9MediaSurfacesKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseDX9MediaSurfacesKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event) CL_API_SUFFIX__VERSION_1_2;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_DX9_MEDIA_SHARING_H */

View File

@@ -1,182 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2016 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/*****************************************************************************\
Copyright (c) 2013-2016 Intel Corporation All Rights Reserved.
THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE
MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File Name: cl_dx9_media_sharing_intel.h
Abstract:
Notes:
\*****************************************************************************/
#ifndef __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H
#define __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H
#include <CL/cl.h>
#include <CL/cl_platform.h>
#include <d3d9.h>
#include <dxvahd.h>
#include <wtypes.h>
#include <d3d9types.h>
#ifdef __cplusplus
extern "C" {
#endif
/***************************************
* cl_intel_dx9_media_sharing extension *
****************************************/
#define cl_intel_dx9_media_sharing 1
typedef cl_uint cl_dx9_device_source_intel;
typedef cl_uint cl_dx9_device_set_intel;
/* error codes */
#define CL_INVALID_DX9_DEVICE_INTEL -1010
#define CL_INVALID_DX9_RESOURCE_INTEL -1011
#define CL_DX9_RESOURCE_ALREADY_ACQUIRED_INTEL -1012
#define CL_DX9_RESOURCE_NOT_ACQUIRED_INTEL -1013
/* cl_dx9_device_source_intel */
#define CL_D3D9_DEVICE_INTEL 0x4022
#define CL_D3D9EX_DEVICE_INTEL 0x4070
#define CL_DXVA_DEVICE_INTEL 0x4071
/* cl_dx9_device_set_intel */
#define CL_PREFERRED_DEVICES_FOR_DX9_INTEL 0x4024
#define CL_ALL_DEVICES_FOR_DX9_INTEL 0x4025
/* cl_context_info */
#define CL_CONTEXT_D3D9_DEVICE_INTEL 0x4026
#define CL_CONTEXT_D3D9EX_DEVICE_INTEL 0x4072
#define CL_CONTEXT_DXVA_DEVICE_INTEL 0x4073
/* cl_mem_info */
#define CL_MEM_DX9_RESOURCE_INTEL 0x4027
#define CL_MEM_DX9_SHARED_HANDLE_INTEL 0x4074
/* cl_image_info */
#define CL_IMAGE_DX9_PLANE_INTEL 0x4075
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_DX9_OBJECTS_INTEL 0x402A
#define CL_COMMAND_RELEASE_DX9_OBJECTS_INTEL 0x402B
/******************************************************************************/
extern CL_API_ENTRY cl_int CL_API_CALL
clGetDeviceIDsFromDX9INTEL(
cl_platform_id /* platform */,
cl_dx9_device_source_intel /* dx9_device_source */,
void* /* dx9_object */,
cl_dx9_device_set_intel /* dx9_device_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int (CL_API_CALL* clGetDeviceIDsFromDX9INTEL_fn)(
cl_platform_id /* platform */,
cl_dx9_device_source_intel /* dx9_device_source */,
void* /* dx9_object */,
cl_dx9_device_set_intel /* dx9_device_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromDX9MediaSurfaceINTEL(
cl_context /* context */,
cl_mem_flags /* flags */,
IDirect3DSurface9* /* resource */,
HANDLE /* sharedHandle */,
UINT /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromDX9MediaSurfaceINTEL_fn)(
cl_context /* context */,
cl_mem_flags /* flags */,
IDirect3DSurface9* /* resource */,
HANDLE /* sharedHandle */,
UINT /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireDX9ObjectsINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireDX9ObjectsINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseDX9ObjectsINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseDX9ObjectsINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_1;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_DX9_MEDIA_SHARING_INTEL_H */

View File

@@ -1,136 +0,0 @@
/*******************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
#ifndef __OPENCL_CL_EGL_H
#define __OPENCL_CL_EGL_H
#ifdef __APPLE__
#else
#include "cl.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
/* Command type for events created with clEnqueueAcquireEGLObjectsKHR */
#define CL_COMMAND_EGL_FENCE_SYNC_OBJECT_KHR 0x202F
#define CL_COMMAND_ACQUIRE_EGL_OBJECTS_KHR 0x202D
#define CL_COMMAND_RELEASE_EGL_OBJECTS_KHR 0x202E
/* Error type for clCreateFromEGLImageKHR */
#define CL_INVALID_EGL_OBJECT_KHR -1093
#define CL_EGL_RESOURCE_NOT_ACQUIRED_KHR -1092
/* CLeglImageKHR is an opaque handle to an EGLImage */
typedef void* CLeglImageKHR;
/* CLeglDisplayKHR is an opaque handle to an EGLDisplay */
typedef void* CLeglDisplayKHR;
/* CLeglSyncKHR is an opaque handle to an EGLSync object */
typedef void* CLeglSyncKHR;
/* properties passed to clCreateFromEGLImageKHR */
typedef intptr_t cl_egl_image_properties_khr;
#define cl_khr_egl_image 1
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromEGLImageKHR(cl_context /* context */,
CLeglDisplayKHR /* egldisplay */,
CLeglImageKHR /* eglimage */,
cl_mem_flags /* flags */,
const cl_egl_image_properties_khr * /* properties */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_mem (CL_API_CALL *clCreateFromEGLImageKHR_fn)(
cl_context context,
CLeglDisplayKHR egldisplay,
CLeglImageKHR eglimage,
cl_mem_flags flags,
const cl_egl_image_properties_khr * properties,
cl_int * errcode_ret);
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireEGLObjectsKHR(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireEGLObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event);
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseEGLObjectsKHR(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseEGLObjectsKHR_fn)(
cl_command_queue command_queue,
cl_uint num_objects,
const cl_mem * mem_objects,
cl_uint num_events_in_wait_list,
const cl_event * event_wait_list,
cl_event * event);
#define cl_khr_egl_event 1
extern CL_API_ENTRY cl_event CL_API_CALL
clCreateEventFromEGLSyncKHR(cl_context /* context */,
CLeglSyncKHR /* sync */,
CLeglDisplayKHR /* display */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_event (CL_API_CALL *clCreateEventFromEGLSyncKHR_fn)(
cl_context context,
CLeglSyncKHR sync,
CLeglDisplayKHR display,
cl_int * errcode_ret);
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_EGL_H */

View File

@@ -1,670 +0,0 @@
/*******************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
/* $Revision: 11928 $ on $Date: 2010-07-13 09:04:56 -0700 (Tue, 13 Jul 2010) $ */
/* cl_ext.h contains OpenCL extensions which don't have external */
/* (OpenGL, D3D) dependencies. */
#ifndef __CL_EXT_H
#define __CL_EXT_H
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __APPLE__
#include <OpenCL/cl.h>
#include <AvailabilityMacros.h>
#else
#include "cl.h"
#endif
/* cl_khr_fp64 extension - no extension #define since it has no functions */
#define CL_DEVICE_DOUBLE_FP_CONFIG 0x1032
/* cl_khr_fp16 extension - no extension #define since it has no functions */
#define CL_DEVICE_HALF_FP_CONFIG 0x1033
/* Memory object destruction
*
* Apple extension for use to manage externally allocated buffers used with cl_mem objects with CL_MEM_USE_HOST_PTR
*
* Registers a user callback function that will be called when the memory object is deleted and its resources
* freed. Each call to clSetMemObjectCallbackFn registers the specified user callback function on a callback
* stack associated with memobj. The registered user callback functions are called in the reverse order in
* which they were registered. The user callback functions are called and then the memory object is deleted
* and its resources freed. This provides a mechanism for the application (and libraries) using memobj to be
* notified when the memory referenced by host_ptr, specified when the memory object is created and used as
* the storage bits for the memory object, can be reused or freed.
*
* The application may not call CL api's with the cl_mem object passed to the pfn_notify.
*
* Please check for the "cl_APPLE_SetMemObjectDestructor" extension using clGetDeviceInfo(CL_DEVICE_EXTENSIONS)
* before using.
*/
#define cl_APPLE_SetMemObjectDestructor 1
cl_int CL_API_ENTRY clSetMemObjectDestructorAPPLE( cl_mem /* memobj */,
void (* /*pfn_notify*/)( cl_mem /* memobj */, void* /*user_data*/),
void * /*user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/* Context Logging Functions
*
* The next three convenience functions are intended to be used as the pfn_notify parameter to clCreateContext().
* Please check for the "cl_APPLE_ContextLoggingFunctions" extension using clGetDeviceInfo(CL_DEVICE_EXTENSIONS)
* before using.
*
* clLogMessagesToSystemLog fowards on all log messages to the Apple System Logger
*/
#define cl_APPLE_ContextLoggingFunctions 1
extern void CL_API_ENTRY clLogMessagesToSystemLogAPPLE( const char * /* errstr */,
const void * /* private_info */,
size_t /* cb */,
void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/* clLogMessagesToStdout sends all log messages to the file descriptor stdout */
extern void CL_API_ENTRY clLogMessagesToStdoutAPPLE( const char * /* errstr */,
const void * /* private_info */,
size_t /* cb */,
void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/* clLogMessagesToStderr sends all log messages to the file descriptor stderr */
extern void CL_API_ENTRY clLogMessagesToStderrAPPLE( const char * /* errstr */,
const void * /* private_info */,
size_t /* cb */,
void * /* user_data */ ) CL_EXT_SUFFIX__VERSION_1_0;
/************************
* cl_khr_icd extension *
************************/
#define cl_khr_icd 1
/* cl_platform_info */
#define CL_PLATFORM_ICD_SUFFIX_KHR 0x0920
/* Additional Error Codes */
#define CL_PLATFORM_NOT_FOUND_KHR -1001
extern CL_API_ENTRY cl_int CL_API_CALL
clIcdGetPlatformIDsKHR(cl_uint /* num_entries */,
cl_platform_id * /* platforms */,
cl_uint * /* num_platforms */);
typedef CL_API_ENTRY cl_int (CL_API_CALL *clIcdGetPlatformIDsKHR_fn)(
cl_uint /* num_entries */,
cl_platform_id * /* platforms */,
cl_uint * /* num_platforms */);
/* Extension: cl_khr_image2D_buffer
*
* This extension allows a 2D image to be created from a cl_mem buffer without a copy.
* The type associated with a 2D image created from a buffer in an OpenCL program is image2d_t.
* Both the sampler and sampler-less read_image built-in functions are supported for 2D images
* and 2D images created from a buffer. Similarly, the write_image built-ins are also supported
* for 2D images created from a buffer.
*
* When the 2D image from buffer is created, the client must specify the width,
* height, image format (i.e. channel order and channel data type) and optionally the row pitch
*
* The pitch specified must be a multiple of CL_DEVICE_IMAGE_PITCH_ALIGNMENT pixels.
* The base address of the buffer must be aligned to CL_DEVICE_IMAGE_BASE_ADDRESS_ALIGNMENT pixels.
*/
/*************************************
* cl_khr_initalize_memory extension *
*************************************/
#define CL_CONTEXT_MEMORY_INITIALIZE_KHR 0x2030
/**************************************
* cl_khr_terminate_context extension *
**************************************/
#define CL_DEVICE_TERMINATE_CAPABILITY_KHR 0x2031
#define CL_CONTEXT_TERMINATE_KHR 0x2032
#define cl_khr_terminate_context 1
extern CL_API_ENTRY cl_int CL_API_CALL clTerminateContextKHR(cl_context /* context */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clTerminateContextKHR_fn)(cl_context /* context */) CL_EXT_SUFFIX__VERSION_1_2;
/*
* Extension: cl_khr_spir
*
* This extension adds support to create an OpenCL program object from a
* Standard Portable Intermediate Representation (SPIR) instance
*/
#define CL_DEVICE_SPIR_VERSIONS 0x40E0
#define CL_PROGRAM_BINARY_TYPE_INTERMEDIATE 0x40E1
/*****************************************
* cl_khr_create_command_queue extension *
*****************************************/
#define cl_khr_create_command_queue 1
typedef cl_bitfield cl_queue_properties_khr;
extern CL_API_ENTRY cl_command_queue CL_API_CALL
clCreateCommandQueueWithPropertiesKHR( cl_context /* context */,
cl_device_id /* device */,
const cl_queue_properties_khr* /* properties */,
cl_int* /* errcode_ret */ ) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_command_queue
(CL_API_CALL *clCreateCommandQueueWithPropertiesKHR_fn)( cl_context /* context */,
cl_device_id /* device */,
const cl_queue_properties_khr* /* properties */,
cl_int* /* errcode_ret */ ) CL_EXT_SUFFIX__VERSION_1_2;
/******************************************
* cl_nv_device_attribute_query extension *
******************************************/
/* cl_nv_device_attribute_query extension - no extension #define since it has no functions */
#define CL_DEVICE_COMPUTE_CAPABILITY_MAJOR_NV 0x4000
#define CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV 0x4001
#define CL_DEVICE_REGISTERS_PER_BLOCK_NV 0x4002
#define CL_DEVICE_WARP_SIZE_NV 0x4003
#define CL_DEVICE_GPU_OVERLAP_NV 0x4004
#define CL_DEVICE_KERNEL_EXEC_TIMEOUT_NV 0x4005
#define CL_DEVICE_INTEGRATED_MEMORY_NV 0x4006
/*********************************
* cl_amd_device_memory_flags *
*********************************/
#define cl_amd_device_memory_flags 1
#define CL_MEM_USE_PERSISTENT_MEM_AMD (1 << 6) // Alloc from GPU's CPU visible heap
/* cl_device_info */
#define CL_DEVICE_MAX_ATOMIC_COUNTERS_EXT 0x4032
/*********************************
* cl_amd_device_attribute_query *
*********************************/
#define CL_DEVICE_PROFILING_TIMER_OFFSET_AMD 0x4036
#define CL_DEVICE_TOPOLOGY_AMD 0x4037
#define CL_DEVICE_BOARD_NAME_AMD 0x4038
#define CL_DEVICE_GLOBAL_FREE_MEMORY_AMD 0x4039
#define CL_DEVICE_SIMD_PER_COMPUTE_UNIT_AMD 0x4040
#define CL_DEVICE_SIMD_WIDTH_AMD 0x4041
#define CL_DEVICE_SIMD_INSTRUCTION_WIDTH_AMD 0x4042
#define CL_DEVICE_WAVEFRONT_WIDTH_AMD 0x4043
#define CL_DEVICE_GLOBAL_MEM_CHANNELS_AMD 0x4044
#define CL_DEVICE_GLOBAL_MEM_CHANNEL_BANKS_AMD 0x4045
#define CL_DEVICE_GLOBAL_MEM_CHANNEL_BANK_WIDTH_AMD 0x4046
#define CL_DEVICE_LOCAL_MEM_SIZE_PER_COMPUTE_UNIT_AMD 0x4047
#define CL_DEVICE_LOCAL_MEM_BANKS_AMD 0x4048
typedef union
{
struct { cl_uint type; cl_uint data[5]; } raw;
struct { cl_uint type; cl_char unused[17]; cl_char bus; cl_char device; cl_char function; } pcie;
} cl_device_topology_amd;
#define CL_DEVICE_TOPOLOGY_TYPE_PCIE_AMD 1
/**************************
* cl_amd_offline_devices *
**************************/
#define CL_CONTEXT_OFFLINE_DEVICES_AMD 0x403F
/*********************************
* cl_arm_printf extension
*********************************/
#define CL_PRINTF_CALLBACK_ARM 0x40B0
#define CL_PRINTF_BUFFERSIZE_ARM 0x40B1
#ifdef CL_VERSION_1_1
/***********************************
* cl_ext_device_fission extension *
***********************************/
#define cl_ext_device_fission 1
extern CL_API_ENTRY cl_int CL_API_CALL
clReleaseDeviceEXT( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int
(CL_API_CALL *clReleaseDeviceEXT_fn)( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
extern CL_API_ENTRY cl_int CL_API_CALL
clRetainDeviceEXT( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int
(CL_API_CALL *clRetainDeviceEXT_fn)( cl_device_id /*device*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef cl_ulong cl_device_partition_property_ext;
extern CL_API_ENTRY cl_int CL_API_CALL
clCreateSubDevicesEXT( cl_device_id /*in_device*/,
const cl_device_partition_property_ext * /* properties */,
cl_uint /*num_entries*/,
cl_device_id * /*out_devices*/,
cl_uint * /*num_devices*/ ) CL_EXT_SUFFIX__VERSION_1_1;
typedef CL_API_ENTRY cl_int
( CL_API_CALL * clCreateSubDevicesEXT_fn)( cl_device_id /*in_device*/,
const cl_device_partition_property_ext * /* properties */,
cl_uint /*num_entries*/,
cl_device_id * /*out_devices*/,
cl_uint * /*num_devices*/ ) CL_EXT_SUFFIX__VERSION_1_1;
/* cl_device_partition_property_ext */
#define CL_DEVICE_PARTITION_EQUALLY_EXT 0x4050
#define CL_DEVICE_PARTITION_BY_COUNTS_EXT 0x4051
#define CL_DEVICE_PARTITION_BY_NAMES_EXT 0x4052
#define CL_DEVICE_PARTITION_BY_AFFINITY_DOMAIN_EXT 0x4053
/* clDeviceGetInfo selectors */
#define CL_DEVICE_PARENT_DEVICE_EXT 0x4054
#define CL_DEVICE_PARTITION_TYPES_EXT 0x4055
#define CL_DEVICE_AFFINITY_DOMAINS_EXT 0x4056
#define CL_DEVICE_REFERENCE_COUNT_EXT 0x4057
#define CL_DEVICE_PARTITION_STYLE_EXT 0x4058
/* error codes */
#define CL_DEVICE_PARTITION_FAILED_EXT -1057
#define CL_INVALID_PARTITION_COUNT_EXT -1058
#define CL_INVALID_PARTITION_NAME_EXT -1059
/* CL_AFFINITY_DOMAINs */
#define CL_AFFINITY_DOMAIN_L1_CACHE_EXT 0x1
#define CL_AFFINITY_DOMAIN_L2_CACHE_EXT 0x2
#define CL_AFFINITY_DOMAIN_L3_CACHE_EXT 0x3
#define CL_AFFINITY_DOMAIN_L4_CACHE_EXT 0x4
#define CL_AFFINITY_DOMAIN_NUMA_EXT 0x10
#define CL_AFFINITY_DOMAIN_NEXT_FISSIONABLE_EXT 0x100
/* cl_device_partition_property_ext list terminators */
#define CL_PROPERTIES_LIST_END_EXT ((cl_device_partition_property_ext) 0)
#define CL_PARTITION_BY_COUNTS_LIST_END_EXT ((cl_device_partition_property_ext) 0)
#define CL_PARTITION_BY_NAMES_LIST_END_EXT ((cl_device_partition_property_ext) 0 - 1)
/* cl_ext_atomic_counters_32 and cl_ext_atomic_counters_64 extensions
* no extension #define since they have no functions
*/
#define CL_DEVICE_MAX_ATOMIC_COUNTERS_EXT 0x4032
/*********************************
* cl_qcom_ext_host_ptr extension
*********************************/
#define CL_MEM_EXT_HOST_PTR_QCOM (1 << 29)
#define CL_DEVICE_EXT_MEM_PADDING_IN_BYTES_QCOM 0x40A0
#define CL_DEVICE_PAGE_SIZE_QCOM 0x40A1
#define CL_IMAGE_ROW_ALIGNMENT_QCOM 0x40A2
#define CL_IMAGE_SLICE_ALIGNMENT_QCOM 0x40A3
#define CL_MEM_HOST_UNCACHED_QCOM 0x40A4
#define CL_MEM_HOST_WRITEBACK_QCOM 0x40A5
#define CL_MEM_HOST_WRITETHROUGH_QCOM 0x40A6
#define CL_MEM_HOST_WRITE_COMBINING_QCOM 0x40A7
typedef cl_uint cl_image_pitch_info_qcom;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetDeviceImageInfoQCOM(cl_device_id device,
size_t image_width,
size_t image_height,
const cl_image_format *image_format,
cl_image_pitch_info_qcom param_name,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret);
typedef struct _cl_mem_ext_host_ptr
{
/* Type of external memory allocation. */
/* Legal values will be defined in layered extensions. */
cl_uint allocation_type;
/* Host cache policy for this external memory allocation. */
cl_uint host_cache_policy;
} cl_mem_ext_host_ptr;
/*********************************
* cl_qcom_ion_host_ptr extension
*********************************/
#define CL_MEM_ION_HOST_PTR_QCOM 0x40A8
typedef struct _cl_mem_ion_host_ptr
{
/* Type of external memory allocation. */
/* Must be CL_MEM_ION_HOST_PTR_QCOM for ION allocations. */
cl_mem_ext_host_ptr ext_host_ptr;
/* ION file descriptor */
int ion_filedesc;
/* Host pointer to the ION allocated memory */
void* ion_hostptr;
} cl_mem_ion_host_ptr;
#endif /* CL_VERSION_1_1 */
#if defined(CL_VERSION_1_2)
/******************************************
* cl_img_yuv_image extension *
******************************************/
/* Image formats used in clCreateImage */
#define CL_NV21_IMG 0x40D0
#define CL_YV12_IMG 0x40D1
/******************************************
* cl_img_cached_allocations extension *
******************************************/
/* Flag values used by clCreteBuffer */
#define CL_MEM_USE_UNCACHED_CPU_MEMORY_IMG (1 << 26)
#define CL_MEM_USE_CACHED_CPU_MEMORY_IMG (1 << 27)
/******************************************
* cl_img_use_gralloc_ptr extension *
******************************************/
/* Flag values used by clCreteBuffer */
#define CL_MEM_USE_GRALLOC_PTR_IMG (1 << 28)
/* To be used by clGetEventInfo: */
#define CL_COMMAND_ACQUIRE_GRALLOC_OBJECTS_IMG 0x40D2
#define CL_COMMAND_RELEASE_GRALLOC_OBJECTS_IMG 0x40D3
/* Error code from clEnqueueReleaseGrallocObjectsIMG */
#define CL_GRALLOC_RESOURCE_NOT_ACQUIRED_IMG 0x40D4
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireGrallocObjectsIMG(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseGrallocObjectsIMG(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
#endif /* CL_VERSION_1_2 */
#ifdef CL_VERSION_2_0
/*********************************
* cl_khr_subgroups extension
*********************************/
#define cl_khr_subgroups 1
/* cl_kernel_sub_group_info is declared in CL.h. */
/* cl_kernel_sub_group_info */
#define CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR 0x2033
#define CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR 0x2034
extern CL_API_ENTRY cl_int CL_API_CALL
clGetKernelSubGroupInfoKHR(cl_kernel /* in_kernel */,
cl_device_id /*in_device*/,
cl_kernel_sub_group_info /* param_name */,
size_t /*input_value_size*/,
const void * /*input_value*/,
size_t /*param_value_size*/,
void* /*param_value*/,
size_t* /*param_value_size_ret*/ ) CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED;
typedef CL_API_ENTRY cl_int
( CL_API_CALL * clGetKernelSubGroupInfoKHR_fn)(cl_kernel /* in_kernel */,
cl_device_id /*in_device*/,
cl_kernel_sub_group_info /* param_name */,
size_t /*input_value_size*/,
const void * /*input_value*/,
size_t /*param_value_size*/,
void* /*param_value*/,
size_t* /*param_value_size_ret*/ ) CL_EXT_SUFFIX__VERSION_2_0_DEPRECATED;
#endif /* CL_VERSION_2_0 */
#ifdef CL_VERSION_2_1
/*********************************
* cl_khr_priority_hints extension
*********************************/
#define cl_khr_priority_hints 1
typedef cl_uint cl_queue_priority_khr;
/* cl_command_queue_properties */
#define CL_QUEUE_PRIORITY_KHR 0x1096
/* cl_queue_priority_khr */
#define CL_QUEUE_PRIORITY_HIGH_KHR (1<<0)
#define CL_QUEUE_PRIORITY_MED_KHR (1<<1)
#define CL_QUEUE_PRIORITY_LOW_KHR (1<<2)
#endif /* CL_VERSION_2_1 */
#ifdef CL_VERSION_2_1
/*********************************
* cl_khr_throttle_hints extension
*********************************/
#define cl_khr_throttle_hints 1
typedef cl_uint cl_queue_throttle_khr;
/* cl_command_queue_properties */
#define CL_QUEUE_THROTTLE_KHR 0x1097
/* cl_queue_throttle_khr */
#define CL_QUEUE_THROTTLE_HIGH_KHR (1<<0)
#define CL_QUEUE_THROTTLE_MED_KHR (1<<1)
#define CL_QUEUE_THROTTLE_LOW_KHR (1<<2)
#endif /* CL_VERSION_2_1 */
#ifdef CL_VERSION_2_2
/*********************************
* cl_khr_subgroup_named_barrier
*********************************/
#define cl_khr_subgroup_named_barrier 1
/* cl_device_info */
#define CL_DEVICE_MAX_NAMED_BARRIER_COUNT_KHR 0x2035
#endif /* CL_VERSION_2_2 */
/**********************************
* cl_arm_import_memory extension *
**********************************/
#ifdef CL_VERSION_1_0
typedef intptr_t cl_import_properties_arm;
/* Default and valid proporties name for cl_arm_import_memory */
#define CL_IMPORT_TYPE_ARM 0x40B2
/* Host process memory type default value for CL_IMPORT_TYPE_ARM property */
#define CL_IMPORT_TYPE_HOST_ARM 0x40B3
/* DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */
#define CL_IMPORT_TYPE_DMA_BUF_ARM 0x40B4
/* Secure DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */
#define CL_IMPORT_TYPE_SECURE_ARM 0x40B5
/* This extension adds a new function that allows for direct memory import into
* OpenCL via the clImportMemoryARM function.
*
* Memory imported through this interface will be mapped into the device's page
* tables directly, providing zero copy access. It will never fall back to copy
* operations and aliased buffers.
*
* Types of memory supported for import are specified as additional extension
* strings.
*
* This extension produces cl_mem allocations which are compatible with all other
* users of cl_mem in the standard API.
*
* This extension maps pages with the same properties as the normal buffer creation
* function clCreateBuffer.
*/
extern CL_API_ENTRY cl_mem CL_API_CALL
clImportMemoryARM( cl_context context,
cl_mem_flags flags,
const cl_import_properties_arm *properties,
void *memory,
size_t size,
cl_int *errcode_ret) CL_EXT_SUFFIX__VERSION_1_0;
#endif /* CL_VERSION_1_0 */
/******************************************
* cl_arm_shared_virtual_memory extension *
******************************************/
#ifdef CL_VERSION_1_2
/* Used by clGetDeviceInfo */
#define CL_DEVICE_SVM_CAPABILITIES_ARM 0x40B6
/* Used by clGetMemObjectInfo */
#define CL_MEM_USES_SVM_POINTER_ARM 0x40B7
/* Used by clSetKernelExecInfoARM: */
#define CL_KERNEL_EXEC_INFO_SVM_PTRS_ARM 0x40B8
#define CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM_ARM 0x40B9
/* To be used by clGetEventInfo: */
#define CL_COMMAND_SVM_FREE_ARM 0x40BA
#define CL_COMMAND_SVM_MEMCPY_ARM 0x40BB
#define CL_COMMAND_SVM_MEMFILL_ARM 0x40BC
#define CL_COMMAND_SVM_MAP_ARM 0x40BD
#define CL_COMMAND_SVM_UNMAP_ARM 0x40BE
/* Flag values returned by clGetDeviceInfo with CL_DEVICE_SVM_CAPABILITIES_ARM as the param_name. */
#define CL_DEVICE_SVM_COARSE_GRAIN_BUFFER_ARM (1 << 0)
#define CL_DEVICE_SVM_FINE_GRAIN_BUFFER_ARM (1 << 1)
#define CL_DEVICE_SVM_FINE_GRAIN_SYSTEM_ARM (1 << 2)
#define CL_DEVICE_SVM_ATOMICS_ARM (1 << 3)
/* Flag values used by clSVMAllocARM: */
#define CL_MEM_SVM_FINE_GRAIN_BUFFER_ARM (1 << 10)
#define CL_MEM_SVM_ATOMICS_ARM (1 << 11)
typedef cl_bitfield cl_svm_mem_flags_arm;
typedef cl_uint cl_kernel_exec_info_arm;
typedef cl_bitfield cl_device_svm_capabilities_arm;
extern CL_API_ENTRY void * CL_API_CALL
clSVMAllocARM(cl_context /* context */,
cl_svm_mem_flags_arm /* flags */,
size_t /* size */,
cl_uint /* alignment */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY void CL_API_CALL
clSVMFreeARM(cl_context /* context */,
void * /* svm_pointer */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMFreeARM(cl_command_queue /* command_queue */,
cl_uint /* num_svm_pointers */,
void *[] /* svm_pointers[] */,
void (CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */,
cl_uint /* num_svm_pointers */,
void *[] /* svm_pointers[] */,
void * /* user_data */),
void * /* user_data */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMMemcpyARM(cl_command_queue /* command_queue */,
cl_bool /* blocking_copy */,
void * /* dst_ptr */,
const void * /* src_ptr */,
size_t /* size */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMMemFillARM(cl_command_queue /* command_queue */,
void * /* svm_ptr */,
const void * /* pattern */,
size_t /* pattern_size */,
size_t /* size */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMMapARM(cl_command_queue /* command_queue */,
cl_bool /* blocking_map */,
cl_map_flags /* flags */,
void * /* svm_ptr */,
size_t /* size */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueSVMUnmapARM(cl_command_queue /* command_queue */,
void * /* svm_ptr */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clSetKernelArgSVMPointerARM(cl_kernel /* kernel */,
cl_uint /* arg_index */,
const void * /* arg_value */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clSetKernelExecInfoARM(cl_kernel /* kernel */,
cl_kernel_exec_info_arm /* param_name */,
size_t /* param_value_size */,
const void * /* param_value */) CL_EXT_SUFFIX__VERSION_1_2;
#endif /* CL_VERSION_1_2 */
#ifdef __cplusplus
}
#endif
#endif /* __CL_EXT_H */

View File

@@ -1,429 +0,0 @@
/*******************************************************************************
* Copyright (c) 2008-2017 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
/*****************************************************************************\
Copyright (c) 2013-2017 Intel Corporation All Rights Reserved.
THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE
MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File Name: cl_ext_intel.h
Abstract:
Notes:
\*****************************************************************************/
#ifndef __CL_EXT_INTEL_H
#define __CL_EXT_INTEL_H
#ifdef __APPLE__
#include <OpenCL/cl.h>
#include <OpenCL/cl_platform.h>
#else
#include "cl.h"
#include "cl_platform.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
/***************************************
* cl_intel_thread_local_exec extension *
****************************************/
#define cl_intel_thread_local_exec 1
#define CL_QUEUE_THREAD_LOCAL_EXEC_ENABLE_INTEL (((cl_bitfield)1) << 31)
/***********************************************
* cl_intel_device_partition_by_names extension *
************************************************/
#define cl_intel_device_partition_by_names 1
#define CL_DEVICE_PARTITION_BY_NAMES_INTEL 0x4052
#define CL_PARTITION_BY_NAMES_LIST_END_INTEL -1
/************************************************
* cl_intel_accelerator extension *
* cl_intel_motion_estimation extension *
* cl_intel_advanced_motion_estimation extension *
*************************************************/
#define cl_intel_accelerator 1
#define cl_intel_motion_estimation 1
#define cl_intel_advanced_motion_estimation 1
typedef struct _cl_accelerator_intel* cl_accelerator_intel;
typedef cl_uint cl_accelerator_type_intel;
typedef cl_uint cl_accelerator_info_intel;
typedef struct _cl_motion_estimation_desc_intel {
cl_uint mb_block_type;
cl_uint subpixel_mode;
cl_uint sad_adjust_mode;
cl_uint search_path_type;
} cl_motion_estimation_desc_intel;
/* error codes */
#define CL_INVALID_ACCELERATOR_INTEL -1094
#define CL_INVALID_ACCELERATOR_TYPE_INTEL -1095
#define CL_INVALID_ACCELERATOR_DESCRIPTOR_INTEL -1096
#define CL_ACCELERATOR_TYPE_NOT_SUPPORTED_INTEL -1097
/* cl_accelerator_type_intel */
#define CL_ACCELERATOR_TYPE_MOTION_ESTIMATION_INTEL 0x0
/* cl_accelerator_info_intel */
#define CL_ACCELERATOR_DESCRIPTOR_INTEL 0x4090
#define CL_ACCELERATOR_REFERENCE_COUNT_INTEL 0x4091
#define CL_ACCELERATOR_CONTEXT_INTEL 0x4092
#define CL_ACCELERATOR_TYPE_INTEL 0x4093
/* cl_motion_detect_desc_intel flags */
#define CL_ME_MB_TYPE_16x16_INTEL 0x0
#define CL_ME_MB_TYPE_8x8_INTEL 0x1
#define CL_ME_MB_TYPE_4x4_INTEL 0x2
#define CL_ME_SUBPIXEL_MODE_INTEGER_INTEL 0x0
#define CL_ME_SUBPIXEL_MODE_HPEL_INTEL 0x1
#define CL_ME_SUBPIXEL_MODE_QPEL_INTEL 0x2
#define CL_ME_SAD_ADJUST_MODE_NONE_INTEL 0x0
#define CL_ME_SAD_ADJUST_MODE_HAAR_INTEL 0x1
#define CL_ME_SEARCH_PATH_RADIUS_2_2_INTEL 0x0
#define CL_ME_SEARCH_PATH_RADIUS_4_4_INTEL 0x1
#define CL_ME_SEARCH_PATH_RADIUS_16_12_INTEL 0x5
#define CL_ME_SKIP_BLOCK_TYPE_16x16_INTEL 0x0
#define CL_ME_CHROMA_INTRA_PREDICT_ENABLED_INTEL 0x1
#define CL_ME_LUMA_INTRA_PREDICT_ENABLED_INTEL 0x2
#define CL_ME_SKIP_BLOCK_TYPE_8x8_INTEL 0x4
#define CL_ME_FORWARD_INPUT_MODE_INTEL 0x1
#define CL_ME_BACKWARD_INPUT_MODE_INTEL 0x2
#define CL_ME_BIDIRECTION_INPUT_MODE_INTEL 0x3
#define CL_ME_BIDIR_WEIGHT_QUARTER_INTEL 16
#define CL_ME_BIDIR_WEIGHT_THIRD_INTEL 21
#define CL_ME_BIDIR_WEIGHT_HALF_INTEL 32
#define CL_ME_BIDIR_WEIGHT_TWO_THIRD_INTEL 43
#define CL_ME_BIDIR_WEIGHT_THREE_QUARTER_INTEL 48
#define CL_ME_COST_PENALTY_NONE_INTEL 0x0
#define CL_ME_COST_PENALTY_LOW_INTEL 0x1
#define CL_ME_COST_PENALTY_NORMAL_INTEL 0x2
#define CL_ME_COST_PENALTY_HIGH_INTEL 0x3
#define CL_ME_COST_PRECISION_QPEL_INTEL 0x0
#define CL_ME_COST_PRECISION_HPEL_INTEL 0x1
#define CL_ME_COST_PRECISION_PEL_INTEL 0x2
#define CL_ME_COST_PRECISION_DPEL_INTEL 0x3
#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_INTEL 0x0
#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_ME_LUMA_PREDICTOR_MODE_DC_INTEL 0x2
#define CL_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_LEFT_INTEL 0x3
#define CL_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_RIGHT_INTEL 0x4
#define CL_ME_LUMA_PREDICTOR_MODE_PLANE_INTEL 0x4
#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_RIGHT_INTEL 0x5
#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_DOWN_INTEL 0x6
#define CL_ME_LUMA_PREDICTOR_MODE_VERTICAL_LEFT_INTEL 0x7
#define CL_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_UP_INTEL 0x8
#define CL_ME_CHROMA_PREDICTOR_MODE_DC_INTEL 0x0
#define CL_ME_CHROMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_ME_CHROMA_PREDICTOR_MODE_VERTICAL_INTEL 0x2
#define CL_ME_CHROMA_PREDICTOR_MODE_PLANE_INTEL 0x3
/* cl_device_info */
#define CL_DEVICE_ME_VERSION_INTEL 0x407E
#define CL_ME_VERSION_LEGACY_INTEL 0x0
#define CL_ME_VERSION_ADVANCED_VER_1_INTEL 0x1
#define CL_ME_VERSION_ADVANCED_VER_2_INTEL 0x2
extern CL_API_ENTRY cl_accelerator_intel CL_API_CALL
clCreateAcceleratorINTEL(
cl_context /* context */,
cl_accelerator_type_intel /* accelerator_type */,
size_t /* descriptor_size */,
const void* /* descriptor */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_accelerator_intel (CL_API_CALL *clCreateAcceleratorINTEL_fn)(
cl_context /* context */,
cl_accelerator_type_intel /* accelerator_type */,
size_t /* descriptor_size */,
const void* /* descriptor */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetAcceleratorInfoINTEL(
cl_accelerator_intel /* accelerator */,
cl_accelerator_info_intel /* param_name */,
size_t /* param_value_size */,
void* /* param_value */,
size_t* /* param_value_size_ret */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetAcceleratorInfoINTEL_fn)(
cl_accelerator_intel /* accelerator */,
cl_accelerator_info_intel /* param_name */,
size_t /* param_value_size */,
void* /* param_value */,
size_t* /* param_value_size_ret */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clRetainAcceleratorINTEL(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clRetainAcceleratorINTEL_fn)(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clReleaseAcceleratorINTEL(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clReleaseAcceleratorINTEL_fn)(
cl_accelerator_intel /* accelerator */) CL_EXT_SUFFIX__VERSION_1_2;
/******************************************
* cl_intel_simultaneous_sharing extension *
*******************************************/
#define cl_intel_simultaneous_sharing 1
#define CL_DEVICE_SIMULTANEOUS_INTEROPS_INTEL 0x4104
#define CL_DEVICE_NUM_SIMULTANEOUS_INTEROPS_INTEL 0x4105
/***********************************
* cl_intel_egl_image_yuv extension *
************************************/
#define cl_intel_egl_image_yuv 1
#define CL_EGL_YUV_PLANE_INTEL 0x4107
/********************************
* cl_intel_packed_yuv extension *
*********************************/
#define cl_intel_packed_yuv 1
#define CL_YUYV_INTEL 0x4076
#define CL_UYVY_INTEL 0x4077
#define CL_YVYU_INTEL 0x4078
#define CL_VYUY_INTEL 0x4079
/********************************************
* cl_intel_required_subgroup_size extension *
*********************************************/
#define cl_intel_required_subgroup_size 1
#define CL_DEVICE_SUB_GROUP_SIZES_INTEL 0x4108
#define CL_KERNEL_SPILL_MEM_SIZE_INTEL 0x4109
#define CL_KERNEL_COMPILE_SUB_GROUP_SIZE_INTEL 0x410A
/****************************************
* cl_intel_driver_diagnostics extension *
*****************************************/
#define cl_intel_driver_diagnostics 1
typedef cl_uint cl_diagnostics_verbose_level;
#define CL_CONTEXT_SHOW_DIAGNOSTICS_INTEL 0x4106
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_ALL_INTEL ( 0xff )
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_GOOD_INTEL ( 1 )
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_BAD_INTEL ( 1 << 1 )
#define CL_CONTEXT_DIAGNOSTICS_LEVEL_NEUTRAL_INTEL ( 1 << 2 )
/********************************
* cl_intel_planar_yuv extension *
*********************************/
#define CL_NV12_INTEL 0x410E
#define CL_MEM_NO_ACCESS_INTEL ( 1 << 24 )
#define CL_MEM_ACCESS_FLAGS_UNRESTRICTED_INTEL ( 1 << 25 )
#define CL_DEVICE_PLANAR_YUV_MAX_WIDTH_INTEL 0x417E
#define CL_DEVICE_PLANAR_YUV_MAX_HEIGHT_INTEL 0x417F
/*******************************************************
* cl_intel_device_side_avc_motion_estimation extension *
********************************************************/
#define CL_DEVICE_AVC_ME_VERSION_INTEL 0x410B
#define CL_DEVICE_AVC_ME_SUPPORTS_TEXTURE_SAMPLER_USE_INTEL 0x410C
#define CL_DEVICE_AVC_ME_SUPPORTS_PREEMPTION_INTEL 0x410D
#define CL_AVC_ME_VERSION_0_INTEL 0x0; // No support.
#define CL_AVC_ME_VERSION_1_INTEL 0x1; // First supported version.
#define CL_AVC_ME_MAJOR_16x16_INTEL 0x0
#define CL_AVC_ME_MAJOR_16x8_INTEL 0x1
#define CL_AVC_ME_MAJOR_8x16_INTEL 0x2
#define CL_AVC_ME_MAJOR_8x8_INTEL 0x3
#define CL_AVC_ME_MINOR_8x8_INTEL 0x0
#define CL_AVC_ME_MINOR_8x4_INTEL 0x1
#define CL_AVC_ME_MINOR_4x8_INTEL 0x2
#define CL_AVC_ME_MINOR_4x4_INTEL 0x3
#define CL_AVC_ME_MAJOR_FORWARD_INTEL 0x0
#define CL_AVC_ME_MAJOR_BACKWARD_INTEL 0x1
#define CL_AVC_ME_MAJOR_BIDIRECTIONAL_INTEL 0x2
#define CL_AVC_ME_PARTITION_MASK_ALL_INTEL 0x0
#define CL_AVC_ME_PARTITION_MASK_16x16_INTEL 0x7E
#define CL_AVC_ME_PARTITION_MASK_16x8_INTEL 0x7D
#define CL_AVC_ME_PARTITION_MASK_8x16_INTEL 0x7B
#define CL_AVC_ME_PARTITION_MASK_8x8_INTEL 0x77
#define CL_AVC_ME_PARTITION_MASK_8x4_INTEL 0x6F
#define CL_AVC_ME_PARTITION_MASK_4x8_INTEL 0x5F
#define CL_AVC_ME_PARTITION_MASK_4x4_INTEL 0x3F
#define CL_AVC_ME_SEARCH_WINDOW_EXHAUSTIVE_INTEL 0x0
#define CL_AVC_ME_SEARCH_WINDOW_SMALL_INTEL 0x1
#define CL_AVC_ME_SEARCH_WINDOW_TINY_INTEL 0x2
#define CL_AVC_ME_SEARCH_WINDOW_EXTRA_TINY_INTEL 0x3
#define CL_AVC_ME_SEARCH_WINDOW_DIAMOND_INTEL 0x4
#define CL_AVC_ME_SEARCH_WINDOW_LARGE_DIAMOND_INTEL 0x5
#define CL_AVC_ME_SEARCH_WINDOW_RESERVED0_INTEL 0x6
#define CL_AVC_ME_SEARCH_WINDOW_RESERVED1_INTEL 0x7
#define CL_AVC_ME_SEARCH_WINDOW_CUSTOM_INTEL 0x8
#define CL_AVC_ME_SEARCH_WINDOW_16x12_RADIUS_INTEL 0x9
#define CL_AVC_ME_SEARCH_WINDOW_4x4_RADIUS_INTEL 0x2
#define CL_AVC_ME_SEARCH_WINDOW_2x2_RADIUS_INTEL 0xa
#define CL_AVC_ME_SAD_ADJUST_MODE_NONE_INTEL 0x0
#define CL_AVC_ME_SAD_ADJUST_MODE_HAAR_INTEL 0x2
#define CL_AVC_ME_SUBPIXEL_MODE_INTEGER_INTEL 0x0
#define CL_AVC_ME_SUBPIXEL_MODE_HPEL_INTEL 0x1
#define CL_AVC_ME_SUBPIXEL_MODE_QPEL_INTEL 0x3
#define CL_AVC_ME_COST_PRECISION_QPEL_INTEL 0x0
#define CL_AVC_ME_COST_PRECISION_HPEL_INTEL 0x1
#define CL_AVC_ME_COST_PRECISION_PEL_INTEL 0x2
#define CL_AVC_ME_COST_PRECISION_DPEL_INTEL 0x3
#define CL_AVC_ME_BIDIR_WEIGHT_QUARTER_INTEL 0x10
#define CL_AVC_ME_BIDIR_WEIGHT_THIRD_INTEL 0x15
#define CL_AVC_ME_BIDIR_WEIGHT_HALF_INTEL 0x20
#define CL_AVC_ME_BIDIR_WEIGHT_TWO_THIRD_INTEL 0x2B
#define CL_AVC_ME_BIDIR_WEIGHT_THREE_QUARTER_INTEL 0x30
#define CL_AVC_ME_BORDER_REACHED_LEFT_INTEL 0x0
#define CL_AVC_ME_BORDER_REACHED_RIGHT_INTEL 0x2
#define CL_AVC_ME_BORDER_REACHED_TOP_INTEL 0x4
#define CL_AVC_ME_BORDER_REACHED_BOTTOM_INTEL 0x8
#define CL_AVC_ME_SKIP_BLOCK_PARTITION_16x16_INTEL 0x0
#define CL_AVC_ME_SKIP_BLOCK_PARTITION_8x8_INTEL 0x4000
#define CL_AVC_ME_SKIP_BLOCK_16x16_FORWARD_ENABLE_INTEL ( 0x1 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_16x16_BACKWARD_ENABLE_INTEL ( 0x2 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_16x16_DUAL_ENABLE_INTEL ( 0x3 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_FORWARD_ENABLE_INTEL ( 0x55 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_BACKWARD_ENABLE_INTEL ( 0xAA << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_DUAL_ENABLE_INTEL ( 0xFF << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_0_FORWARD_ENABLE_INTEL ( 0x1 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_0_BACKWARD_ENABLE_INTEL ( 0x2 << 24 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_1_FORWARD_ENABLE_INTEL ( 0x1 << 26 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_1_BACKWARD_ENABLE_INTEL ( 0x2 << 26 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_2_FORWARD_ENABLE_INTEL ( 0x1 << 28 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_2_BACKWARD_ENABLE_INTEL ( 0x2 << 28 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_3_FORWARD_ENABLE_INTEL ( 0x1 << 30 )
#define CL_AVC_ME_SKIP_BLOCK_8x8_3_BACKWARD_ENABLE_INTEL ( 0x2 << 30 )
#define CL_AVC_ME_BLOCK_BASED_SKIP_4x4_INTEL 0x00
#define CL_AVC_ME_BLOCK_BASED_SKIP_8x8_INTEL 0x80
#define CL_AVC_ME_INTRA_16x16_INTEL 0x0
#define CL_AVC_ME_INTRA_8x8_INTEL 0x1
#define CL_AVC_ME_INTRA_4x4_INTEL 0x2
#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_16x16_INTEL 0x6
#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_8x8_INTEL 0x5
#define CL_AVC_ME_INTRA_LUMA_PARTITION_MASK_4x4_INTEL 0x3
#define CL_AVC_ME_INTRA_NEIGHBOR_LEFT_MASK_ENABLE_INTEL 0x60
#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_MASK_ENABLE_INTEL 0x10
#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_RIGHT_MASK_ENABLE_INTEL 0x8
#define CL_AVC_ME_INTRA_NEIGHBOR_UPPER_LEFT_MASK_ENABLE_INTEL 0x4
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_INTEL 0x0
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DC_INTEL 0x2
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_LEFT_INTEL 0x3
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_DIAGONAL_DOWN_RIGHT_INTEL 0x4
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_PLANE_INTEL 0x4
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_RIGHT_INTEL 0x5
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_DOWN_INTEL 0x6
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_VERTICAL_LEFT_INTEL 0x7
#define CL_AVC_ME_LUMA_PREDICTOR_MODE_HORIZONTAL_UP_INTEL 0x8
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_DC_INTEL 0x0
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_HORIZONTAL_INTEL 0x1
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_VERTICAL_INTEL 0x2
#define CL_AVC_ME_CHROMA_PREDICTOR_MODE_PLANE_INTEL 0x3
#define CL_AVC_ME_FRAME_FORWARD_INTEL 0x1
#define CL_AVC_ME_FRAME_BACKWARD_INTEL 0x2
#define CL_AVC_ME_FRAME_DUAL_INTEL 0x3
#define CL_AVC_ME_SLICE_TYPE_PRED_INTEL 0x0
#define CL_AVC_ME_SLICE_TYPE_BPRED_INTEL 0x1
#define CL_AVC_ME_SLICE_TYPE_INTRA_INTEL 0x2
#define CL_AVC_ME_INTERLACED_SCAN_TOP_FIELD_INTEL 0x0
#define CL_AVC_ME_INTERLACED_SCAN_BOTTOM_FIELD_INTEL 0x1
#ifdef __cplusplus
}
#endif
#endif /* __CL_EXT_INTEL_H */

View File

@@ -1,167 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
#ifndef __OPENCL_CL_GL_H
#define __OPENCL_CL_GL_H
#ifdef __APPLE__
#include <OpenCL/cl.h>
#else
#include "cl.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef cl_uint cl_gl_object_type;
typedef cl_uint cl_gl_texture_info;
typedef cl_uint cl_gl_platform_info;
typedef struct __GLsync *cl_GLsync;
/* cl_gl_object_type = 0x2000 - 0x200F enum values are currently taken */
#define CL_GL_OBJECT_BUFFER 0x2000
#define CL_GL_OBJECT_TEXTURE2D 0x2001
#define CL_GL_OBJECT_TEXTURE3D 0x2002
#define CL_GL_OBJECT_RENDERBUFFER 0x2003
#define CL_GL_OBJECT_TEXTURE2D_ARRAY 0x200E
#define CL_GL_OBJECT_TEXTURE1D 0x200F
#define CL_GL_OBJECT_TEXTURE1D_ARRAY 0x2010
#define CL_GL_OBJECT_TEXTURE_BUFFER 0x2011
/* cl_gl_texture_info */
#define CL_GL_TEXTURE_TARGET 0x2004
#define CL_GL_MIPMAP_LEVEL 0x2005
#define CL_GL_NUM_SAMPLES 0x2012
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromGLBuffer(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLuint /* bufobj */,
int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromGLTexture(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLenum /* target */,
cl_GLint /* miplevel */,
cl_GLuint /* texture */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromGLRenderbuffer(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLuint /* renderbuffer */,
cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetGLObjectInfo(cl_mem /* memobj */,
cl_gl_object_type * /* gl_object_type */,
cl_GLuint * /* gl_object_name */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetGLTextureInfo(cl_mem /* memobj */,
cl_gl_texture_info /* param_name */,
size_t /* param_value_size */,
void * /* param_value */,
size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireGLObjects(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseGLObjects(cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem * /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event * /* event_wait_list */,
cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
/* Deprecated OpenCL 1.1 APIs */
extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL
clCreateFromGLTexture2D(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLenum /* target */,
cl_GLint /* miplevel */,
cl_GLuint /* texture */,
cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED;
extern CL_API_ENTRY CL_EXT_PREFIX__VERSION_1_1_DEPRECATED cl_mem CL_API_CALL
clCreateFromGLTexture3D(cl_context /* context */,
cl_mem_flags /* flags */,
cl_GLenum /* target */,
cl_GLint /* miplevel */,
cl_GLuint /* texture */,
cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1_DEPRECATED;
/* cl_khr_gl_sharing extension */
#define cl_khr_gl_sharing 1
typedef cl_uint cl_gl_context_info;
/* Additional Error Codes */
#define CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR -1000
/* cl_gl_context_info */
#define CL_CURRENT_DEVICE_FOR_GL_CONTEXT_KHR 0x2006
#define CL_DEVICES_FOR_GL_CONTEXT_KHR 0x2007
/* Additional cl_context_properties */
#define CL_GL_CONTEXT_KHR 0x2008
#define CL_EGL_DISPLAY_KHR 0x2009
#define CL_GLX_DISPLAY_KHR 0x200A
#define CL_WGL_HDC_KHR 0x200B
#define CL_CGL_SHAREGROUP_KHR 0x200C
extern CL_API_ENTRY cl_int CL_API_CALL
clGetGLContextInfoKHR(const cl_context_properties * /* properties */,
cl_gl_context_info /* param_name */,
size_t /* param_value_size */,
void * /* param_value */,
size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetGLContextInfoKHR_fn)(
const cl_context_properties * properties,
cl_gl_context_info param_name,
size_t param_value_size,
void * param_value,
size_t * param_value_size_ret);
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_GL_H */

View File

@@ -1,74 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
/* cl_gl_ext.h contains vendor (non-KHR) OpenCL extensions which have */
/* OpenGL dependencies. */
#ifndef __OPENCL_CL_GL_EXT_H
#define __OPENCL_CL_GL_EXT_H
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __APPLE__
#include <OpenCL/cl_gl.h>
#else
#include "cl_gl.h"
#endif
/*
* For each extension, follow this template
* cl_VEN_extname extension */
/* #define cl_VEN_extname 1
* ... define new types, if any
* ... define new tokens, if any
* ... define new APIs, if any
*
* If you need GLtypes here, mirror them with a cl_GLtype, rather than including a GL header
* This allows us to avoid having to decide whether to include GL headers or GLES here.
*/
/*
* cl_khr_gl_event extension
* See section 9.9 in the OpenCL 1.1 spec for more information
*/
#define CL_COMMAND_GL_FENCE_SYNC_OBJECT_KHR 0x200D
extern CL_API_ENTRY cl_event CL_API_CALL
clCreateEventFromGLsyncKHR(cl_context /* context */,
cl_GLsync /* cl_GLsync */,
cl_int * /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_1;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_GL_EXT_H */

File diff suppressed because it is too large Load Diff

View File

@@ -1,172 +0,0 @@
/**********************************************************************************
* Copyright (c) 2008-2016 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
**********************************************************************************/
/*****************************************************************************\
Copyright (c) 2013-2016 Intel Corporation All Rights Reserved.
THESE MATERIALS ARE PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL INTEL OR ITS
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THESE
MATERIALS, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
File Name: cl_va_api_media_sharing_intel.h
Abstract:
Notes:
\*****************************************************************************/
#ifndef __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
#define __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H
#include "cl.h"
#include "cl_platform.h"
#include <va/va.h>
#ifdef __cplusplus
extern "C" {
#endif
/******************************************
* cl_intel_va_api_media_sharing extension *
*******************************************/
#define cl_intel_va_api_media_sharing 1
/* error codes */
#define CL_INVALID_VA_API_MEDIA_ADAPTER_INTEL -1098
#define CL_INVALID_VA_API_MEDIA_SURFACE_INTEL -1099
#define CL_VA_API_MEDIA_SURFACE_ALREADY_ACQUIRED_INTEL -1100
#define CL_VA_API_MEDIA_SURFACE_NOT_ACQUIRED_INTEL -1101
/* cl_va_api_device_source_intel */
#define CL_VA_API_DISPLAY_INTEL 0x4094
/* cl_va_api_device_set_intel */
#define CL_PREFERRED_DEVICES_FOR_VA_API_INTEL 0x4095
#define CL_ALL_DEVICES_FOR_VA_API_INTEL 0x4096
/* cl_context_info */
#define CL_CONTEXT_VA_API_DISPLAY_INTEL 0x4097
/* cl_mem_info */
#define CL_MEM_VA_API_MEDIA_SURFACE_INTEL 0x4098
/* cl_image_info */
#define CL_IMAGE_VA_API_PLANE_INTEL 0x4099
/* cl_command_type */
#define CL_COMMAND_ACQUIRE_VA_API_MEDIA_SURFACES_INTEL 0x409A
#define CL_COMMAND_RELEASE_VA_API_MEDIA_SURFACES_INTEL 0x409B
typedef cl_uint cl_va_api_device_source_intel;
typedef cl_uint cl_va_api_device_set_intel;
extern CL_API_ENTRY cl_int CL_API_CALL
clGetDeviceIDsFromVA_APIMediaAdapterINTEL(
cl_platform_id /* platform */,
cl_va_api_device_source_intel /* media_adapter_type */,
void* /* media_adapter */,
cl_va_api_device_set_intel /* media_adapter_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL * clGetDeviceIDsFromVA_APIMediaAdapterINTEL_fn)(
cl_platform_id /* platform */,
cl_va_api_device_source_intel /* media_adapter_type */,
void* /* media_adapter */,
cl_va_api_device_set_intel /* media_adapter_set */,
cl_uint /* num_entries */,
cl_device_id* /* devices */,
cl_uint* /* num_devices */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_mem CL_API_CALL
clCreateFromVA_APIMediaSurfaceINTEL(
cl_context /* context */,
cl_mem_flags /* flags */,
VASurfaceID* /* surface */,
cl_uint /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_mem (CL_API_CALL * clCreateFromVA_APIMediaSurfaceINTEL_fn)(
cl_context /* context */,
cl_mem_flags /* flags */,
VASurfaceID* /* surface */,
cl_uint /* plane */,
cl_int* /* errcode_ret */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueAcquireVA_APIMediaSurfacesINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueAcquireVA_APIMediaSurfacesINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
extern CL_API_ENTRY cl_int CL_API_CALL
clEnqueueReleaseVA_APIMediaSurfacesINTEL(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
typedef CL_API_ENTRY cl_int (CL_API_CALL *clEnqueueReleaseVA_APIMediaSurfacesINTEL_fn)(
cl_command_queue /* command_queue */,
cl_uint /* num_objects */,
const cl_mem* /* mem_objects */,
cl_uint /* num_events_in_wait_list */,
const cl_event* /* event_wait_list */,
cl_event* /* event */) CL_EXT_SUFFIX__VERSION_1_2;
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_CL_VA_API_MEDIA_SHARING_INTEL_H */

View File

@@ -1,59 +0,0 @@
/*******************************************************************************
* Copyright (c) 2008-2015 The Khronos Group Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and/or associated documentation files (the
* "Materials"), to deal in the Materials without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Materials, and to
* permit persons to whom the Materials are furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Materials.
*
* MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS
* KHRONOS STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS
* SPECIFICATIONS AND HEADER INFORMATION ARE LOCATED AT
* https://www.khronos.org/registry/
*
* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
******************************************************************************/
/* $Revision: 11708 $ on $Date: 2010-06-13 23:36:24 -0700 (Sun, 13 Jun 2010) $ */
#ifndef __OPENCL_H
#define __OPENCL_H
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __APPLE__
#include <OpenCL/cl.h>
#include <OpenCL/cl_gl.h>
#include <OpenCL/cl_gl_ext.h>
#include <OpenCL/cl_ext.h>
#else
#include "cl.h"
#include "cl_gl.h"
#include "cl_gl_ext.h"
#include "cl_ext.h"
#endif
#ifdef __cplusplus
}
#endif
#endif /* __OPENCL_H */

6027
include/triton/external/CUDA/cuda.h vendored Executable file → Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -328,7 +328,7 @@ typedef enum nvmlGpuLevel_enum
typedef enum nvmlGpuP2PStatus_enum
{
NVML_P2P_STATUS_OK = 0,
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED,
NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED,
NVML_P2P_STATUS_GPU_NOT_SUPPORTED,
NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED,
NVML_P2P_STATUS_DISABLED_BY_REGKEY,
@@ -736,7 +736,7 @@ typedef enum nvmlReturn_enum
NVML_ERROR_IN_USE = 19, //!< An operation cannot be performed because the GPU is currently in use
NVML_ERROR_MEMORY = 20, //!< Insufficient memory
NVML_ERROR_NO_DATA = 21, //!<No data
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, becasue ECC is enabled
NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22, //!< The requested vgpu operation is not available on target device, because ECC is enabled
NVML_ERROR_UNKNOWN = 999 //!< An internal driver error occurred
} nvmlReturn_t;
@@ -1463,7 +1463,7 @@ typedef struct nvmlEncoderSessionInfo_st
*/
typedef enum nvmlFBCSessionType_enum
{
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknwon
NVML_FBC_SESSION_TYPE_UNKNOWN = 0, //!< Unknown
NVML_FBC_SESSION_TYPE_TOSYS, //!< ToSys
NVML_FBC_SESSION_TYPE_CUDA, //!< Cuda
NVML_FBC_SESSION_TYPE_VID, //!< Vid
@@ -3678,10 +3678,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetEncoderStats (nvmlDevice_t device, unsigned in
* Retrieves information about active encoder sessions on a target device.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfos. The
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
@@ -3727,7 +3727,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetDecoderUtilization(nvmlDevice_t device, unsign
* For Maxwell &tm; or newer fully supported devices.
*
* @param device The identifier of the target device
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
*
* @return
* - \ref NVML_SUCCESS if \a fbcStats is fetched
@@ -3742,10 +3742,10 @@ nvmlReturn_t DECLDIR nvmlDeviceGetFBCStats(nvmlDevice_t device, nvmlFBCStats_t *
* Retrieves information about active frame buffer capture sessions on a target device.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
@@ -4208,7 +4208,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetRetiredPages(nvmlDevice_t device, nvmlPageReti
* The address information provided from this API is the hardware address of the page that was retired. Note
* that this does not match the virtual address used in CUDA, but will match the address information in XID 63
*
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps paramter to return the time of each page's
* \note nvmlDeviceGetRetiredPages_v2 adds an additional timestamps parameter to return the time of each page's
* retirement.
*
* For Kepler &tm; or newer fully supported devices.
@@ -4476,7 +4476,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetDriverModel(nvmlDevice_t device, nvmlDriverMod
* Set clocks that device will lock to.
*
* Sets the clocks that the device will be running at to the value in the range of minGpuClockMHz to maxGpuClockMHz.
* Setting this will supercede application clock values and take effect regardless if a cuda app is running.
* Setting this will supersede application clock values and take effect regardless if a cuda app is running.
* See /ref nvmlDeviceSetApplicationsClocks
*
* Can be used as a setting to request constant performance.
@@ -5297,7 +5297,7 @@ nvmlReturn_t DECLDIR nvmlDeviceSetVirtualizationMode(nvmlDevice_t device, nvmlGp
* pointed at by \a vgpuTypeIds. The element count of nvmlVgpuTypeId_t array is passed in \a vgpuCount, and \a vgpuCount
* is used to return the number of vGPU types written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
* To query the number of vGPU types supported for the GPU, call this function with *vgpuCount = 0.
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are supported.
@@ -5327,9 +5327,9 @@ nvmlReturn_t DECLDIR nvmlDeviceGetSupportedVgpus(nvmlDevice_t device, unsigned i
* can concurrently run on a device. For example, if only one vGPU type is allowed at a time on a device, then the creatable
* list will be restricted to whatever vGPU type is already running on the device.
*
* If the supplied buffer is not large enough to accomodate the vGPU type array, the function returns
* If the supplied buffer is not large enough to accommodate the vGPU type array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuTypeId_t array required in \a vgpuCount.
* To query the number of vGPU types createable for the GPU, call this function with *vgpuCount = 0.
* To query the number of vGPU types creatable for the GPU, call this function with *vgpuCount = 0.
* The code will return NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU types are creatable.
*
* @param device The identifier of the target device
@@ -5392,7 +5392,7 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetName(nvmlVgpuTypeId_t vgpuTypeId, char *vgpu
*
* @param vgpuTypeId Handle to vGPU type
* @param deviceID Device ID and vendor ID of the device contained in single 32 bit value
* @param subsystemID Subsytem ID and subsytem vendor ID of the device contained in single 32 bit value
* @param subsystemID subsystem ID and subsystem vendor ID of the device contained in single 32 bit value
*
* @return
* - \ref NVML_SUCCESS successful completion
@@ -5516,10 +5516,10 @@ nvmlReturn_t DECLDIR nvmlVgpuTypeGetMaxInstances(nvmlDevice_t device, nvmlVgpuTy
* Retrieve the active vGPU instances on a device.
*
* An array of active vGPU instances is returned in the caller-supplied buffer pointed at by \a vgpuInstances. The
* array elememt count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
* array element count is passed in \a vgpuCount, and \a vgpuCount is used to return the number of vGPU instances
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the vGPU instance array, the function returns
* If the supplied buffer is not large enough to accommodate the vGPU instance array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlVgpuInstance_t array required in \a vgpuCount.
* To query the number of active vGPU instances, call this function with *vgpuCount = 0. The code will return
* NVML_ERROR_INSUFFICIENT_SIZE, or NVML_SUCCESS if no vGPU Types are supported.
@@ -5702,7 +5702,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFrameRateLimit(nvmlVgpuInstance_t vgpuIn
* @param encoderCapacity Reference to an unsigned int for the encoder capacity
*
* @return
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrived
* - \ref NVML_SUCCESS if \a encoderCapacity has been retrieved
* - \ref NVML_ERROR_UNINITIALIZED if the library has not been successfully initialized
* - \ref NVML_ERROR_INVALID_ARGUMENT if \a vgpuInstance is 0, or \a encoderQueryType is invalid
* - \ref NVML_ERROR_NOT_FOUND if \a vgpuInstance does not match a valid active vGPU instance on the system
@@ -5863,10 +5863,10 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderStats(nvmlVgpuInstance_t vgpuInst
* Retrieves information about all active encoder sessions on a vGPU Instance.
*
* An array of active encoder sessions is returned in the caller-supplied buffer pointed at by \a sessionInfo. The
* array elememt count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlEncoderSessionInfo_t array required in \a sessionCount.
* To query the number of active encoder sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active encoder sessions updated in *sessionCount.
@@ -5896,7 +5896,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetEncoderSessions(nvmlVgpuInstance_t vgpuI
* For Maxwell &tm; or newer fully supported devices.
*
* @param vgpuInstance Identifier of the target vGPU instance
* @param fbcStats Reference to nvmlFBCStats_t structure contianing NvFBC stats
* @param fbcStats Reference to nvmlFBCStats_t structure containing NvFBC stats
*
* @return
* - \ref NVML_SUCCESS if \a fbcStats is fetched
@@ -5914,7 +5914,7 @@ nvmlReturn_t DECLDIR nvmlVgpuInstanceGetFBCStats(nvmlVgpuInstance_t vgpuInstance
* array element count is passed in \a sessionCount, and \a sessionCount is used to return the number of sessions
* written to the buffer.
*
* If the supplied buffer is not large enough to accomodate the active session array, the function returns
* If the supplied buffer is not large enough to accommodate the active session array, the function returns
* NVML_ERROR_INSUFFICIENT_SIZE, with the element count of nvmlFBCSessionInfo_t array required in \a sessionCount.
* To query the number of active FBC sessions, call this function with *sessionCount = 0. The code will return
* NVML_SUCCESS with number of active FBC sessions updated in *sessionCount.
@@ -6094,7 +6094,7 @@ typedef struct nvmlVgpuPgpuMetadata_st
unsigned int version; //!< Current version of the structure
unsigned int revision; //!< Current revision of the structure
char hostDriverVersion[NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE]; //!< Host driver version
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualizaion capabilities bitfileld
unsigned int pgpuVirtualizationCaps; //!< Pgpu virtualization capabilities bitfield
unsigned int reserved[7]; //!< Reserved for internal use
unsigned int opaqueDataSize; //!< Size of opaque data field in bytes
char opaqueData[4]; //!< Opaque data
@@ -6191,7 +6191,7 @@ nvmlReturn_t DECLDIR nvmlDeviceGetVgpuMetadata(nvmlDevice_t device, nvmlVgpuPgpu
*
* The caller passes in a buffer via \a compatibilityInfo, into which a compatibility information structure is written. The
* structure defines the states in which the vGPU / VM may be booted on the physical GPU. If the vGPU / VM compatibility
* with the physical GPU is limited, a limit code indicates the factor limiting compability.
* with the physical GPU is limited, a limit code indicates the factor limiting compatibility.
* (see \ref nvmlVgpuPgpuCompatibilityLimitCode_t for details).
*
* Note: vGPU compatibility does not take into account dynamic capacity conditions that may limit a system's ability to

View File

@@ -950,7 +950,7 @@ namespace half_float
/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<std::float_round_style R,bool E,typename T> T half2int_impl(uint16 value)
@@ -988,13 +988,13 @@ namespace half_float
/// Convert half-precision floating point to integer.
/// \tparam R rounding mode to use, `std::round_indeterminate` for fastest rounding
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<std::float_round_style R,typename T> T half2int(uint16 value) { return half2int_impl<R,HALF_ROUND_TIES_TO_EVEN,T>(value); }
/// Convert half-precision floating point to integer using round-to-nearest-away-from-zero.
/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \tparam T type to convert to (builtin integer type with at least 16 bits precision, excluding any implicit sign bits)
/// \param value binary representation of half-precision value
/// \return integral value
template<typename T> T half2int_up(uint16 value) { return half2int_impl<std::round_to_nearest,0,T>(value); }
@@ -1053,7 +1053,7 @@ namespace half_float
/// Half-precision floating point type.
/// This class implements an IEEE-conformant half-precision floating point type with the usual arithmetic operators and
/// conversions. It is implicitly convertible to single-precision floating point, which makes artihmetic expressions and
/// conversions. It is implicitly convertible to single-precision floating point, which makes arithmetic expressions and
/// functions with mixed-type operands to be of the most precise operand type. Additionally all arithmetic operations
/// (and many mathematical functions) are carried out in single-precision internally. All conversions from single- to
/// half-precision are done using the library's default rounding mode, but temporary results inside chained arithmetic
@@ -1062,7 +1062,7 @@ namespace half_float
/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and
/// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which
/// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of
/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not necessarily have to be of
/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most
/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit
/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if
@@ -2181,7 +2181,7 @@ namespace half_float
/// Identity.
/// \param arg operand
/// \return uncahnged operand
/// \return unchanged operand
template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator+(T arg) { return arg; }
/// Negation.
@@ -2620,7 +2620,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); }
inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); }
inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); }
@@ -2636,7 +2636,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); }
inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); }
inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); }
@@ -2644,7 +2644,7 @@ namespace half_float
/// Multiply by power of two.
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
// template<typename T> typename enable<half,T>::type scalbln(T arg, long exp) { return functions::scalbln(arg, exp); }
inline half scalbln(half arg, long exp) { return functions::scalbln(arg, exp); }
inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); }

288
include/triton/external/hip.h vendored Normal file
View File

@@ -0,0 +1,288 @@
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
// we can make that yield a warning
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
#include <cstddef>
typedef enum hipError_t {
hipSuccess = 0, ///< Successful completion.
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
///< or not in an acceptable range.
hipErrorOutOfMemory = 2,
// Deprecated
hipErrorMemoryAllocation = 2, ///< Memory allocation error.
hipErrorNotInitialized = 3,
// Deprecated
hipErrorInitializationError = 3,
hipErrorDeinitialized = 4,
hipErrorProfilerDisabled = 5,
hipErrorProfilerNotInitialized = 6,
hipErrorProfilerAlreadyStarted = 7,
hipErrorProfilerAlreadyStopped = 8,
hipErrorInvalidConfiguration = 9,
hipErrorInvalidPitchValue = 12,
hipErrorInvalidSymbol = 13,
hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer
hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction
hipErrorInsufficientDriver = 35,
hipErrorMissingConfiguration = 52,
hipErrorPriorLaunchFailure = 53,
hipErrorInvalidDeviceFunction = 98,
hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices
hipErrorInvalidDevice = 101, ///< DeviceID must be in range 0...#compute-devices.
hipErrorInvalidImage = 200,
hipErrorInvalidContext = 201, ///< Produced when input context is invalid.
hipErrorContextAlreadyCurrent = 202,
hipErrorMapFailed = 205,
// Deprecated
hipErrorMapBufferObjectFailed = 205, ///< Produced when the IPC memory attach failed from ROCr.
hipErrorUnmapFailed = 206,
hipErrorArrayIsMapped = 207,
hipErrorAlreadyMapped = 208,
hipErrorNoBinaryForGpu = 209,
hipErrorAlreadyAcquired = 210,
hipErrorNotMapped = 211,
hipErrorNotMappedAsArray = 212,
hipErrorNotMappedAsPointer = 213,
hipErrorECCNotCorrectable = 214,
hipErrorUnsupportedLimit = 215,
hipErrorContextAlreadyInUse = 216,
hipErrorPeerAccessUnsupported = 217,
hipErrorInvalidKernelFile = 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX
hipErrorInvalidGraphicsContext = 219,
hipErrorInvalidSource = 300,
hipErrorFileNotFound = 301,
hipErrorSharedObjectSymbolNotFound = 302,
hipErrorSharedObjectInitFailed = 303,
hipErrorOperatingSystem = 304,
hipErrorInvalidHandle = 400,
// Deprecated
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
hipErrorNotFound = 500,
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
///< ready. This is not actually an error, but is used to distinguish
///< from hipSuccess (which indicates completion). APIs that return
///< this error include hipEventQuery and hipStreamQuery.
hipErrorIllegalAddress = 700,
hipErrorLaunchOutOfResources = 701, ///< Out of resources error.
hipErrorLaunchTimeOut = 702,
hipErrorPeerAccessAlreadyEnabled =
704, ///< Peer access was already enabled from the current device.
hipErrorPeerAccessNotEnabled =
705, ///< Peer access was never enabled from the current device.
hipErrorSetOnActiveProcess = 708,
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
hipErrorHostMemoryAlreadyRegistered =
712, ///< Produced when trying to lock a page-locked memory.
hipErrorHostMemoryNotRegistered =
713, ///< Produced when trying to unlock a non-page-locked memory.
hipErrorLaunchFailure =
719, ///< An exception occurred on the device while executing a kernel.
hipErrorCooperativeLaunchTooLarge =
720, ///< This error indicates that the number of blocks launched per grid for a kernel
///< that was launched via cooperative launch APIs exceeds the maximum number of
///< allowed blocks for the current device
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
hipErrorUnknown = 999, //< Unknown error.
// HSA Runtime Error Codes start here.
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
///< in production systems.
hipErrorRuntimeOther = 1053, ///< HSA runtime call other than memory returned error. Typically
///< not seen in production systems.
hipErrorTbd ///< Marker that more error codes are needed.
} hipError_t;
typedef struct ihipCtx_t* hipCtx_t;
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
typedef int hipDevice_t;
typedef enum hipDeviceP2PAttr {
hipDevP2PAttrPerformanceRank = 0,
hipDevP2PAttrAccessSupported,
hipDevP2PAttrNativeAtomicSupported,
hipDevP2PAttrHipArrayAccessSupported
} hipDeviceP2PAttr;
typedef struct ihipStream_t* hipStream_t;
#define hipIpcMemLazyEnablePeerAccess 0
#define HIP_IPC_HANDLE_SIZE 64
typedef struct hipIpcMemHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcMemHandle_t;
typedef struct hipIpcEventHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcEventHandle_t;
typedef struct ihipModule_t* hipModule_t;
typedef struct ihipModuleSymbol_t* hipFunction_t;
typedef struct hipFuncAttributes {
int binaryVersion;
int cacheModeCA;
size_t constSizeBytes;
size_t localSizeBytes;
int maxDynamicSharedSizeBytes;
int maxThreadsPerBlock;
int numRegs;
int preferredShmemCarveout;
int ptxVersion;
size_t sharedSizeBytes;
} hipFuncAttributes;
typedef struct ihipEvent_t* hipEvent_t;
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
///< bytes.
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
///< thread block. This number is shared by all thread
///< blocks simultaneously resident on a
///< multiprocessor.
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
///< cache.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
///< multiprocessor.
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
///< concurrently.
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
///< Multiprocessor.
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeIntegrated, ///< iGPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///devices with unmatched shared memories
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
/// the device without migration
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
/// concurrently with the CPU
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
/// without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
/// the host's page tables
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64() , '0' otherwise.
} hipDeviceAttribute_t;
typedef void* hipDeviceptr_t;
/*
* @brief hipJitOption
* @enum
* @ingroup Enumerations
*/
typedef enum hipJitOption {
hipJitOptionMaxRegisters = 0,
hipJitOptionThreadsPerBlock,
hipJitOptionWallTime,
hipJitOptionInfoLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionOptimizationLevel,
hipJitOptionTargetFromContext,
hipJitOptionTarget,
hipJitOptionFallbackStrategy,
hipJitOptionGenerateDebugInfo,
hipJitOptionLogVerbose,
hipJitOptionGenerateLineInfo,
hipJitOptionCacheMode,
hipJitOptionSm3xOpt,
hipJitOptionFastCompile,
hipJitOptionNumOptions
} hipJitOption;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncAttribute {
hipFuncAttributeMaxDynamicSharedMemorySize = 8,
hipFuncAttributePreferredSharedMemoryCarveout = 9,
hipFuncAttributeMax
} hipFuncAttribute;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncCache_t {
hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default)
hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 cache
hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
} hipFuncCache_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)

View File

@@ -1,4 +1,4 @@
#pragma once
#pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_
@@ -27,7 +27,7 @@ public:
private:
// constructors
basic_block(context &ctx, const std::string &name, function *parent);
basic_block(context &ctx, const std::string &name, function *parent, basic_block *next);
public:
// accessors
@@ -35,10 +35,12 @@ public:
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
void replace_phi_uses_with(basic_block* before, basic_block* after);
iterator get_first_non_phi();
// get instruction list
inst_list_t &get_inst_list() { return inst_list_; }
const inst_list_t &get_inst_list() const { return inst_list_; }
void erase(instruction *i) { inst_list_.remove(i); }
// instruction iterator functions
@@ -59,13 +61,18 @@ public:
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
void append_instruction(ir::instruction* i);
// split
basic_block* split_before(ir::instruction* loc, const std::string& name);
// predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
const std::vector<basic_block*>& get_successors() const { return succs_; }
void add_predecessor(basic_block* pred);
std::vector<basic_block*> get_predecessors() const;
std::vector<basic_block*> get_successors() const;
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent);
static basic_block* create(context &ctx, const std::string &name, function *parent, basic_block *next = nullptr);
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_basic_block(this); }

View File

@@ -22,13 +22,16 @@ class phi_node;
/* Builder */
class builder{
public:
typedef basic_block::iterator iterator;
public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// const context& get_context() const { return ctx_; }
context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
@@ -38,8 +41,8 @@ public:
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_int32(uint32_t val);
value *get_int64(uint64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
@@ -50,7 +53,9 @@ public:
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_fp8_ty();
type *get_half_ty();
type *get_bf16_ty();
type *get_float_ty();
type *get_double_ty();
// Insert
@@ -67,8 +72,13 @@ public:
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
value* create_ret(value *ret);
// Dequantize instructions
value* create_dequantize(value *src, value *scale, value *shift, type *dest_ty);
// Cast instructions
value* create_bitcast(value *src, type *dest_ty);
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_int_to_ptr(value *src, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
@@ -78,6 +88,9 @@ public:
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Call instruction
value* create_call(function* fn, const std::vector<value*>& args);
value* create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
@@ -87,11 +100,11 @@ public:
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
@@ -130,32 +143,57 @@ public:
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value);
value *create_masked_store(value *ptr, value *val, value *mask);
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile);
value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction);
// Struct instructions
value *create_insert_value(value* val, value *elt, size_t idx);
value *create_extract_value(value* val, size_t idx);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Atomic instruction
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_atomic_max(value *ptr, value *val, value *msk);
value *create_atomic_umax(value *ptr, value *val, value *msk);
value *create_atomic_min(value *ptr, value *val, value *msk);
value *create_atomic_umin(value *ptr, value *val, value *msk);
value *create_atomic_fadd(value *ptr, value *val, value *msk);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_atomic_and(value *ptr, value *val, value *msk);
value *create_atomic_or(value *ptr, value *val, value *msk);
value *create_atomic_xor(value *ptr, value *val, value *msk);
value *create_atomic_xchg(value *ptr, value *val, value *msk);
// Utilities
value *create_clock();
value *create_globaltimer();
// Extern instruction
value *create_extern_elementwise(const std::string &lib_name,
const std::string &lib_path,
const std::string &symbol_name,
const std::vector<value *> &args,
type *ret_ty);
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_exch(value *ptr, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
// These have no place in the IR, and hopefully they can be removed at some point
value *create_umulhi(value* lhs, value* rhs);
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);

View File

@@ -9,7 +9,6 @@
namespace triton{
namespace ir{
class builder;
class type;
class context_impl;
@@ -21,7 +20,6 @@ public:
context& operator=(const context&) = delete;
public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl;
};

View File

@@ -3,17 +3,15 @@
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_
#include <map>
#include "triton/ir/type.h"
#include "triton/ir/constant.h"
#include <map>
#include <memory>
namespace triton{
namespace ir{
class context;
class constant;
class constant_int;
class constant_fp;
class undef_value;
/* Context impl */
class context_impl {
@@ -29,16 +27,17 @@ public:
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
std::map<std::pair<type*, unsigned>, std::unique_ptr<pointer_type>> ptr_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
std::map<std::pair<type*, type::block_shapes_t>, std::unique_ptr<block_type>> block_tys;
// Struct types
std::map<type::contained_tys_vec_t, struct_type*> struct_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
std::map<std::pair<type*, uint64_t>, std::unique_ptr<constant_int>> int_constants_;
// Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
std::map<std::pair<type*, double>, std::unique_ptr<constant_fp>> fp_constants_;
// undef values
std::map<type*, undef_value*> uv_constants_;
std::map<type*, std::unique_ptr<undef_value>> uv_constants_;
};

View File

@@ -1,107 +0,0 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
// math
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -38,6 +38,7 @@ enum class atomic_rmw_op_t: unsigned int{
UMax,
UMin,
FAdd,
Xchg,
};
enum cast_op_t: unsigned int {
@@ -94,6 +95,9 @@ enum value_id_t: unsigned {
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// call
INST_CALL,
INST_LAUNCH,
// phi
INST_PHI,
// arithmetic
@@ -104,6 +108,8 @@ enum value_id_t: unsigned {
// cmp
INST_ICMP,
INST_FCMP,
// dequantize
INST_DEQUANTIZE,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
@@ -128,9 +134,13 @@ enum value_id_t: unsigned {
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// struct
INST_EXTRACT_VALUE,
INST_INSERT_VALUE,
// retile
INST_RESHAPE,
INST_SPLAT,
INST_CAT,
INST_BROADCAST,
INST_DOWNCAST,
// builtin
@@ -141,10 +151,13 @@ enum value_id_t: unsigned {
INST_ATOMIC_EXCH,
INST_ATOMIC_RMW,
// math
INST_UMULHI,
INST_EXP,
INST_COS,
INST_SIN,
INST_LOG,
// extern
INST_EXTERN_ELEMENTWISE,
// array arithmetic
INST_TRANS,
INST_REDUCE,
@@ -152,6 +165,9 @@ enum value_id_t: unsigned {
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_CVT_LAYOUT,
INST_CVT_SCANLINE,
INST_DECOALESCE,
INST_RECOALESCE,
INST_BARRIER,
INST_ASYNC_WAIT,
@@ -159,6 +175,8 @@ enum value_id_t: unsigned {
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
INST_GLOBALTIMER,
INST_CLOCK,
};

View File

@@ -24,7 +24,7 @@ public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
unsigned get_arg_no() const;
void accept(visitor *v);
@@ -104,19 +104,27 @@ public:
// accessors
const args_t &args() const { return args_; }
function_type* get_fn_type() { return fn_ty_; }
const function_type* get_fn_type() const { return fn_ty_; }
module *get_parent() { return parent_; }
const module *get_parent() const { return parent_; }
// factory methods
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
bool get_is_kernel() { return is_kernel_; }
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_function(this); }
@@ -128,6 +136,7 @@ private:
args_t args_;
blocks_t blocks_;
attr_map_t attrs_;
bool is_kernel_;
};
}

View File

@@ -59,8 +59,8 @@ public:
std::string repr() const { return repr_impl(); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
std::vector<unsigned> value) { metadatas_[kind] = value;}
std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
@@ -73,12 +73,59 @@ public:
// instruction id
value_id_t get_id() const { return id_; }
void print(std::ostream &os);
private:
basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_;
value_id_t id_;
};
//===----------------------------------------------------------------------===//
// call_inst classes
//===----------------------------------------------------------------------===//
class call_inst: public instruction {
private:
std::string repr_impl() const;
call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next);
public:
static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "", instruction *next = nullptr);
ir::function* get_fn() { return fn_; }
_TRITON_DEFINE_CLONE(call_inst)
_TRITON_DEFINE_ACCEPT(call_inst)
private:
ir::function* fn_;
};
class launch_inst: public instruction {
private:
std::string repr_impl() const { return "launch"; }
launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
const std::string &name = "", instruction *next = nullptr);
public:
static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
const std::string& name = "", instruction* next = nullptr);
ir::function* get_fn();
std::vector<ir::value*> get_values();
std::vector<ir::value*> get_grid();
ir::value* get_num_warps();
_TRITON_DEFINE_CLONE(launch_inst)
_TRITON_DEFINE_ACCEPT(launch_inst)
private:
unsigned val_begin;
unsigned val_end;
unsigned grid_begin;
unsigned grid_end;
};
//===----------------------------------------------------------------------===//
// phi_node classes
@@ -115,6 +162,7 @@ private:
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction {
public:
typedef binary_op_t op_t;
@@ -143,6 +191,10 @@ public:
bool is_shl() const;
bool is_shr() const;
// Approx
void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
@@ -161,6 +213,8 @@ public:
binary_op_t op_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
bool fdiv_ieee_rnd_;
};
@@ -220,6 +274,24 @@ protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// dequantize_inst classes
//===----------------------------------------------------------------------===//
class dequantize_inst: public instruction{
private:
std::string repr_impl() const override { return "dequantize"; }
protected:
dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
public:
static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dequantize_inst)
_TRITON_DEFINE_ACCEPT(dequantize_inst)
};
//===----------------------------------------------------------------------===//
// cast_inst classes
@@ -381,20 +453,61 @@ private:
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
public:
enum EVICTION_POLICY : uint32_t {
NORMAL=0,
EVICT_FIRST,
EVICT_LAST,
};
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
std::string get_eviction_policy_repr() const {
if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
if (eviction_ == EVICT_LAST) return ".L2::evict_last";
return "";
}
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
EVICTION_POLICY get_eviction_policy() const { return eviction_; }
protected:
EVICTION_POLICY eviction_;
};
// load
class load_inst: public io_inst {
public:
enum CACHE_MODIFIER : uint32_t {
NONE=0,
CA,
CG,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
bool get_is_volatile() const { return is_volatile_; }
protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops,
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "", instruction *next = nullptr);
std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca";
if (cache_ == CG) return ".cg";
return "";
}
CACHE_MODIFIER cache_;
std::string get_volatile_repr() {
return is_volatile_ ? ".volatile" : "";
}
bool is_volatile_;
private:
static type *get_pointee_type(type *ty);
@@ -403,11 +516,13 @@ private:
// unmasked load
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load"; }
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
@@ -417,8 +532,8 @@ public:
// masked load
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load"; }
masked_load_inst(value *ptr, value *mask, value *false_value,
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
const std::string &name, instruction *next);
public:
@@ -427,6 +542,8 @@ public:
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
bool is_volatile,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
@@ -436,9 +553,10 @@ public:
// masked load async
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async_async"; }
std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next);
CACHE_MODIFIER cache, EVICTION_POLICY eviction,
const std::string &name, instruction *next);
public:
// accessors
@@ -446,6 +564,8 @@ public:
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)
@@ -457,7 +577,7 @@ public:
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
const std::string &name = "", instruction *next = nullptr);
public:
@@ -468,11 +588,11 @@ public:
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
@@ -482,24 +602,77 @@ public:
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
_TRITON_DEFINE_ACCEPT(masked_store_inst)
};
//===----------------------------------------------------------------------===//
// struct classes
//===----------------------------------------------------------------------===//
// insert_value
class insert_value_inst: public instruction {
private:
std::string repr_impl() const { return "insertvalue"; }
insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next);
public:
static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr);
size_t get_idx() { return idx_; }
_TRITON_DEFINE_CLONE(insert_value_inst)
_TRITON_DEFINE_ACCEPT(insert_value_inst)
private:
size_t idx_;
};
// extract_value
class extract_value_inst: public instruction {
private:
std::string repr_impl() const { return "extractvalue"; }
extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next);
public:
static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr);
size_t get_idx() { return idx_; }
_TRITON_DEFINE_CLONE(extract_value_inst)
_TRITON_DEFINE_ACCEPT(extract_value_inst)
private:
size_t idx_;
};
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
// cat
class cat_inst: public instruction {
private:
std::string repr_impl() const { return "cat"; }
cat_inst(value *x, value *y, const std::string &name, instruction *next);
public:
static instruction* create(value *lhs, value *rhs,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(cat_inst)
_TRITON_DEFINE_ACCEPT(cat_inst)
};
// retile
class retile_inst: public unary_inst {
@@ -606,6 +779,8 @@ private:
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
io_inst(ty, id, num_ops, NORMAL, name, next) {}
};
class atomic_rmw_inst: public atomic_inst {
@@ -634,18 +809,17 @@ public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class atomic_exch_inst: public atomic_inst {
class umulhi_inst: public builtin_inst {
private:
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_exch"; }
_TRITON_DEFINE_CLONE(atomic_exch_inst)
_TRITON_DEFINE_ACCEPT(atomic_exch_inst)
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "umulhi"; }
_TRITON_DEFINE_CLONE(umulhi_inst)
_TRITON_DEFINE_ACCEPT(umulhi_inst)
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
};
class exp_inst: public builtin_inst {
private:
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
@@ -694,24 +868,40 @@ public:
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
enum DataType {
FP8, FP16, BF16, TF32, FP32,
INT1, INT4, INT8, INT32,
UNKNOWN,
};
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
bool is_trans_a() const { return AT_ == Trans; }
bool is_trans_b() const { return BT_ == Trans; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst)
private:
bool is_prefetched_ = false;
bool allow_tf32_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
TransT AT_;
TransT BT_;
};
//class outer_inst: public builtin_inst {
@@ -753,8 +943,11 @@ public:
class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN,
FADD, FSUB, FMAX, FMIN
ADD, SUB, MAX, MIN, UMAX, UMIN,
ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
FADD, FSUB, FMAX, FMIN,
ARGFMAX, ARGFMIN,
XOR
};
private:
@@ -771,12 +964,19 @@ public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
bool with_index() const {
return with_index_ops_.find(op_) != with_index_ops_.end();
}
private:
unsigned axis_;
op_t op_;
const static inline std::set<op_t> with_index_ops_ = {
op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
@@ -795,6 +995,7 @@ public:
// intrinsics classes
//===----------------------------------------------------------------------===//
class copy_to_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
@@ -819,16 +1020,15 @@ public:
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class recoalesce_inst: public unary_inst{
class cvt_layout_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "recoalesce_inst"; }
std::string repr_impl() const { return "cvt_layout_inst"; }
public:
static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(recoalesce_inst)
_TRITON_DEFINE_ACCEPT(recoalesce_inst)
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(cvt_layout_inst)
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
};
class barrier_inst: public instruction{
@@ -864,11 +1064,11 @@ class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}
@@ -877,35 +1077,6 @@ public:
instruction *next=nullptr);
};
//// On NVIDIA, implementation is such that
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
//// so as to enable re-association on nv_static_program_idx which is constant
//class make_range_dyn: public instruction {
//private:
// make_range_dyn(type *ty, const std::string &name, instruction *next);
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
// _TRITON_DEFINE_CLONE(make_range_dyn)
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
//public:
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
//};
//class make_range_sta: public constant {
//private:
// make_range_sta(make_range *range);
//public:
// static make_range_sta *get(make_range* range);
// make_range* get_range() const;
// std::string repr() const { return "nv_static_program_idx"; }
// _TRITON_DEFINE_ACCEPT(make_range_sta)
//private:
// make_range *range_;
//};
/* constant range */
class make_range: public instruction{
make_range(type *ty, constant_int* first, constant_int* last);
@@ -923,7 +1094,53 @@ private:
constant_int* last_;
};
/* timing utilities */
class clock_inst: public instruction{
clock_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "clock"; }
_TRITON_DEFINE_CLONE(clock_inst)
_TRITON_DEFINE_ACCEPT(clock_inst)
public:
static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class globaltimer_inst: public instruction{
globaltimer_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "globaltimer"; }
_TRITON_DEFINE_CLONE(globaltimer_inst)
_TRITON_DEFINE_ACCEPT(globaltimer_inst)
public:
static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
};
class extern_elementwise_inst : public instruction {
extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
type *dst_ty, const std::string &lib_name,
const std::string &extern_lib_path,
const std::string &symbol_name,
const std::string &name, instruction *next);
std::string repr_impl() const { return "extern_elementwise"; }
_TRITON_DEFINE_CLONE(extern_elementwise_inst)
_TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
public:
static extern_elementwise_inst *create(
context &ctx, const std::vector<value *> &args, type *dst_ty,
const std::string &lib_name = "", const std::string &lib_path = "",
const std::string &symbol_name = "", const std::string &name = "",
instruction *next = nullptr);
const std::string &get_lib_name() const { return lib_name_; }
const std::string &get_lib_path() const { return lib_path_; }
const std::string &get_symbol_name() const { return symbol_name_; }
private:
std::string lib_name_;
std::string lib_path_;
std::string symbol_name_;
};
}
}

View File

@@ -3,6 +3,8 @@
#ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_
#include <vector>
namespace triton{
namespace ir{
@@ -11,18 +13,19 @@ namespace ir{
class metadata{
public:
enum kind_t{
multiple_of
multiple_of,
max_contiguous
};
private:
metadata(kind_t kind, unsigned value);
metadata(kind_t kind, std::vector<unsigned> value);
public:
static metadata* get(kind_t kind, unsigned value);
static metadata* get(kind_t kind, std::vector<unsigned> value);
private:
kind_t kind_;
unsigned value_;
std::vector<unsigned> value_;
};
}

View File

@@ -34,50 +34,74 @@ class constant;
class global_value;
class alloc_const;
/* Module */
class module {
class value_constructor {
typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
lang::iteration_statement *statement;
basic_block *block;
};
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder& builder);
builder& get_builder();
// Setters
value_constructor(builder &builder);
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Metadata
private:
ir::builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
std::map<value*, value**> current_phi_;
};
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
typedef std::pair<ir::metadata::kind_t, std::vector<unsigned>> md_pair_t;
friend class function;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder &builder): name_(name), builder_(builder) {}
builder &get_builder() { return builder_; };
const std::string& get_name() { return name_; };
// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_function(const std::string& name) {
if(symbols_.find(name) == symbols_.end())
throw std::runtime_error("function " + name + " is not declared");
return (function*)symbols_.at(name);
}
function *get_or_insert_function(const std::string &name, function_type *ty);
bool has_function(const std::string& name){
return symbols_.find(name) != symbols_.end();
}
void remove_function(ir::function* fn){
functions_.erase(std::remove(functions_.begin(), functions_.end(), fn), functions_.end());
}
void reset_ret_ty(const std::string& name, type* ty);
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
@@ -85,20 +109,15 @@ public:
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void print(std::ostream &os);
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
const std::map<std::string, md_pair_t> &get_metadatas() const { return metadatas_; }
private:
std::string name_;
builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
builder &builder_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;

View File

@@ -1,4 +1,4 @@
#pragma once
#pragma once
#ifndef _TRITON_IR_TYPE_H_
#define _TRITON_IR_TYPE_H_
@@ -6,6 +6,7 @@
#include <cassert>
#include <vector>
#include <string>
#include <stdexcept>
namespace triton{
namespace ir{
@@ -20,7 +21,6 @@ class type {
public:
typedef std::vector<unsigned> block_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator;
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
@@ -68,23 +68,24 @@ public:
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;
type *get_pointer_element_ty() const;
unsigned get_struct_numel() const { return contained_tys_.size(); }
type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_fp16_ty() const { return id_ == FP16TyID; }
bool is_bf16_ty() const { return id_ == BF16TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
bool is_struct_ty() const { return id_ == StructTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
@@ -128,21 +129,21 @@ public:
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case BF16TyID: return "bf16";
case FP16TyID: return "f16";
case FP32TyID: return "f32";
case FP64TyID: return "f64";
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
case BlockTyID: return tile_repr();
default: break;
}
assert(false);
return "";
throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
};
private:
@@ -159,7 +160,7 @@ class integer_type: public type {
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
: type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
public:
// accessors
@@ -181,6 +182,16 @@ public:
type* get_type_at_index(value *idx) const;
};
class struct_type: public composite_type {
public:
struct_type(const contained_tys_vec_t& tys, bool is_packed);
unsigned get_num_types() const { return contained_tys_.size(); }
static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
private:
bool is_packed_;
};
class block_type: public composite_type {
private:
block_type(type *ty, const block_shapes_t &shapes);
@@ -229,6 +240,7 @@ public:
ty_iterator params_end() { return contained_tys_.end(); }
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
type* get_return_ty() const { return contained_tys_.at(0); }
void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
// factory methods
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
};

View File

@@ -22,6 +22,7 @@ public:
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_instruction_backward(module &mod, const std::function<void (instruction *)> &do_work);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}

View File

@@ -21,7 +21,7 @@ class visitor;
class value {
public:
typedef std::set<user*> users_t;
typedef std::vector<user*> users_t;
public:
// constructor
@@ -30,11 +30,12 @@ public:
// uses
void add_use(user* arg);
users_t::iterator erase_use(user* arg);
const std::set<user*> &get_users() { return users_; }
const std::vector<user*> &get_users() { return users_; }
void replace_all_uses_with(value *target);
// name
void set_name(const std::string &name);
const std::string &get_name() const { return name_; }
bool has_name() const { return !name_.empty(); }
type* get_type() const { return ty_; }
// visitor
virtual void accept(visitor *v) = 0;
@@ -70,6 +71,7 @@ public:
// Operands
const ops_t& ops() { return ops_; }
const ops_t& ops() const { return ops_; }
op_iterator op_begin() { return ops_.begin(); }
op_iterator op_end() { return ops_.end(); }
void set_operand(unsigned i, value *x);

View File

@@ -11,12 +11,16 @@ class value;
class instruction;
class call_inst;
class launch_inst;
class phi_node;
class binary_operator;
class getelementptr_inst;
class icmp_inst;
class fcmp_inst;
class dequantize_inst;
class cast_inst;
class trunc_inst;
class z_ext_inst;
@@ -42,12 +46,17 @@ class masked_load_inst;
class unmasked_store_inst;
class masked_store_inst;
class extract_value_inst;
class insert_value_inst;
class retile_inst;
class reshape_inst;
class splat_inst;
class cat_inst;
class broadcast_inst;
class downcast_inst;
class umulhi_inst;
class exp_inst;
class cos_inst;
class sin_inst;
@@ -57,7 +66,6 @@ class get_program_id_inst;
class get_num_programs_inst;
class atomic_inst;
class atomic_cas_inst;
class atomic_exch_inst;
class atomic_rmw_inst;
class dot_inst;
class trans_inst;
@@ -65,7 +73,7 @@ class sqrt_inst;
class reduce_inst;
class select_inst;
class recoalesce_inst;
class cvt_layout_inst;
class copy_to_shared_inst;
class copy_from_shared_inst;
class masked_load_async_inst;
@@ -74,6 +82,10 @@ class async_wait_inst;
class make_range_dyn;
class make_range;
class prefetch_s_inst;
class clock_inst;
class globaltimer_inst;
class extern_elementwise_inst;
class make_range_sta;
class undef_value;
@@ -102,6 +114,8 @@ public:
virtual ~visitor() {}
virtual void visit_value(ir::value*);
virtual void visit_call_inst(ir::call_inst*) = 0;
virtual void visit_launch_inst(ir::launch_inst*) = 0;
virtual void visit_basic_block(basic_block*) = 0;
virtual void visit_argument(argument*) = 0;
@@ -111,6 +125,7 @@ public:
virtual void visit_icmp_inst(icmp_inst*) = 0;
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
virtual void visit_dequantize_inst(dequantize_inst*) = 0;
virtual void visit_cast_inst(cast_inst*) = 0;
virtual void visit_return_inst(return_inst*) = 0;
@@ -123,20 +138,24 @@ public:
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
virtual void visit_umulhi_inst(umulhi_inst*) = 0;
virtual void visit_exp_inst(exp_inst*) = 0;
virtual void visit_cos_inst(cos_inst*) = 0;
virtual void visit_sin_inst(sin_inst*) = 0;
virtual void visit_log_inst(log_inst*) = 0;
virtual void visit_extract_value_inst(extract_value_inst*) = 0;
virtual void visit_insert_value_inst(insert_value_inst*) = 0;
virtual void visit_reshape_inst(reshape_inst*) = 0;
virtual void visit_splat_inst(splat_inst*) = 0;
virtual void visit_cat_inst(cat_inst*) = 0;
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
virtual void visit_downcast_inst(downcast_inst*) = 0;
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
virtual void visit_atomic_rmw_inst(atomic_rmw_inst*) = 0;
virtual void visit_dot_inst(dot_inst*) = 0;
virtual void visit_trans_inst(trans_inst*) = 0;
@@ -144,23 +163,26 @@ public:
virtual void visit_reduce_inst(reduce_inst*) = 0;
virtual void visit_select_inst(select_inst*) = 0;
virtual void visit_recoalesce_inst(recoalesce_inst*) = 0;
virtual void visit_cvt_layout_inst(cvt_layout_inst*) = 0;
virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0;
virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0;
virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0;
virtual void visit_barrier_inst(barrier_inst*) = 0;
virtual void visit_async_wait_inst(async_wait_inst*) = 0;
// virtual void visit_make_range_dyn(make_range_dyn*) = 0;
virtual void visit_make_range(make_range*) = 0;
virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0;
virtual void visit_function(function*) = 0;
virtual void visit_clock_inst(clock_inst*) = 0;
virtual void visit_globaltimer_inst(globaltimer_inst*) = 0;
// virtual void visit_make_range_sta(make_range_sta*) = 0;
virtual void visit_undef_value(undef_value*) = 0;
virtual void visit_constant_int(constant_int*) = 0;
virtual void visit_constant_fp(constant_fp*) = 0;
virtual void visit_alloc_const(alloc_const*) = 0;
virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0;
};
}

View File

@@ -3,30 +3,32 @@
#ifndef _TRITON_TOOLS_THREAD_GRAPH_H_
#define _TRITON_TOOLS_THREAD_GRAPH_H_
#include "llvm/ADT/SetVector.h"
#include <map>
#include <set>
#include <vector>
#include <iostream>
namespace triton {
namespace tools{
template<class node_t>
class graph {
typedef std::map<node_t, std::set<node_t>> edges_t;
typedef std::map<node_t, llvm::SetVector<node_t>> edges_t;
public:
typedef std::map<size_t, std::vector<node_t>> cmap_t;
typedef std::map<node_t, size_t> nmap_t;
private:
void connected_components_impl(node_t x, std::set<node_t> &nodes,
void connected_components_impl(node_t x, llvm::SetVector<node_t> &nodes,
nmap_t* nmap, cmap_t* cmap, int id) const {
if(nmap)
(*nmap)[x] = id;
if(cmap)
(*cmap)[id].push_back(x);
if(nodes.find(x) != nodes.end()) {
nodes.erase(x);
if (nodes.count(x)) {
nodes.remove(x);
for(const node_t &y: edges_.at(x))
connected_components_impl(y, nodes, nmap, cmap, id);
}
@@ -38,10 +40,11 @@ public:
cmap->clear();
if(nmap)
nmap->clear();
std::set<node_t> nodes = nodes_;
llvm::SetVector<node_t> nodes = nodes_;
unsigned id = 0;
while(!nodes.empty())
while(!nodes.empty()){
connected_components_impl(*nodes.begin(), nodes, nmap, cmap, id++);
}
}
void add_edge(node_t x, node_t y) {
@@ -57,7 +60,7 @@ public:
}
private:
std::set<node_t> nodes_;
llvm::SetVector<node_t> nodes_;
edges_t edges_;
};

View File

@@ -0,0 +1,46 @@
#ifndef TRITON_TOOLS_SYS_EXEC_HPP
#define TRITON_TOOLS_SYS_EXEC_HPP
#include <cstdio>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
namespace triton
{
namespace tools
{
#ifdef _WIN32
#define popen _popen
#define pclose _pclose
#endif
#ifndef WEXITSTATUS
#define WEXITSTATUS(stat_val) ((unsigned)(stat_val) & 255)
#endif
int exec(const std::string& cmd, std::string& result) {
char buffer[128];
FILE* pipe = popen(cmd.c_str(), "r");
if (!pipe)
return 0;
result.clear();
try {
while (fgets(buffer, sizeof buffer, pipe) != NULL)
result += buffer;
} catch (...) {
pclose(pipe);
return 0;
}
int status = pclose(pipe);
return WEXITSTATUS(status);
}
}
}
#endif

View File

@@ -33,19 +33,10 @@ namespace tools
inline std::string getenv(const char * name)
{
#ifdef _MSC_VER
char* cache_path = 0;
std::size_t sz = 0;
_dupenv_s(&cache_path, &sz, name);
#else
const char * cstr = std::getenv(name);
#endif
const char * cstr = std::getenv(name);
if(!cstr)
return "";
std::string result(cstr);
#ifdef _MSC_VER
free(cache_path);
#endif
return result;
}

View File

@@ -115,6 +115,18 @@ std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_ins
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_block_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++) {
result.push_back(op_cst[d]);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
@@ -129,6 +141,36 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax = {1, 0};
// Examples:
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
//
// if LHS is a range of N continuous (or equal) elements that starts at M,
// and RHS is a set of N constants that start at K
// then the result in constant in groups of gcd(M, K)
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
rhs[d].num_cst % lhs[d].num_cst == 0)
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
@@ -136,12 +178,14 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
@@ -180,10 +224,14 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_is_constant_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
return populate_is_constant_cmp(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
@@ -245,6 +293,23 @@ std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++) {
unsigned factor = 1;
if (d == shapes.size() - 1) {
factor = ret_last_dim / op_last_dim;
}
result.push_back(factor * op_mc[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
@@ -285,8 +350,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
}
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
}
result.push_back(value);
@@ -331,12 +396,19 @@ std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
if(!max_contiguous.empty())
return add_to_cache(x, max_contiguous, max_contiguous_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_max_contiguous_cast(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_max_contiguous_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
@@ -381,6 +453,23 @@ std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
auto op_last_dim = (op->get_type()->get_block_shapes()).back();
auto op_multiple = populate_starting_multiple(op);
for(size_t d = 0; d < shapes.size(); d++) {
unsigned factor = 1;
if (d == shapes.size() - 1) {
factor = ret_last_dim / op_last_dim;
}
result.push_back(factor * op_multiple[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
@@ -396,7 +485,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]);
}
@@ -466,28 +555,42 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
return add_to_cache(v, {1}, starting_multiple_);
}
unsigned get_max_multiple(int val){
if(val == 0) return 1 << 31;
if(val % 128 == 0) return 128;
if(val % 64 == 0) return 64;
if(val % 32 == 0) return 32;
if(val % 16 == 0) return 16;
if(val % 8 == 0) return 8;
if(val % 4 == 0) return 4;
if(val % 2 == 0) return 2;
return 1;
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(!multiple_of.empty())
return add_to_cache(x, multiple_of, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::cast_inst*>(v))
return populate_starting_multiple_cast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
return populate_starting_multiple_dequantize(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
@@ -506,12 +609,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
return is_constant_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {

View File

@@ -50,7 +50,6 @@ void allocation::run(ir::module &mod) {
J.erase(j_it);
}
}
// Build interference graph
std::map<shared_layout*, std::set<shared_layout*>> interferences;
for(shared_layout* x: V)
@@ -66,13 +65,10 @@ void allocation::run(ir::module &mod) {
&& XS.intersect(YS))
interferences[x].insert(y);
}
// Initialize colors
std::map<shared_layout*, int> colors;
for(shared_layout* X: V)
colors[X] = (X==V[0])?0:-1;
// First-fit graph coloring
std::vector<bool> available(V.size());
for(shared_layout* x: V){
@@ -87,7 +83,6 @@ void allocation::run(ir::module &mod) {
auto It = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), It);
}
// Finalize allocation
for(shared_layout* x: V){
unsigned Adj = 0;
@@ -95,11 +90,12 @@ void allocation::run(ir::module &mod) {
Adj = std::max<unsigned>(Adj, starts[y] + y->get_size());
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(shared_layout* x: V)
for(shared_layout* x: V){
allocated_size_ = std::max<size_t>(allocated_size_, starts[x] + x->get_size());
// std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl;
}
}
}

View File

@@ -56,6 +56,17 @@ void axes::update_graph_trans(ir::instruction *i) {
graph_.add_edge({i, perm[d]}, {op, d});
}
void axes::update_graph_dequantize(ir::instruction *i) {
auto *dequantize = static_cast<ir::dequantize_inst*>(i);
auto shapes = dequantize->get_type()->get_block_shapes();
ir::value *op = dequantize->get_operand(0);
// add edge except the last axis
for(unsigned d = 0; d < shapes.size() - 1; d ++){
graph_.add_edge({i, d}, {op, d});
}
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_block_shapes();
@@ -79,19 +90,28 @@ void axes::update_graph_dot(ir::instruction *i) {
graph_.add_edge({dot, d}, {D, d});
}
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
void axes::update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async) {
if(i->get_num_operands() == 0)
return;
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_block_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(connect_ret && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
for(unsigned d = 0; d < rank; d++) {
// If we are dealing with a masked async load we need to attach the
// dimensions so we match the behaviour of the copy_to_shared instruction
// which async masked load replaces.
if (is_masked_load_async) {
graph_.add_edge({i, d}, {i, d});
}
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()) {
if(!is_masked_load_async && !i->get_type()->is_void_ty())
graph_.add_edge({i, d}, {opx, d});
graph_.add_edge({opx, d}, {opy, d});
}
}
}
@@ -105,17 +125,19 @@ void axes::update_graph_no_edge(ir::instruction *i) {
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);;
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_no_edge(i);
case ir::INST_CAT: return update_graph_elementwise(i, true);
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
default: return update_graph_elementwise(i);
}
return;
}
@@ -135,11 +157,15 @@ std::vector<int> axes::get(ir::value *value) {
void axes::run(ir::module &mod) {
// make graph
graph_.clear();
axes_.clear();
ir::for_each_instruction(mod, [this](ir::instruction *x) {
update_graph(x);
});
// find connected components
graph_.connected_components(nullptr, &axes_);
std::set<size_t> uniq;
for(auto x: axes_)
uniq.insert(x.second);
}
}

View File

@@ -23,19 +23,67 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
return std::min(std::max(x, lo), hi);
}
inline bool is_hmma_c(ir::value *v){
inline bool is_hmma_c(ir::value *v, int sm){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_fp16_ty() &&
b_ty->get_scalar_ty()->is_fp16_ty();
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80) ||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
sm >= 80);
}
return result;
}
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
mma_layout::TensorCoreType mma_type;
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
ir::type* a_ty = a->get_type();
ir::type* b_ty = b->get_type();
ir::type* c_ty = v->get_type();
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
// floating point tensor cores
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
mma_type = mma_layout::FP32_FP16_FP16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
mma_type = mma_layout::FP32_BF16_BF16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
&& dot->allow_tf32()) {
mma_type = mma_layout::FP32_TF32_TF32_FP32;
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
// throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
mma_type = mma_layout::INT32_INT8_INT8_INT32;
return mma_type;
}
}
}
return mma_layout::NOT_APPLICABLE;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
@@ -52,11 +100,12 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
}
}
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
result = i;
}
}
}
@@ -109,9 +158,6 @@ data_layout::data_layout(id_t id,
max_contiguous = curr;
}
}
bool is_recoalesce = false;
for(ir::value* v: values)
is_recoalesce = is_recoalesce || dynamic_cast<ir::recoalesce_inst*>(v);
if(max_contiguous.size() > 0){
std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
return max_contiguous[a] > max_contiguous[b];
@@ -129,6 +175,13 @@ int data_layout::find_axis(int to_find) const {
}
distributed_layout::distributed_layout(id_t id,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align): data_layout(id, axes, shape, values, align)
{ }
/* -------------------------------- *
* MMA Layout *
* -------------------------------- */
@@ -138,20 +191,13 @@ mma_layout::mma_layout(size_t num_warps,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
shared_layout *layout_a, shared_layout *layout_b,
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){
fpw_ = {2, 2, 1};
// std::vector<int> fpw_nm1;
// unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
// do {
// fpw_nm1 = fpw_;
// if(fpw_[0]*fpw_[1] < num_fragments)
// fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
// if(fpw_[0]*fpw_[1] < num_fragments)
// fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
// }while(fpw_nm1 != fpw_);
auto ord_a = layout_a->get_order();
auto ord_b = layout_b->get_order();
bool is_a_row = ord_a[0] != 0;
@@ -162,27 +208,70 @@ mma_layout::mma_layout(size_t num_warps,
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
contig_per_thread_ = {1, 1};
order_ = {0, 1};
}
else{
fpw_ = {1, 1, 1};
spw_ = {16, 8, 1};
rep_ = {2, 2, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
contig_per_thread_ = {1, 2};
order_ = {1, 0};
}
/* warps per tile */
// try to make things as square as possible to maximize data re-use
wpt_ = {1, 1, 1};
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_);
// try to make warp-level tiles as square as possible to maximize data re-use
if (tgt->as_nvidia()->sm() < 80) {
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt_;
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
}while(wpt_nm1 != wpt_);
} else {
bool changed = false;
// try to have a warp own entire rows of the output
// this makes it easier to fuse multiple mmas by fusing
// registers
bool one_warp_per_row = false;
for(ir::value* v: values)
for(ir::user* u: v->get_users()){
auto* dot = dynamic_cast<ir::dot_inst*>(u);
auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
one_warp_per_row = shape[0] / spw_[0] >= num_warps;
}
// std::cout << one_warp_per_row << std::endl;
if(one_warp_per_row){
wpt_[1] = 1;
wpt_[0] = num_warps;
}
else{
do {
changed = false;
if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
break;
if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
if (wpt_[0] < shape_[0] / spw_[0]) {
wpt_[0] *= 2;
changed = true;
}
} else {
if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
wpt_[1] *= 2;
changed = true;
}
}
} while(changed);
}
}
// std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
/* shape per block */
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
}
@@ -194,7 +283,7 @@ scanline_layout::scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){
analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
nts_.resize(shape_.size());
@@ -202,19 +291,19 @@ scanline_layout::scanline_layout(size_t num_warps,
bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
ir::value *ptr = nullptr;
std::vector<ir::value*> ptrs;
for(ir::value *v: values)
for(ir::user *usr: v->get_users())
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand();
}
for(ir::user *usr: v->get_users())
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
ptrs.push_back(io->get_pointer_operand());
}
unsigned i = order_[0];
int contiguous = 1;
if(ptr){
for(ir::value* ptr: ptrs){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
}
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
@@ -230,6 +319,10 @@ scanline_layout::scanline_layout(size_t num_warps,
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
num_threads = num_threads / mts_[i];
}
shape_per_cta_.resize(shape_.size());
for(size_t d = 0; d < shape_.size(); d++)
shape_per_cta_[d] = mts_[d]*nts_[d];
}
@@ -274,12 +367,16 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
res.reset(new double_buffer_info_t{value_1, value_0, phi});
}
static bool is_smem(ir::value* v) {
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v))
return true;
else
return false;
static bool is_smem_in(ir::value* v, const ir::basic_block* bb) {
if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) {
if (instr->get_parent() != bb)
return false;
if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
dynamic_cast<ir::masked_load_async_inst*>(v)) {
return true;
}
}
return false;
}
/// param:
@@ -294,14 +391,14 @@ static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::
ir::basic_block *cbb0 = cphi->get_incoming_block(0);
ir::basic_block *cbb1 = cphi->get_incoming_block(1);
if (is_smem(c0)) {
if (is_smem_in(c0, cbb0)) {
assert(cbb0 == bb0);
values_0.push_back(c0);
if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
next = phi1;
continue;
} else {
if (is_smem(c1)) {
if (is_smem_in(c1, cbb1)) {
value_1 = c1;
assert(cbb1 == bb1);
return true;
@@ -356,7 +453,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
analysis::align* align, target *tgt, bool is_tmp)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
size_ = 0;
arg_layout_ = arg;
@@ -382,12 +480,35 @@ shared_layout::shared_layout(data_layout *arg,
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
}
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// Update mma_vec
if (hmma_dot_a_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 0) // need transpose
allow_swizzle_ = false;
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 1) // need transpose
allow_swizzle_ = false;
}
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
@@ -451,7 +572,8 @@ void layouts::make_graph(ir::instruction *i) {
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto it_hmma_c = std::find_if(values.begin(), values.end(),
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
@@ -473,26 +595,72 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
(shared_layout*)layouts_.at(groups_.at(a)),
(shared_layout*)layouts_.at(groups_.at(b)),
dot);
}
else if(it_cts != values.end()){
ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
}
else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
}
}
// layout checkers
bool layouts::is_scanline(ir::instruction *i) {
return this->get(i->get_operand(0))->to_scanline() != nullptr;
}
bool layouts::is_coalesced_scanline(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
auto *scanline = this->get(i->get_operand(0))->to_scanline();
return scanline && scanline->get_order()[0] == red->get_axis();
}
return false;
}
bool layouts::is_mma(ir::instruction *i) {
return this->get(i->get_operand(0))->to_mma() != nullptr;
}
bool layouts::is_a100_mma(ir::instruction *i) {
if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
(red->get_axis() == 1);
}
return false;
}
void layouts::create_tmp_layout(size_t id, data_layout *arg,
const std::vector<int> &axes,
const std::vector<unsigned> &shape,
ir::instruction *i, bool is_index) {
ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
: i->get_type()->get_scalar_ty();
layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
if (is_index) {
tmp_index_[i] = id;
} else {
tmp_[i] = id;
}
}
void layouts::run(ir::module &mod) {
// make graph
graph_.clear();
layouts_.clear();
groups_.clear();
ir::for_each_instruction(mod, [this](ir::instruction* i) {
make_graph(i);
});
// connected components
graph_.connected_components(&values_, &groups_);
@@ -503,42 +671,50 @@ void layouts::run(ir::module &mod) {
// create temporaries
size_t id = values_.size();
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
// std::cout << "layout: " << std::endl;
// i->print(std::cout);
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
id++;
ir::value *arg = red->get_operand(0);
unsigned axis = red->get_axis();
distributed_layout *layout =
dynamic_cast<analysis::distributed_layout *>(get(arg));
// shape
auto shapes = arg->get_type()->get_block_shapes();
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
unsigned axis = red->get_axis();
shapes[axis] =
layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
tmp_[red] = id;
}
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
ir::value *val = recoalasce->get_operand(0);
mma_layout* in_layout = get(val)->to_mma();
scanline_layout* out_layout = get(i)->to_scanline();
if(!in_layout || !out_layout)
return;
id++;
ir::type::block_shapes_t in_shape = val->get_type()->get_block_shapes();
ir::type::block_shapes_t shape(in_shape.size());
size_t ld = out_layout->get_order(0);
shape[ld] = in_shape[ld];
for(size_t k = 0; k < in_shape.size(); k++)
if(k != ld)
shape[k] = in_layout->to_mma()->spt(k);
// create layout
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
tmp_[recoalasce] = id;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
if (red->with_index()) {
id++;
create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
}
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
size_t dim = val->get_type()->get_tile_rank();
ir::type::block_shapes_t shape(dim);
for(size_t k = 0; k < dim; k++){
shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k));
}
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
int in_vec = in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec);
shape[out_ord[0]] += pad;
id++;
create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
tmp_[atom] = id;
create_tmp_layout(id, nullptr, {}, {1}, atom);
}
});
}
}

View File

@@ -14,43 +14,108 @@ namespace analysis{
void liveness::run(ir::module &mod) {
intervals_.clear();
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
for(ir::function *fn: mod.get_function_list()){
slot_index index = 0;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *instr: block->get_inst_list()){
index += 1;
indices.insert({instr, index});
std::map<ir::value*, std::set<shared_layout*>> layouts_map;
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(!layout || layout->is_tmp())
continue;
for(ir::value* v:layout->get_values()){
layouts_map[v].insert(layout);
}
}
// create live intervals
std::map<ir::user*, std::set<shared_layout*>> live_in;
while(true){
bool changed = false;
ir::instruction* last_inst = nullptr;
ir::for_each_instruction_backward(mod, [&](ir::instruction* i){
// gen
std::set<shared_layout*> gen;
for(ir::value* v: i->ops())
for(shared_layout* layout: layouts_map[v])
gen.insert(layout);
// kill
std::set<shared_layout*> kill;
for(shared_layout* layout: layouts_map[i])
kill.insert(layout);
// temporaries are handled separately
if(layouts_->has_tmp(i)){
gen.insert(layouts_->get(layouts_->tmp(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp(i))->to_shared());
}
if(layouts_->has_tmp_index(i)){
gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared());
}
// live-out
std::set<shared_layout*> live_out;
std::vector<ir::instruction*> succs = {last_inst};
if(i == i->get_parent()->get_inst_list().back())
for(ir::basic_block* succ: i->get_parent()->get_successors())
succs.push_back(succ->get_inst_list().front());
for(ir::instruction* succ: succs)
for(shared_layout* layout: live_in[succ])
if(!layout->is_tmp())
live_out.insert(layout);
// new sets
std::set<shared_layout*> live_out_minus_kill;
std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(),
std::inserter(live_out_minus_kill, live_out_minus_kill.end()));
std::set<shared_layout*> new_live_in;
std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(),
std::inserter(new_live_in, new_live_in.end()));
changed = changed || (new_live_in != live_in[i]);
live_in[i] = new_live_in;
last_inst = i;
});
if(!changed)
break;
}
// ir::for_each_instruction(mod, [&](ir::instruction* i){
// i->print(std::cout);
// std::cout << " live_in: " << live_in[i].size() << std::endl;
// });
// Assigns index to each instruction
std::map<ir::value*, slot_index> indices;
slot_index index = 0;
ir::for_each_instruction(mod, [&](ir::instruction* instr){
index += 1;
indices.insert({instr, index});
});
for(auto &x: layouts_->get_all()){
shared_layout* layout = x.second->to_shared();
if(layout)
intervals_[layout] = segment{INT32_MAX, 0};
}
for(auto& x: live_in)
for(shared_layout* layout: x.second)
intervals_[layout].start = std::min<int>(intervals_[layout].start, indices[x.first]);
for(auto& x: live_in)
for(shared_layout* layout: x.second){
intervals_[layout].end = std::max<int>(intervals_[layout].end, indices[x.first] + 1);
}
for(auto &x: layouts_->get_all()) {
shared_layout* layout = x.second->to_shared();
if(!layout)
continue;
// users
std::set<ir::user*> users;
for(ir::value *v: layout->get_values()){
for(ir::user *u: v->get_users())
users.insert(u);
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->get_values())
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
if(end == 0)
end = start + 1;
intervals_[layout] = segment{start, end};
// std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl;
}
}

View File

@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
@@ -27,22 +28,31 @@ void swizzle::run(ir::module &) {
}
auto ord = layout->get_order();
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
if(!in_layout)
continue;
int per_phase = 1;
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(tgt_->as_nvidia()->sm() < 80){
if(in_layout)
per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
else
per_phase = 1;
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
int inner = mma_dot_a ? 0 : 1;
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
per_phase_[layout] = per_phase;
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
if(mma_dot_a)
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
else
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else{
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout];
vec_[layout] = 8;
else {
if (!layout->allow_swizzle()) {
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = per_phase;
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}
}

63
lib/codegen/extern_lib.cc Normal file
View File

@@ -0,0 +1,63 @@
#include "triton/codegen/extern_lib.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Type.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "triton/codegen/pass.h"
namespace triton {
namespace codegen {
std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
llvm::SMDiagnostic err;
auto mod = llvm::parseIRFile(this->path_, err, ctx);
if (!mod) {
throw std::runtime_error("Failed to load extern lib " + this->name_ +
" at " + this->path_);
}
return mod;
}
void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
std::unique_ptr<llvm::Module>& mod) {
// Set triple and data layout to match the target module
mod->setTargetTriple(llvm->getTargetTriple());
mod->setDataLayout(llvm->getDataLayout());
if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
throw std::runtime_error("Failed to link extern lib " + this->name_ +
" at " + this->path_);
}
}
void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
// Add nvvm reflect flags to llvm module
// https://llvm.org/docs/LangRef.html#module-flags-metadata
// i32 4: Override the other module.
// i32 1: Emit an error
// If both modules specify Override, but the values differ, an error
// will be emitted.
llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata* md_four =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata* md_one =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
llvm->addModuleFlag(reflect);
}
std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
const std::string& lib_path) {
if (lib_name == "libdevice") {
return std::make_unique<LibDevice>(lib_name, lib_path);
} else {
throw std::runtime_error("Unknown external library: " + lib_name);
}
}
} // namespace codegen
} // namespace triton

View File

@@ -1,4 +1,14 @@
#include "triton/codegen/pass.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/axes.h"
@@ -9,57 +19,98 @@
#include "triton/codegen/transform/cts.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/disassociate.h"
#include "triton/codegen/transform/inline.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/pipeline.h"
#include "triton/codegen/transform/prefetch.h"
#include "triton/driver/device.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/print.h"
#include "llvm/IR/Module.h"
namespace triton {
namespace codegen {
static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
const ExternLibMap& target_extern_lib_map,
ir::module& ir, llvm::LLVMContext& ctx,
std::unique_ptr<llvm::Module>& llvm) {
for (const auto& iter : target_extern_lib_map) {
auto &lib_name = iter.first;
if (user_extern_lib_map.count(lib_name) != 0 &&
user_extern_lib_map.at(lib_name)->path() != "") {
// If the user specified a path for this library, use it.
user_extern_lib_map.at(lib_name)->install(ctx, llvm);
} else {
// Otherwise, use the default path.
iter.second->install(ctx, llvm);
}
}
std::set<llvm::StringRef> function_names;
for (auto& func : ir.get_function_list()) {
function_names.insert(func->get_name());
}
llvm::legacy::PassManager pass;
pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
if (function_names.count(v.getName()) != 0) {
// Preserve global functions
return true;
}
// Internalize all device functions
return false;
}));
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*llvm);
llvm::PassManagerBuilder builder;
builder.OptLevel = 3;
builder.SizeLevel = 0;
builder.populateModulePassManager(pass);
pass.run(*llvm);
}
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, bool force_nc_cache,
driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) {
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
int num_warps, int num_stages, int& shared_static,
const ExternLibMap& extern_lib_map) {
// generate llvm code
llvm::LLVMContext ctx;
std::string name = ir.get_function_list()[0]->get_name();
std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
// optimizations
std::unique_ptr<codegen::target> target = dev->make_target();
bool cts_use_async = target->as_nvidia()->sm() >= 80;
bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
// create passes
codegen::analysis::align align;
codegen::transform::inliner inliner;
codegen::analysis::axes axes;
codegen::transform::cts cts(cts_use_async);
codegen::transform::pipeline pipeline(cts_use_async, num_stages);
codegen::transform::pipeline pipeline(has_sm80, num_stages);
codegen::transform::disassociate disassociate;
codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get());
codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
codegen::transform::cts cts(&layouts, has_sm80);
codegen::analysis::liveness liveness(&layouts);
codegen::analysis::swizzle swizzle(&layouts, target.get());
codegen::analysis::swizzle swizzle(&layouts, target);
codegen::analysis::allocation allocation(&liveness);
codegen::transform::dce dce;
codegen::transform::peephole peephole(target.get(), &layouts);
// codegen::transform::reassociate reassociate;
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::transform::prefetch prefetch_s(target.get());
codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get());
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps, force_nc_cache);
codegen::transform::peephole peephole(target, &layouts);
codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
codegen::transform::prefetch prefetch_s(target);
codegen::transform::membar barriers(&liveness, &layouts, &allocation,
&prefetch_s, target);
codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
target, num_warps);
// run passes
inliner.run(ir);
dce.run(ir);
peephole.run(ir);
dce.run(ir);
// ir::print(ir, std::cout);
pipeline.run(ir);
dce.run(ir);
// ir::print(ir, std::cout);
// ir.print(std::cout);
disassociate.run(ir);
dce.run(ir);
align.run(ir);
@@ -67,8 +118,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
layouts.run(ir);
peephole.run(ir);
dce.run(ir);
if (target->is_gpu())
cts.run(ir);
if (target->is_gpu()) cts.run(ir);
align.run(ir);
axes.run(ir);
layouts.run(ir);
@@ -76,10 +126,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
dce.run(ir);
align.run(ir);
dce.run(ir);
if (target->is_gpu()) {
// reassociate.run(ir);
cts.run(ir);
}
if (target->is_gpu()) cts.run(ir);
dce.run(ir);
align.run(ir);
axes.run(ir);
@@ -90,18 +137,34 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps,
axes.run(ir);
layouts.run(ir);
swizzle.run(ir);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
// std::cout << "---" << std::endl;
// ir.print(std::cout);
liveness.run(ir);
allocation.run(ir);
prefetch_s.run(ir);
// ir::print(ir, std::cout);
barriers.run(ir);
// ir::print(ir, std::cout);
// ir::print(ir, std::cout);
// exit(1);
// ir.print(std::cout);
isel.visit(ir, *llvm);
mod = driver::module::create(dev, std::move(llvm));
ker = driver::kernel::create(&*mod, name.c_str());
shared_mem = allocation.allocated_size();
shared_static = allocation.allocated_size();
if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
// sm < 70 (Pascal) has little shared memory resource.
// Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
if (shared_static >= 65536) {
throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
}
}
if (isel.get_extern_lib_map().size() > 0) {
// If there's any extern lib calls,
// we need to link them in.
link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
}
return llvm;
}
} // namespace codegen
} // namespace triton
} // namespace codegen
} // namespace triton

File diff suppressed because it is too large Load Diff

View File

@@ -12,131 +12,105 @@ namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
: align_(align), layout_(layouts) { }
// Find all values that are used as pointer operands in LD/ST
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
if(i && i->get_pointer_operand() == v)
result.insert(i);
}
}
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
ir::value *ptr = i->get_pointer_operand();
auto contiguous = align_->contiguous(ptr);
auto it = std::max_element(contiguous.begin(), contiguous.end());
int axis = std::distance(contiguous.begin(), it);
result[axis].push_back(i);
}
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
std::map<ir::value*, ir::value*>& seen) {
if(seen.find(x) != seen.end())
return seen.at(x);
auto i = dynamic_cast<ir::instruction*>(x);
// not an instruction -- forward value
if(!i)
return x;
// already in shared memory -- forward value
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){
return x;
}
// set insert point
auto& inst_list = i->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(x)){
ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x));
return ret;
}
// default -- recursive clone
ir::instruction *cloned = builder.insert(i->clone());
seen[i] = cloned;
// rematerialize operands
for(ir::value *op: cloned->ops())
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen));
return cloned;
}
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
: align_(align), layout_(layouts), has_sm80_(has_sm80) { }
void coalesce::run(ir::module &mod) {
size_t num_groups = layout_->num_layouts();
for(size_t id = 0; id < num_groups; id++) {
if(!layout_->get(id)->to_mma())
continue;
// extract memory stores
const auto& values = layout_->values_of(id);
ir::value* dot = nullptr;
for(ir::value *v: values)
if(auto x = dynamic_cast<ir::dot_inst*>(v))
dot = x;
ir::builder& builder = mod.get_builder();
std::vector<ir::value*> worklist = {dot};
std::set<ir::value*> seen;
while(!worklist.empty()) {
ir::value *current = worklist.back();
seen.insert(current);
worklist.pop_back();
// stop if trunc
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
std::set<analysis::data_layout*> invalidated;
ir::builder& builder = mod.get_builder();
// add layout conversion instructions
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// coalesce before store
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
if(ir::value* op = i->get_operand(1))
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_ranks1() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma())
if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
i->replace_uses_of_with(op, new_op);
}
// coalesce before copy_to_shared
// only necessary for sm < 80 as Ampere+ can handle reduction
// on MMA layout
if(!has_sm80_)
if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
if(ir::value* op = i->get_operand(0))
if(op->get_type()->is_block_ty())
if(op->get_type()->get_tile_ranks1() == 2)
if(invalidated.find(layout_->get(op)) == invalidated.end())
if(layout_->get(op)->to_mma()){
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
builder.set_insert_point(i);
builder.insert(new_op);
op->replace_all_uses_with(new_op);
new_op->replace_uses_of_with(new_op, op);
invalidated.insert(layout_->get(op));
}
// uncoalesce after load
if(auto x = dynamic_cast<ir::load_inst*>(i))
if(x->get_type()->is_block_ty())
if(x->get_type()->get_tile_ranks1()==2)
if(layout_->get(x)->to_mma())
if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
builder.set_insert_point_after(x);
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
builder.insert(rc);
x->replace_all_uses_with(rc);
rc->replace_uses_of_with(rc, x);
break;
}
// recurse
for(ir::user *u: current->get_users())
if(seen.find(u) == seen.end())
worklist.push_back(u);
ir::instruction* new_x = ir::cvt_layout_inst::create(x);
builder.insert(new_x);
x->replace_all_uses_with(new_x);
new_x->replace_uses_of_with(new_x, x);
}
}
// find values to rematerialize
std::vector<ir::io_inst*> remat;
for(size_t id = 0; id < num_groups; id++) {
const auto& values = layout_->values_of(id);
// extract pointers used in ld/st operations
std::set<ir::io_inst*> io;
for(ir::value *v: values)
extract_io_use(v, io);
// extract leading axes
std::map<int, std::vector<ir::io_inst*>> axes;
for(ir::io_inst *i: io){
if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){
extract_ld(i, axes);
}
}
// update list of values to rematerialize
if(axes.empty())
continue;
for(auto it = ++axes.rbegin(); it != axes.rend(); it++){
if(it->second.size() == 1)
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
// re-arrange scanline to promote memory coalescing
if(auto x = dynamic_cast<ir::store_inst*>(i)){
ir::value* ptr = x->get_pointer_operand();
ir::value* val = x->get_value_operand();
auto out_contig = align_->contiguous(ptr);
auto val_inst = dynamic_cast<ir::instruction*>(val);
if(!val_inst)
continue;
remat.insert(remat.begin(), it->second.begin(), it->second.end());
}
}
// rematerialize values
for(ir::io_inst *r: remat) {
ir::builder& builder = mod.get_builder();
// rematerialize operands
std::map<ir::value*, ir::value*> seen;
for(ir::value *op: r->ops())
r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen));
// copy to shared if load
auto& inst_list = r->get_parent()->get_inst_list();
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
builder.set_insert_point(pos);
if(dynamic_cast<ir::load_inst*>(r)){
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
r->replace_all_uses_with(cts);
cts->replace_uses_of_with(cts, r);
if(dynamic_cast<ir::cvt_layout_inst*>(val))
continue;
if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
continue;
std::vector<unsigned> in_contig;
std::vector<ir::instruction*> queue = {val_inst};
std::set<ir::instruction*> seen;
std::vector<ir::io_inst*> ios;
while(!queue.empty()){
ir::instruction* curr = queue.back();
seen.insert(curr);
queue.pop_back();
if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
break;
if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
in_contig = align_->contiguous(io_inst->get_pointer_operand());
break;
}
for(ir::value* op: curr->ops()){
auto inst_op = dynamic_cast<ir::instruction*>(op);
if(!inst_op || seen.find(inst_op) != seen.end())
continue;
if(!op->get_type()->is_block_ty() ||
!val->get_type()->is_block_ty())
continue;
if(op->get_type()->get_tile_num_elements() ==
val->get_type()->get_tile_num_elements())
queue.push_back(inst_op);
}
}
if(in_contig.size() <= 1 || out_contig==in_contig)
continue;
builder.set_insert_point_after(val_inst);
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
x->replace_uses_of_with(val_inst, new_val);
}
}
}

View File

@@ -1,8 +1,10 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/codegen/transform/cts.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
@@ -10,9 +12,9 @@ namespace codegen{
namespace transform{
inline bool is_shmem_op(ir::instruction* i, int op) {
bool cts::is_shmem_op(ir::instruction* i, int op) {
if(i->get_id() == ir::INST_DOT)
return op==0 || op==1;
return op == 0 || op == 1;
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
return op==0;
if(i->get_id() == ir::INST_TRANS)
@@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) {
return false;
}
inline bool is_shmem_res(ir::value* v){
bool cts::is_shmem_res(ir::value* v){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return false;
@@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){
// run pass on module
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
auto *i = dynamic_cast<ir::instruction*>(x);
// not an instruction
if(!i) {
@@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
// phi node
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
return;
}
// already in shared memory
@@ -65,33 +67,52 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder,
}
else
copy = builder.create_copy_from_shared(x);
parent->replace_uses_of_with(x, copy);
copies.insert({x, copy});
parent->replace_uses_of_with(x, copies.at(x));
}
void cts::run(ir::module &mod) {
// Add shared copies
ir::builder &builder = mod.get_builder();
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* i: block->get_inst_list()){
size_t num_op = i->get_num_operands();
// copy to shared operands
for(size_t k = 0; k < num_op; k++)
if(is_shmem_op(i, k)){
add_copy(i, i->get_operand(k), builder, true);
}
// copy from shared operands
for(size_t k = 0; k < num_op; k++)
if(!dynamic_cast<ir::phi_node*>(i) &&
!is_shmem_op(i,k) &&
is_shmem_res(i->get_operand(k))){
add_copy(i, i->get_operand(k), builder, false);
}
// Precompute where copies should be added
std::set<ir::value*> shmem_ops;
std::set<ir::value*> shmem_res;
ir::for_each_instruction(mod, [&](ir::instruction* i) {
if(i->get_id() == ir::INST_DOT){
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
ir::value* lhs = i->get_operand(0);
ir::type* ty = lhs->get_type()->get_scalar_ty();
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
// TODO: V100
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
if(is_lhs_shmem)
shmem_ops.insert(lhs);
shmem_ops.insert(i->get_operand(1));
}
}
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS)
shmem_ops.insert(i->get_operand(0));
if(i->get_id() == ir::INST_TRANS ||
i->get_id() == ir::INST_COPY_TO_SHARED ||
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
shmem_res.insert(i);
});
// Add shared copies
std::map<ir::value*, ir::value*> copies;
ir::builder &builder = mod.get_builder();
ir::for_each_instruction(mod, [&](ir::instruction* i) {
size_t num_op = i->get_num_operands();
for(size_t k = 0; k < num_op; k++){
ir::value* op = i->get_operand(k);
// copy to shared operands
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
if(is_shmem_op)
add_copy(i, op, builder, true, copies);
}
});
}
}
}
}
}

View File

@@ -3,6 +3,7 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
#include <iostream>
namespace triton {
namespace codegen{
@@ -28,6 +29,8 @@ void dce::run(ir::module &mod) {
case ir::INST_ATOMIC_CAS:
case ir::INST_ATOMIC_RMW:
case ir::INST_ATOMIC_EXCH:
case ir::INST_CALL:
case ir::INST_LAUNCH:
case ir::INST_BARRIER: {
work_list.push_back(i);
marked.insert(i);
@@ -65,6 +68,7 @@ void dce::run(ir::module &mod) {
}
}
// delete
for(ir::instruction* i: to_delete)
i->erase_from_parent();

View File

@@ -9,67 +9,50 @@ namespace triton {
namespace codegen{
namespace transform{
void extract_retile_chain(ir::user *root,
std::map<int, std::set<ir::user*>>& result,
int depth,
ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
std::set<ir::value*>& seen) {
if (dynamic_cast<ir::phi_node*>(root))
return root;
if(!seen.insert(root).second)
return;
result[depth].insert(root);
if(dynamic_cast<ir::make_range*>(root) ||
dynamic_cast<ir::splat_inst*>(root)){
return;
}
return root;
if(!root->get_type()->is_block_ty())
return root;
bld.set_insert_point(root);
ir::instruction *new_root = bld.insert(root->clone());
for(ir::value *op: root->ops()){
ir::user *u = dynamic_cast<ir::user*>(op);
if(!u)
ir::instruction *i = dynamic_cast<ir::instruction*>(op);
if(!i || i->get_id() == ir::INST_REDUCE)
continue;
extract_retile_chain(u, result, depth + 1, seen);
ir::instruction* new_op = rematerialize(bld, i, seen);
new_root->replace_uses_of_with(op, new_op);
}
return new_root;
}
void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder();
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
// ir::for_each_instruction(mod, [&](ir::instruction *i){
// bld.set_insert_point(i);
// for(ir::value* op: i->ops()){
// auto reshape = dynamic_cast<ir::make_range*>(op);
// if(!reshape)
// continue;
// ir::instruction* new_op = bld.insert(reshape->clone());
// i->replace_uses_of_with(op, new_op);
// }
// });
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){
ir::value* op = i->get_operand(0);
if(!dynamic_cast<ir::user*>(op))
return;
if(op->get_type()->get_tile_rank() > i->get_type()->get_tile_rank())
return;
std::map<int, std::set<ir::user*>> chains;
if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
std::set<ir::value*> seen;
extract_retile_chain(i, chains, 0, seen);
if(chains.size())
clone_info[i] = chains;
ir::instruction* new_i = rematerialize(bld, i, seen);
i->replace_all_uses_with(new_i);
}
});
for(const auto& x: clone_info){
int depth = 1;
std::map<ir::instruction*, ir::instruction*> clone_map;
while(x.second.find(depth) != x.second.end()){
// clone all users
const auto& remat = x.second.at(depth);
for(ir::user* u: remat){
ir::instruction *y = (ir::instruction*)u;
ir::instruction *cloned = y->clone();
bld.set_insert_point(y);
bld.insert(cloned);
clone_map[y] = cloned;
// replace operands of parents
if(depth > 1)
for(ir::user* ux: x.second.at(depth - 1))
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
else
x.first->replace_uses_of_with(y, cloned);
}
depth += 1;
}
}
}

View File

@@ -0,0 +1,147 @@
#include <iostream>
#include "triton/codegen/transform/inline.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
namespace transform{
bool fncmp::operator()(ir::function* x, ir::function* y) const {
auto fn_list = x->get_parent()->get_function_list();
return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
};
void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
std::list<ir::call_inst*>& callsites){
ir::basic_block* parent_block = callsite->get_parent();
ir::function* parent_fn = parent_block->get_parent();
// the parent block is split into block A and block B:
// - block A (`new_blocks[0]`) is the entry block of the inlined function
// - block B (`exit`) resumes execution of the parent function
ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
ir::basic_block* exit = entry->get_successors()[0];
std::vector<ir::basic_block*> new_blocks = {entry};
for(size_t i = 1; i < fn->blocks().size(); i++){
ir::basic_block* block = fn->blocks()[i];
ir::context& ctx = block->get_context();
const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
}
// a phi node holds the return values of the inlined function
if(exit->get_inst_list().empty())
builder.set_insert_point(exit);
else
builder.set_insert_point(exit->get_first_non_phi());
ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
callsite->replace_all_uses_with(exit_val);
callsite->erase_from_parent();
// get arguments `fn` is called with
std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
// Actually generate the instructions:
// - Remove the branch created by basic_block::split_before
// - Clone all instructions
// - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
// new_blocks[0]->get_inst_list().back()->erase_from_parent();
terminator->erase_from_parent();
std::map<ir::instruction*, ir::instruction*> inst_map;
std::map<ir::argument*, ir::value*> arg_map;
for(size_t k = 0; k < fn->args().size(); k++)
arg_map[fn->args()[k]] = callsite->ops()[k];
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// clone instructions
for(size_t i = 0; i < new_blocks.size(); i++){
ir::basic_block* old_block = fn->blocks()[i];
ir::basic_block* new_block = new_blocks[i];
builder.set_insert_point(new_block);
for(ir::instruction* old_inst: old_block->get_inst_list()){
ir::instruction* new_inst = old_inst->clone();
inst_map[old_inst] = new_inst;
builder.insert(new_inst);
}
}
// update basic blocks
for(size_t i = 0; i < new_blocks.size(); i++) {
for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
// replace basic use cases
for(size_t k = 0; k < new_blocks.size(); k++)
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
// additionally replace basic blocks of phi-nodes since
// replace_uses_of_with() does not replace them.
for(unsigned in = 0; in < phi->get_num_incoming(); in++)
for(size_t k = 0; k < new_blocks.size(); k++)
if (phi->get_incoming_block(in) == fn->blocks()[k])
phi->set_incoming_block(in, new_blocks[k]);
}
}
}
// replace operands of instructions after constructing inst_map
for (auto& it: inst_map) {
ir::instruction* new_inst = it.second;
for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
ir::value* op = new_inst->get_operand(k);
if(auto arg_op = dynamic_cast<ir::argument*>(op))
new_inst->set_operand(k, arg_map.at(arg_op));
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// handles a ret instruction.
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
if(ir::value* ret_val = ret->get_return_value())
exit_val->add_incoming(ret_val, new_inst->get_parent());
// replace ret with branch
ir::instruction* new_br_inst = ir::branch_inst::create(exit);
builder.set_insert_point(new_inst->get_parent());
builder.insert(new_br_inst);
new_inst->erase_from_parent();
}
}
if(exit_val->get_num_incoming() == 1)
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
// done -- make sure insert point is properly set to exit block
builder.set_insert_point(exit);
}
void inliner::run(ir::module &mod) {
// gather all call sites
while(true){
std::map<ir::function*, size_t> counts;
for(ir::function* fn: mod.get_function_list())
counts[fn] = 0;
std::list<ir::call_inst*> callsites;
for(ir::function* fn: mod.get_function_list()){
for(ir::basic_block* block: fn->blocks())
for(ir::instruction* instr: block->get_inst_list())
if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
callsites.push_back(call);
counts[call->get_fn()] += 1;
}
}
for(auto& count: counts){
if(!count.first->get_is_kernel() && count.second == 0)
count.first->get_parent()->remove_function(count.first);
}
if(callsites.empty())
break;
for(ir::call_inst* call: callsites)
do_inline(call->get_fn(), call, mod.get_builder(), callsites);
}
}
}
}
}

View File

@@ -36,6 +36,9 @@ int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
else{
if(layouts_->has_tmp(v))
return async_write.size() - 1;
// // Ignore copy_to_shared. It won't modify async behavior.
// if(dynamic_cast<ir::copy_to_shared_inst*>(v))
// return 0;
auto it = std::find(async_write.begin(), async_write.end(), v);
return std::distance(async_write.begin(), it);
}
@@ -60,15 +63,22 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b
continue;
analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
for(ir::value* b: bs){
if(!b->get_type()->is_block_ty())
continue;
analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
if(intersect_with(a_layout, b_layout) ||
intersect_with(a_layout, b_tmp) ||
intersect_with(a_layout, b_tmp_index) ||
intersect_with(a_tmp, b_layout) ||
intersect_with(a_tmp, b_tmp))
intersect_with(a_tmp, b_tmp) ||
intersect_with(a_tmp, b_tmp_index) ||
intersect_with(a_tmp_index, b_layout) ||
intersect_with(a_tmp_index, b_tmp) ||
intersect_with(a_tmp_index, b_tmp_index))
ret.insert(b);
}
}

View File

@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && add->get_op() == ir::binary_op_t::FAdd) {
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
ir::value *lhs = add->get_operand(0);
ir::value *rhs = add->get_operand(1);
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
@@ -72,15 +73,21 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *other = (dot == lhs) ? rhs : lhs;
ir::value *acc = dot->get_operand(2);
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
ir::constant_fp *_0 = nullptr;
ir::constant *_0 = nullptr;
if(splat)
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
if(!(_0 && _0->get_value() == 0.0))
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
if(!_0)
return false;
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
if (fp_0->get_value() != 0.0)
return false;
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
if (int_0->get_value() != 0)
return false;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}
@@ -116,7 +123,7 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build
int nts = layout->nts(layout->get_order()[0]);
int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
if(nts*dtsize >= 4){
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
copy_to_shared->replace_all_uses_with(new_load);
return true;
}
@@ -143,32 +150,53 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
}
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
auto binop = dynamic_cast<ir::binary_operator*>(value);
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
ir::value *lhs = binop->get_operand(0);
ir::value *rhs = binop->get_operand(1);
ir::constant_int *_1_lhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_lhs = cst;
}
ir::constant_int *_1_rhs = nullptr;
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
if(cst && cst->get_value() == 1)
_1_rhs = cst;
}
if(_1_lhs){
binop->replace_all_uses_with(rhs);
return true;
}
else if(_1_rhs){
binop->replace_all_uses_with(lhs);
return true;
}
}
return false;
}
bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
if(!extracted)
return false;
size_t extract_idx = extracted->get_idx();
ir::value* agg = extracted->get_operand(0);
auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
while(insert){
agg = insert->get_operand(0);
ir::value* inserted = insert->get_operand(1);
size_t insert_idx = insert->get_idx();
insert = dynamic_cast<ir::insert_value_inst*>(agg);
if(extract_idx == insert_idx){
extracted->replace_all_uses_with(inserted);
return true;
}
insert = dynamic_cast<ir::insert_value_inst*>(agg);
}
return false;
}
@@ -206,11 +234,50 @@ bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& b
builder.set_insert_point(select);
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
if_value->get_mask_operand(),
select->get_else_value_op());
select->get_else_value_op(),
if_value->get_cache_modifier(),
if_value->get_eviction_policy(),
if_value->get_is_volatile());
select->replace_all_uses_with(new_load);
return true;
}
bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
if(!cvt)
return false;
ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
if(!op)
return false;
// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
// if(op->get_id() == ir::INST_BINOP){
// for(size_t i = 0; i < op->get_num_operands(); i++){
// ir::value* arg_i = op->get_operand(i);
// builder.set_insert_point(op);
// // create new layout transform
// ir::instruction* new_arg_i = cvt->clone();
// layouts_->copy(new_arg_i, op);
// builder.insert(new_arg_i);
// // set the right args
// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
// op->replace_uses_of_with(arg_i, new_arg_i);
// }
// cvt->replace_all_uses_with(op);
// return true;
// }
auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
if(!cvt_op)
return false;
// convert1(convert2(x)) if convert1 is the inverse of convert2
ir::value* op_op = cvt_op->get_operand(0);
if(layouts_->has(cvt) && layouts_->has(op_op) &&
layouts_->get(cvt) && layouts_->get(op_op)){
cvt->replace_all_uses_with(op_op);
return true;
}
return false;
}
void peephole::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// keep track of whether any modification was made
@@ -245,10 +312,13 @@ void peephole::run(ir::module &mod) {
was_modified = was_modified || rewrite_mult(i, builder);
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
// was_modified = was_modified || rewrite_trans_phi(i, builder);
was_modified = was_modified || rewrite_insert_extract(i, builder);
was_modified = was_modified || rewrite_unit_red(i, builder);
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
was_modified = was_modified || rewrite_select_masked_load(i, builder);
if(tgt_->as_nvidia()->sm() >= 80)
// TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
// was_modified = was_modified || rewrite_select_masked_load(i, builder);
was_modified = was_modified || rewrite_cvt_layout(i, builder);
if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
was_modified = was_modified || rewrite_load_to_shared(i, builder);
if(was_modified)
seen.insert(i);

View File

@@ -23,29 +23,6 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instru
recursive_deps(u, block, ret);
}
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
throw std::runtime_error("Don't have that phi node\n");
return prev_phi_vals.at(phi);
}
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
ret->set_operand(k, new_ops[k]);
builder.insert(ret);
return ret;
}
void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
auto instr = dynamic_cast<ir::instruction*>(cond);
for (auto op : instr->ops()) {
@@ -58,17 +35,21 @@ void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
}
}
/// Returns phi_val if sees a phi node
ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) {
/// assume incoming block is 1
ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi_val;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
if (prev_phi_vals.find(phi) == prev_phi_vals.end())
throw std::runtime_error("Don't have that phi node\n");
return prev_phi_vals.at(phi);
}
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize_val(builder, op, phi_val));
new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
@@ -77,16 +58,17 @@ ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_
return ret;
}
ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
ir::value* v, size_t phi_idx){
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
if(!i)
if(!i || i->get_parent() != block)
return v;
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
return phi->get_incoming_value(phi_idx);
std::vector<ir::value*> new_ops;
for(ir::value* op: i->ops()){
new_ops.push_back(rematerialize(builder, op, phi_idx));
new_ops.push_back(rematerialize(builder, block, op, phi_idx));
}
ir::instruction* ret = i->clone();
for(size_t k = 0; k < new_ops.size(); k++)
@@ -96,18 +78,20 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){
}
/// moving the prev phi vals to the next iteration
void update_prev_phi_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
for (auto& [phi, val] : prev_phi_vals) {
// TODO: handling nested phis
val = rematerialize_val(builder, phi->get_incoming_value(1), val);
std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
std::map<ir::phi_node*, ir::value*> next_phi_vals;
for (auto &[phi, val] : prev_phi_vals) {
next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
}
return next_phi_vals;
}
void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>& load_ivs,
void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
for (auto& [phi, val] : load_ivs) {
if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs);
ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
new_phi->add_incoming(next_k, phi->get_incoming_block(1));
// cache next_k (to be used by next_mask)
@@ -117,30 +101,43 @@ void finalize_iv_vals(ir::builder& builder, std::map<ir::phi_node*, ir::value*>&
}
}
struct pipeline_info_t {
ir::load_inst* load;
ir::phi_node* ptr;
ir::dot_inst* dot;
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
: load(load), ptr(ptr), dot(dot) {}
};
void pipeline::run(ir::module &mod) {
if (num_stages_ <= 1)
return;
// *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if:
// - the pointer is a phi node that references a value
// in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
std::vector<pipeline_info_t> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users();
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
to_pipeline.push_back({load, ptr});
&& users.size() == 1 && dot)
to_pipeline.push_back({load, ptr, dot});
}});
// do the pipelining
std::vector<ir::phi_node*> new_loads;
ir::builder &builder = mod.get_builder();
const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){
ir::load_inst* load = info.first;
ir::phi_node* ptr = info.second;
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;
ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
@@ -163,12 +160,11 @@ void pipeline::run(ir::module &mod) {
std::map<ir::phi_node*, ir::value*> prev_phi_vals;
// initialize prev_phi_vals
// note: we assume that ptr & other values only depend on ptr & iv (phis)
// TODO: can we just add all phis here?
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
for (ir::phi_node* iv : induction_vars)
prev_phi_vals[iv] = iv->get_value_for_block(header);
prev_phi_vals[ptr] = ptr->get_value_for_block(header);
// Add all phi nodes. The following DCE pass will delete dead ones.
for (ir::instruction *instr : block->get_inst_list())
if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
if (phi->get_incoming_block(1) == block)
prev_phi_vals[phi] = phi->get_value_for_block(header);
builder.set_insert_point(header->get_inst_list().back());
first_ptrs[0] = ptr->get_value_for_block(header);
@@ -176,57 +172,58 @@ void pipeline::run(ir::module &mod) {
first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
ir::value* false_value = nullptr;
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[0] = builder.create_and(first_masks[0], remat_mask);
false_value = remat_false_value;
} else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value);
first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
for (int stage = 1; stage < num_stages-1; ++stage) {
// mask is the loop condition of the previous iteration
loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals);
update_prev_phi_vals(builder, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals);
loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals);
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
false_value = remat_false_value;
}
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value);
first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
}
// create new phis for induction variables
builder.set_insert_point(block->get_first_non_phi());
std::map<ir::phi_node*, ir::value*> load_ivs;
std::map<ir::phi_node*, ir::value*> next_load_ivs;
for (ir::phi_node* iv : induction_vars) {
for (auto& [iv, val] : prev_phi_vals) {
ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
pn->add_incoming(prev_phi_vals[iv], header);
load_ivs[iv] = pn;
}
// add incoming for phis & update next_load_ivs
finalize_iv_vals(builder, load_ivs, next_load_ivs);
finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
// ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
ir::value* next_mask = builder.create_splat(
rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes());
rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs);
ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
// TODO: false may depends on some other phi nodes
ir::value* remat_false_value =
rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs);
rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
@@ -254,25 +251,25 @@ void pipeline::run(ir::module &mod) {
ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
ir::value* false_value;
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0);
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
first_mask = builder.create_and(first_mask, remat_mask);
false_value = remat_false_value;
}
else
false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value);
ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// pre-fetch next iteration
builder.set_insert_point(block->get_inst_list().back());
ir::value* next_ptr = ptr->get_value_for_block(block);
ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1);
ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
next_mask = builder.create_and(next_mask, remat_mask);
false_value = remat_false_value;
}
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value);
ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
// phi node
builder.set_insert_point(block->get_first_non_phi());
ir::phi_node* new_load = builder.create_phi(ty, 2);
@@ -304,22 +301,23 @@ void pipeline::run(ir::module &mod) {
std::vector<ir::instruction*> insts;
ir::load_inst* dst;
};
std::map<ir::basic_block*, move_config_t> to_move;
std::vector<move_config_t> to_move(to_pipeline.size());
if(has_copy_async_){
for(ir::function* fn: mod.get_function_list())
for(ir::basic_block* bb: fn->blocks())
for(ir::instruction* inst: bb->get_inst_list()){
if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
recursive_deps(i, bb, to_move[bb].insts);
if(auto* i = dynamic_cast<ir::load_inst*>(inst))
to_move[bb].dst = i;
for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
auto info = to_pipeline[idx];
ir::load_inst* load = info.load;
ir::phi_node* ptr = info.ptr;
ir::dot_inst* dot = info.dot;
ir::basic_block* bb = dot->get_parent();
recursive_deps(dot, bb, to_move[idx].insts);
to_move[idx].dst = load;
}
for(auto& x: to_move){
builder.set_insert_point_after(x.second.dst);
for(ir::instruction* i: x.second.insts){
x.first->erase(i);
for(auto& move_config: to_move){
builder.set_insert_point_after(move_config.dst);
for(ir::instruction* i: move_config.insts){
i->get_parent()->erase(i);
builder.insert(i);
}
}

View File

@@ -29,8 +29,16 @@ void prefetch::run(ir::module &mod) {
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID)
// Now only do prefetching when dot is using tensor cores
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
@@ -83,7 +91,7 @@ void prefetch::run(ir::module &mod) {
}
// move loads to the beginning of the loop
if (tgt_->as_nvidia()->sm() < 80) {
if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) {
for (ir::function *fn : mod.get_function_list())
for (ir::basic_block *bb : fn->blocks()) {
// only apply to loop body

View File

@@ -1,231 +0,0 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <vector>
#include <stdexcept>
#include "triton/driver/dispatch.h"
#include "triton/driver/backend.h"
#include "triton/driver/buffer.h"
#include "triton/driver/context.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
namespace triton
{
namespace driver
{
/*-----------------------------------*/
//----------- Platforms ------------*/
/*-----------------------------------*/
void backend::platforms::init() {
if(!cache_.empty())
return;
//if CUDA is here
if(dispatch::cuinit()){
cache_.push_back(new cu_platform());
}
//if host should be added
bool host_visible = true;
if(host_visible){
cache_.push_back(new host_platform());
}
// //if OpenCL is here
// if(dispatch::clinit()){
// cl_uint num_platforms;
// dispatch::clGetPlatformIDs(0, nullptr, &num_platforms);
// std::vector<cl_platform_id> ids(num_platforms);
// dispatch::clGetPlatformIDs(num_platforms, ids.data(), nullptr);
// for(cl_platform_id id: ids)
// cache_.push_back(new cl_platform(id));
// }
if(cache_.empty())
throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path");
}
void backend::platforms::get(std::vector<platform *> &results) {
std::copy(cache_.begin(), cache_.end(), std::back_inserter(results));
}
std::vector<driver::platform*> backend::platforms::cache_;
/*-----------------------------------*/
//----------- Devices --------------*/
/*-----------------------------------*/
void backend::devices::init(std::vector<platform*> const & platforms) {
if(!cache_.empty())
return;
for(driver::platform* pf: platforms)
pf->devices(cache_);
if(cache_.empty())
throw std::runtime_error("Triton: No device available. Make sure that your platform is configured properly");
}
void backend::devices::get(std::vector<device*> &devs) {
std::copy(cache_.begin(), cache_.end(), std::back_inserter(devs));
}
std::vector<driver::device*> backend::devices::cache_;
/*-----------------------------------*/
//---------- Modules ----------------*/
/*-----------------------------------*/
void backend::modules::release(){
for(auto & x: cache_)
delete x.second;
cache_.clear();
}
std::map<std::tuple<driver::stream*, std::string>, driver::module*> backend::modules::cache_;
/*-----------------------------------*/
//----------- Kernels --------------*/
/*-----------------------------------*/
void backend::kernels::release(){
for(auto & x: cache_)
delete x.second;
cache_.clear();
}
driver::kernel* backend::kernels::get(driver::module *mod, std::string const & name){
std::tuple<driver::module*, std::string> key(mod, name);
if(cache_.find(key)==cache_.end()){
return &*cache_.insert({key, driver::kernel::create(mod, name.c_str())}).first->second;
}
return cache_.at(key);
}
std::map<std::tuple<driver::module*, std::string>, driver::kernel*> backend::kernels::cache_;
/*-----------------------------------*/
//------------ Queues --------------*/
/*-----------------------------------*/
void backend::streams::init(std::list<driver::context*> const & contexts){
for(driver::context* ctx : contexts)
if(cache_.find(ctx)==cache_.end())
cache_.insert(std::make_pair(ctx, std::vector<driver::stream*>{driver::stream::create(ctx->backend())}));
}
void backend::streams::release(){
for(auto & x: cache_)
for(auto & y: x.second)
delete y;
cache_.clear();
}
driver::stream* backend::streams::get_default()
{ return get(contexts::get_default(), 0); }
driver::stream* backend::streams::get(driver::context* context, unsigned int id){
init(std::list<driver::context*>(1,context));
for(auto & x : cache_)
if(x.first==context)
return x.second[id];
throw;
}
void backend::streams::get(driver::context* context, std::vector<driver::stream*> & queues){
init(std::list<driver::context*>(1,context));
queues = cache_.at(context);
}
std::map<driver::context*, std::vector<driver::stream*>> backend::streams::cache_;
/*-----------------------------------*/
//------------ Contexts ------------*/
/*-----------------------------------*/
void backend::contexts::init(std::vector<driver::device*> const & devices){
for(driver::device* dvc: devices)
cache_.push_back(driver::context::create(dvc));
}
void backend::contexts::release(){
for(auto & x: cache_)
delete x;
cache_.clear();
}
driver::context* backend::contexts::get_default(){
backend::init();
auto it = cache_.begin();
std::advance(it, default_device);
return *it;
}
void backend::contexts::get(std::list<driver::context*> & contexts){
backend::init();
contexts = cache_;
}
std::list<driver::context*> backend::contexts::cache_;
/*-----------------------------------*/
//------------ General -------------*/
/*-----------------------------------*/
void backend::synchronize(driver::context* context){
for(driver::stream * queue: streams::cache_.at(context))
queue->synchronize();
}
void backend::release(){
backend::kernels::release();
// backend::programs::release();
backend::streams::release();
backend::contexts::release();
}
void backend::init(){
if(!contexts::cache_.empty())
return;
// initialize platforms
backend::platforms::init();
// initialize devices
backend::devices::init(platforms::cache_);
// initialize contexts
backend::contexts::init(devices::cache_);
// initialize streams
streams::init(contexts::cache_);
}
unsigned int backend::default_device = 0;
}
}

View File

@@ -1,90 +0,0 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/driver/stream.h"
#include "triton/driver/buffer.h"
#include "triton/driver/context.h"
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
//
buffer::buffer(size_t size, CUdeviceptr cu, bool take_ownership)
: polymorphic_resource(cu, take_ownership), size_(size) { }
buffer::buffer(size_t size, host_buffer_t hst, bool take_ownership)
: polymorphic_resource(hst, take_ownership), size_(size) { }
size_t buffer::size() {
return size_;
}
uintptr_t buffer::addr_as_uintptr_t() {
switch(backend_){
case CUDA: return *cu_;
case Host: return (uintptr_t)hst_->data;
default: return 0;
}
}
buffer* buffer::create(driver::context* ctx, size_t size) {
switch(ctx->backend()){
case CUDA: return new cu_buffer(size);
case Host: return new host_buffer(size);
default: throw std::runtime_error("unknown backend");
}
}
//
host_buffer::host_buffer(size_t size)
: buffer(size, host_buffer_t(), true){
hst_->data = new char[size];
}
//
cu_buffer::cu_buffer(size_t size)
: buffer(size, CUdeviceptr(), true) {
dispatch::cuMemAlloc(&*cu_, size);
}
cu_buffer::cu_buffer(size_t size, CUdeviceptr cu, bool take_ownership)
: buffer(size, cu, take_ownership){
}
void cu_buffer::set_zero(driver::stream* queue, size_t size){
dispatch::cuMemsetD8Async(*cu_, 0, size, *queue->cu());
}
}
}

View File

@@ -1,118 +0,0 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <cassert>
#include "triton/driver/context.h"
#include "triton/driver/module.h"
#include "triton/tools/sys/getenv.hpp"
#include "triton/tools/sys/mkdir.hpp"
namespace triton
{
namespace driver
{
/* ------------------------ */
// BASE //
/* ------------------------ */
context::context(driver::device *dev, CUcontext cu, bool take_ownership):
polymorphic_resource(cu, take_ownership),
dev_(dev), cache_path_(get_cache_path()) {
}
context::context(driver::device *dev, host_context_t hst, bool take_ownership):
polymorphic_resource(hst, take_ownership),
dev_(dev), cache_path_(get_cache_path()){
}
context* context::create(driver::device *dev){
switch(dev->backend()){
case CUDA: return new cu_context(dev);
case Host: return new host_context(dev);
default: throw std::runtime_error("unknown backend");
}
}
driver::device* context::device() const {
return dev_;
}
std::string context::get_cache_path(){
//user-specified cache path
std::string result = tools::getenv("TRITON_CACHE_PATH");
if(!result.empty()){
if(tools::mkpath(result)==0)
return result;
}
//create in home
result = tools::getenv("HOME");
if(!result.empty())
{
result = result + "/.triton/cache/";
if(tools::mkpath(result)==0)
return result;
}
//couldn't find a directory
return "";
}
std::string const & context::cache_path() const{
return cache_path_;
}
/* ------------------------ */
// Host //
/* ------------------------ */
host_context::host_context(driver::device* dev): context(dev, host_context_t(), true){
}
/* ------------------------ */
// CUDA //
/* ------------------------ */
// import CUdevice
CUdevice cu_context::get_device_of(CUcontext context){
dispatch::cuCtxPushCurrent_v2(context);
CUdevice res;
dispatch::cuCtxGetDevice(&res);
dispatch::cuCtxPopCurrent_v2(NULL);
return res;
}
// wrapper for cuda context
cu_context::cu_context(CUcontext context, bool take_ownership): driver::context(new driver::cu_device(get_device_of(context), false),
context, take_ownership) {
}
cu_context::cu_context(driver::device* device): context(device, CUcontext(), true){
dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, *((driver::cu_device*)dev_)->cu());
// dispatch::cuCtxPopCurrent_v2(NULL);
}
}
}

Some files were not shown because too many files have changed in this diff Show More