256 Commits

Author SHA1 Message Date
Da Yan
0f5c6e619c [BUILD] Add the missing triton/impl to setup.py (#1042) 2023-01-09 19:03:45 +00:00
Connor Baker
c20215dad1 [FRONTEND] Update PTX/SM support for LLVM14 (PR #1038 redux) (#1039)
=
2023-01-09 10:31:55 -08:00
Keren Zhou
733301ff31 [Backend] Rewrite code for linking external library to expose more inlining opportunities (#1037)
- Also make it cleaner. 
- And mark out the code needs to be fixed in `semantic.py`.
2023-01-08 13:44:29 -08:00
Shintaro Iwasaki
ff399fbc20 [Build] Support GCC 8.x to build Triton (#1036) 2023-01-06 19:36:14 -08:00
Keren Zhou
4023149ee3 [Frontend] Convert constexpr to value for store and load ops (#1030)
Fixing problem 2 in https://github.com/openai/triton/issues/1017

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-01-05 14:40:16 -05:00
Gregory Axler
2193bee94e [Example] Fix the compile function in copy_strided.py (#1029) 2023-01-05 10:37:41 -08:00
Sophia Wisdom
411bacb2a8 [FRONTEND] Add logical operations on constexprs (#1033) 2023-01-04 18:06:32 -08:00
Sharad Vikram
bc73bbb12c [FRONTEND] Fix argmin/max output type (#1012)
Currently Triton returns tensors with the input types rather than i32
when doing reduce argmax/argmin.
2023-01-03 23:12:16 -08:00
Keren Zhou
8460ea3df1 [Frontend] Fix import for libdevice (#1028)
This is a hotfix for issue 1 in
https://github.com/openai/triton/issues/1017
2023-01-03 15:48:05 -08:00
Keren Zhou
678b9f53a2 [Backend] Use post-order traversal for liveness numbering (#1027)
Also add tests for `tt.trans`.
2023-01-03 15:11:54 -08:00
goostavz
0e8590f1c9 [BACKEND] Add generic support of convert_layout from distributed to shared (#1025) 2022-12-30 11:29:58 -08:00
fdrocha
194ba103b1 [BUILD] Fixed error when compiling in systems with multiple versions of python installed (#1019) 2022-12-29 15:10:34 -08:00
goostavz
1d3029faf8 [Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation;
2, move the indices calculation emitting logics into
ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by
templates. Refer to the discussion in this thread by @LyricZhao :
https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
2022-12-29 11:19:59 -08:00
Yan Chunwei
2ba74d2729 [OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014)
Continue the work https://github.com/openai/triton/pull/990

# Background
The `versionMinor` in MmaEncodingAttr holds some states of DotOp's
operands in Volta, while such operands will be modified by some
patterns, making the states out-of-date.

This PR helps to correct the states.

# Implementation
It adds three new patterns:

1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding
the MmaEncodingAttr instances with wrong states and create new correct
ones for them,
2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating
the wrong MmaEncodingAttr instances with new correct ones, currently it
supports the following Ops
    a. `convert_layout[X -> mma]`
    b. `arith.constant SplatAttr : !tensor<mma>`
    c. `dot ... : !tensor<mma>`

# Limitation
This PR chooses the mapping way to bypass the IR walk complexity from
the circular dependency between dot_operand[parent] and mma.
We use the MmaEncodingAttr instance as the mapping key, but there might
be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the
same wrong MmaEncodingAttr instance.
To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID
field to MmaEncodingAttr.
2022-12-28 12:24:01 +08:00
Keren Zhou
fd2da4aff6 [BACKEND] Support splat constant on the DotOperandLayout (#1008) 2022-12-22 00:48:46 -08:00
Sharad Vikram
925d3d7f98 [FRONTEND] Export broadcast and broadcast_to in triton.language (#1007) 2022-12-22 01:57:33 +00:00
Keren Zhou
b5aafb0dab [FRONTEND] Fix 3d indexing (#1006) 2022-12-21 12:52:32 -08:00
Philippe Tillet
20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00
Yang Hau
8650b4d1cb [DRIVER] Fix typos (#939) 2022-12-02 11:13:46 -08:00
Crutcher Dunnavant
44f577984d Fix format double substitution bug: {i} => {{i}} (#886)
The previous `{i}` was silently expanding to the `i` from the
enumeration loop on `regular_args` (when it wasn't empty).
2022-11-20 11:44:42 -08:00
Crutcher Dunnavant
0e4691e6dd [FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883)
Ran mypy over `build_extern.py`, cleaned up type annotations.

Found a fixed a bug where `ExternLibrary(format=)` was being ignored.
2022-11-17 18:45:30 +01:00
Natalia Gimelshein
0d7e753227 [TESTING] use torch.int for autotuning cache (#840)
For stupid reasons, ops on int8 are 3 times slower than on int, and for
another set of stupid reasons we are not using cudaMemset for `zero_`,
so using `int8` buffer in `do_bench` makes it slow.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-11-04 18:05:16 -07:00
Shintaro Iwasaki
77bc5187b5 Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.

## What this PR does

- Skip unsupported tests on P100.
  - Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.

Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).

### Results on P100 (Google Cloud)

```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```

<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>

### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
  OS: Ubuntu 18.04 LTS
  Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried.  I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed).  Needed to wait an hour for approval
```

### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"

## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools

## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'

## 4. Test
pytest test/unit
```

### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
```

</p></details>
2022-11-03 00:11:52 -07:00
Chenggang Zhao
f16138d447 [Frontend] Interface fixes for libdevice (#830)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
2022-11-01 10:51:58 -07:00
Mark Saroufim
578ada7740 [DOCS] Add install from source instructions to README (#821) 2022-10-31 11:08:18 -07:00
Phil Tillet
6311d70406 Revert "[BUILD] Now using cibuildwheel default"
This reverts commit 584086f08c.
2022-10-29 17:15:47 -07:00
Phil Tillet
584086f08c [BUILD] Now using cibuildwheel default 2022-10-29 16:59:06 -07:00
Keren Zhou
3ca667dfa8 [Frontend] Return a scalar if all input args are scalar (#816) 2022-10-28 23:27:06 -07:00
Yanbo Liang
5ca1ed0101 Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during
triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need
add missing triton support for bf16/fp16/fp64.
2022-10-24 19:41:25 -07:00
Keren Zhou
db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00
Twizzes
ddae106c0e [DOCS] Update installation.rst to fix windows build error (#747) 2022-10-13 13:27:15 -07:00
Keren Zhou
bc98aead33 [Backend] Fix for mov.u8 (#766)
Init a potential fix for mov.u8 which is not supported by ptx for now.
Use mov.u16 instead and cast it to u8.
2022-10-12 14:32:27 -07:00
Yu Guo
71b46acc42 [IR] Added special-purpose dequantize instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
2022-10-12 14:14:45 -07:00
Philippe Tillet
33e6f0df7f [DRIVER] Bumped CUDA requirement to 11.4+. This is to avoid bad performance surprises as older ptxas are much slower. (#769)
This also makes codegen simpler by avoiding special handling of eviction policies
2022-10-12 12:02:30 -07:00
Philippe Tillet
af76c989eb [RUNTIME] Make entry point cache key depend on triton version hash (#765) 2022-10-11 13:24:30 -07:00
Bin Bao
09cc2d454b [FRONTEND] Fix a bool tensor storing problem (#746) 2022-10-10 12:11:50 -07:00
Felipe Petroski Such
5d4b26d380 [RUNTIME] support multiple devices in the same process (#757) 2022-10-09 20:30:04 -07:00
Chris
9a11a567ce [DOCS] Fixed typos in 01-vector-add.py (#751) 2022-10-09 18:12:46 -07:00
Keren Zhou
11345e9b74 [RUNTIME] Add callback functions for external tools (#738) 2022-10-05 14:46:55 -07:00
Philippe Tillet
bdfdb9a1d2 [RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742) 2022-10-05 11:00:32 -07:00
shenggan
77c752dc78 [RUNTIME] remove fixed cu_include_dir (#739)
Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
2022-10-04 19:49:57 -07:00
Natalia Gimelshein
d3c925db8a [FRONTEND] properly broadcast scalar where condition (#736) 2022-10-04 12:44:03 -07:00
fdrocha
2b0f877fad [RUNTIME] Support environments with multiple cudalibs (#733) 2022-10-03 18:36:24 +00:00
Keren Zhou
4a2d3b7d79 [RUNTIME] Dump llvm, ttir, and sass to help debugging (#732) 2022-10-03 00:39:52 +00:00
Natalia Gimelshein
f55960e773 [FRONTEND] fix broadcasting for where (#729)
Fixes #532, all 3 inputs to where have to be broadcast together.
2022-10-01 13:18:47 -07:00
Phil Tillet
b244db06da [TUTORIALS] Attention tutorial fixup 2022-09-30 19:31:43 -07:00
Shintaro Iwasaki
7b61303ea1 [CODEGEN] Fix extract_N_bufferable in layout analysis (#728) 2022-09-30 12:21:22 -07:00
Shintaro Iwasaki
ae59f51c2d [CODEGEN] Fix an inliner to call a function with a phi-node (#727) 2022-09-29 21:36:40 -07:00
albanD
f45e31ba7c [FRONTEND] Make sure to hold the gil when creating python objects (#726)
Without this patch, a debug version of python complains that:
```
Fatal Python error: Python memory allocator called without holding the GIL
Python runtime state: initialized
```
2022-09-29 18:06:22 -07:00
Philippe Tillet
dad97528b2 [TESTING] allclose fixup (#724) 2022-09-28 22:49:05 +00:00
Jason Ansel
998fd5f9af [FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
2022-09-24 13:41:47 -07:00
Shintaro Iwasaki
3ac929b48b [BUILD] Download pybind11 in setup.py (#703)
Based on the discussion in #700, this PR enables downloading pybind11 in
`setup.py` without `git submodule` instead of copy-pasting pybind11
code. The downloaded pybind11 will be in `~/.triton/pybind` (like
`llvm`).
2022-09-23 15:54:07 -07:00
Jason Ansel
579c03615d [FRONTEND] Reduce number of compiles in JITFunction (#704)
I suspect this was the cause of the "new compiles even on a warm cache"
behavior I was seeing, though haven't 100% confirmed it.

Python `set()` iteration order is nondeterministic when you create a new
process. So the same args could produce different `instance_descriptor`s
and have false cache misses.
2022-09-23 21:44:52 +00:00
Philippe Tillet
25e1b36785 Revert "[pybind11] Use git-submodule for pybind11" (#701)
Reverts openai/triton#699
2022-09-23 12:25:38 -07:00
Shintaro Iwasaki
61d104ab3a [FRONTEND] Use git-submodule for pybind11 (#699)
This PR changes the `pybind11` source code management from copy-paste to
a package controlled by git-submodule.

See the discussion in #694 for details.
2022-09-23 09:55:03 -07:00
Philippe Tillet
8c3d4d5749 [RUNTIME] now decoupling entry point from cubin (#696) 2022-09-22 16:44:22 -07:00
Shintaro Iwasaki
df67068bb0 [pybind11] Update pybind11 to 2.10.0 (#691)
This PR updates the version of pybind11 to 2.10.0 (the latest stable).
2022-09-21 20:18:02 -07:00
Philippe Tillet
677ddae618 [FRONTEND] Add warmup for triton.jit() (#684)
This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object

Co-authored-by: Jason Ansel <jansel@jansel.net>
2022-09-21 19:13:20 +00:00
Jason Ansel
6abe813d1c Fix issue breaking cudagraphs (#685)
@ngimel figured this one out. 

The errors we were seeing from cudagraphs capture were coming from
`cuStreamGetCtx` which is not allowed while a stream is capturing.

It appears the result of `cuStreamGetCtx()` isn't even used, so I
believe it can just be removed.
2022-09-21 10:20:48 -07:00
Philippe Tillet
e318185eb4 [DOCS] Improved README.md wording (#683)
Initial wording dates from a time where nobody knew Triton, and
comparing it to CUDA helped differentiate it from other existing DSLs.
But nowadays this comparison doesn't make much sense; Triton is its own
thing, and some people may even still be more productive in CUDA than
Triton -- language preferences are subjective after all.
2022-09-20 18:09:43 -07:00
Philippe Tillet
7dc2a70edb Revert "Add .warmup() for triton.jit()" (#682)
Reverts openai/triton#671

It seems like for some reason this caused out-of-memory errors on some
of our internal workloads. I'm reverting this so that HEAD can be used
in production at OpenAI, and I will work on digging into this issue
asynchronously.
2022-09-20 16:05:14 -07:00
Philippe Tillet
48f30550f1 [FRONTEND] Now using raw compiler syscalls when possible (#678) 2022-09-19 21:01:36 -07:00
Jason Ansel
93b1adc53b [FRONTEND] Add .warmup() for triton.jit() (#671) 2022-09-18 23:09:34 -07:00
Phil Tillet
82956e5d6b [PACKAGING] Added missing package 2022-09-18 17:34:05 -07:00
Philippe Tillet
2baf333d44 [DOCS] Fixed typos (#670) 2022-09-18 17:13:12 -07:00
Jason Ansel
49f6bc3f2b [FRONTEND] Fix filename too long error in new runtime (#669) 2022-09-18 21:26:29 +00:00
Phil Tillet
00f4ef6958 [CI] wheel/docs workflows now only run on V100 machine 2022-09-18 13:28:35 -07:00
Jason Ansel
e647402fd3 Fix warning in generated C code (#667) 2022-09-18 12:57:32 -07:00
Philippe Tillet
4a77dfb042 [FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
2022-09-18 08:51:48 -07:00
Ian Bearman
889d9e34a1 [REPO] update gitignore (#666)
Update `.gitignore` to include `.vs` and `.vscode`
2022-09-17 14:25:28 -07:00
Shintaro Iwasaki
c668d6596e [DOCS] Fix spelling (#664)
This PR applies minor spelling fix in comments and string literals to
`master`. It shouldn't hurt anything.
2022-09-16 12:26:40 -07:00
Sophia Wisdom
4580a04710 [FRONTEND] Improve error message for CPU tensors (#654)
Redo of #651 against master. Fixes #525 by catching CUDA error when we
check pytorch tensor size and rethrowing a more informative error that
says why we failed.
2022-09-14 14:26:42 -07:00
Philippe Tillet
cfbbc7b43a [CI] Added V100 tag to disambiguate self-hosted runners (#653) 2022-09-14 13:47:50 -07:00
Yunxing Dai
59a8e25f43 [DOCS] Fix typo (#650) 2022-09-14 12:17:05 -07:00
Da Yan
437ced38c2 fp8 <> bf16 conversion (#637)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-08-30 14:20:12 -07:00
Da Yan
210a296699 [BACKEND] bf16 flash-attention (#636) 2022-08-26 20:40:55 -07:00
Daniil Fukalov
fe0c29b9ec Fix inconsistent struct declaration instead of class. (#632)
Looks like typo.
2022-08-26 16:20:21 -07:00
Phil Wang
7394d732ad [DOCS] support for variable head dimensions in flash attention triton tutorial (#623) 2022-08-15 19:16:49 -07:00
Da Yan
3e2953f357 Allow multiple_of and max_contiguous to accept n-d values (#617) 2022-08-10 09:59:32 -07:00
Daniil Fukalov
cc79376222 Fix deprectaion warning on CreateGEP(Value *, ArrayRef<Value *>, const Twine &) (#608)
This variant of CreateGEP() is already removed in LLVM 14.
2022-08-07 17:10:18 -07:00
Daniil Fukalov
7b91c7befd Fix "warning: control reaches end of non-void function". (#607) 2022-08-02 16:12:48 -07:00
Sharad Vikram
968f59027e Expose module.print in pybind (#604) 2022-07-29 21:36:08 -07:00
Anton Kostin
923d468187 Update LICENSE (#602) 2022-07-25 09:30:03 -07:00
Jason Ansel
027321cdcf [FRONTEND] Make tl.rand() 1-exclusive (#601) 2022-07-24 17:47:23 -07:00
Jason Ansel
e02e56dc63 [FRONTEND] Add missing rfloordiv (#598)
* [FRONTEND] Add missing rfloordiv

* fix tests
2022-07-23 21:54:12 -07:00
Philippe Tillet
ab56d310dd [BACKEND][IR] Fixed up internal dtype size for booleans (1bit -> 8bit) (#600) 2022-07-23 20:08:03 -07:00
Da Yan
f28caddbf8 [FRONTEND] Allow tl.where to select pointers (#595) 2022-07-21 09:54:27 -07:00
Keren Zhou
af85f5fa46 [FRONTEND] Refresh cache when the source code of outlined functions are changed (#590) 2022-07-20 17:34:07 -07:00
daadaada
9b2bc88d11 [BACKEND] Better bf16 support (#588) 2022-07-19 21:22:37 -07:00
Philippe Tillet
86cab58d89 [CI] Changed dev wheel date to UTC time to match CRON schedule (#587) 2022-07-18 14:54:13 -07:00
Phil Tillet
5b04331dd2 [TUTORIALS] Added more credits in fused attention tutorial 2022-07-13 23:48:58 -07:00
Jason Ansel
0a3f3d5f25 [PACKAGING] Include triton/language/libdevice.10.bc in package data (#582) 2022-07-13 23:45:27 -07:00
Keren Zhou
4912916c11 [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) 2022-07-13 15:52:21 -07:00
Phil Tillet
971f5782b4 [tutorials] Added flash attention credits in tutorial 2022-07-11 18:56:48 -07:00
Philippe Tillet
d5eb9bc230 [tutorial] Added bwd in fused attention example (#579)
Doesn't work on V100
2022-07-11 15:43:46 -07:00
Jason Ansel
c9a2b9c7d4 [FRONTEND] Add missing args to get_simd_tflops() (#578) 2022-07-11 14:37:59 -07:00
Philippe Tillet
4a399a7e40 [BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
2022-07-06 20:03:04 -07:00
vesuppi
22105bc33b [FRONTEND] Added type check in semantic arange (#572) 2022-07-03 15:25:37 -07:00
Keren Zhou
4bf509889b [BUILD] Change the default build type to Release (#571) 2022-07-01 12:17:22 -07:00
Keren Zhou
a74cce375f [FRONTEND] Raise broadcast error (#555) 2022-06-30 17:32:07 -07:00
Philippe Tillet
f733327ba4 [BACKEND][CODEGEN] Disabling L2 residency control by default (#570) 2022-06-29 17:05:13 -07:00
Natalia Gimelshein
1bbb2430d9 [TUTORIALS] adjust heuristics for dwdb kernel (#565) 2022-06-29 17:00:22 -07:00
Kashif Rasul
1895ceaa2d [TUTORIAL] Fix f-string for older python (#569)
fixes issue #568
2022-06-29 09:39:10 -07:00
Philippe Tillet
feb7a2a0dc [FRONTEND] Hotfix for store argument order (#567) 2022-06-28 00:24:02 -07:00
Philippe Tillet
5b4c8f221e [BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
2022-06-27 11:49:19 -07:00
Keren Zhou
87413bc925 [BACKEND] Fix layout convert for non-contiguous input (#564) 2022-06-25 23:12:03 -07:00
Keren Zhou
d345ddf837 [DOCS] Separate atomic cas from other atomic operations since operands are very different (#559) 2022-06-22 17:51:17 -07:00
Keren Zhou
b02bac41ba [CI] Change cache dir (#561) 2022-06-22 11:44:35 -07:00
Keren Zhou
a428cf0bb2 [FRONTEND] Fix pytorch warning. (#560)
UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc').
2022-06-20 20:12:09 -07:00
Keren Zhou
b5e728cb14 Add argmin argmax (#552) 2022-06-15 13:55:20 -07:00
Jason Ansel
6b9756532f [BACKEND] Remove print in coalesce.cc (#551) 2022-06-15 13:13:20 -07:00
Madeleine Thompson
8ce2c12e33 [PYTHON] move ephemeral files to homedir (#549)
This prevents potential conflicts with other users on shared machines.
2022-06-13 19:37:52 -07:00
Keren Zhou
93209c07e0 [BACKEND][CODEGEN] Fix reduce uint (#547) 2022-06-13 16:43:57 -07:00
Philippe Tillet
58c8889235 [FRONTEND] Fix scanline layout (#548) 2022-06-13 16:21:10 -07:00
Natalia Gimelshein
7094657aa9 [FRONTEND] fix bool conversion of floating types (#545) 2022-06-13 15:52:37 -07:00
Keren Zhou
38573d1261 [FRONTEND] Return allocated registers and spilled registers for users (#541) 2022-06-07 18:37:12 -07:00
Mengchi Zhang
2cdc6d35c4 [FRONTEND] Give col_per_thread an initial value to make the compiler happy (#535)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-06-06 12:48:23 -07:00
TC
f13cbaab9f [FRONTEND] assert that num_warps is a power of 2 (#539) 2022-06-06 11:37:08 -07:00
Philippe Tillet
751e325d2e [TUTORIALS] Fixed typo 2022-06-05 13:33:21 -07:00
Philippe Tillet
801c8a4c92 [TUTORIALS] Fixed typo 2022-06-05 12:32:07 -07:00
Philippe Tillet
8876e53206 [BACKEND] Restored reduction bugfixes 2022-06-03 11:38:52 -07:00
Philippe Tillet
a60374a597 Revert "[BACKEND] Various bug fixes; making reductions faster (#533)".
This is a more stable commit that produce bitwise identical code to earlier
versions. Using commits after this one may lead to slightly different numerics
2022-06-03 11:36:06 -07:00
Philippe Tillet
efa04cac1f [FRONTEND] A couple of bugfixes (#534) 2022-06-02 16:57:37 -07:00
Philippe Tillet
3e7500dfe6 [BACKEND] Various bug fixes; making reductions faster (#533) 2022-05-31 17:14:44 -07:00
Bert Maher
37037bb3be [FRONTEND] Default cache dir to /tmp/triton_$USER (#527) 2022-05-27 13:51:05 -07:00
Philippe Tillet
c82a206684 [FRONTEND] Better dot error message (#531) 2022-05-26 17:41:09 -07:00
Philippe Tillet
0e2883020a [BACKEND] Fixed typo in alignment analysis (#528) 2022-05-25 20:01:19 -07:00
Bert Maher
43fec2adca [FRONTEND] Add binding for create_int_to_ptr (#526) 2022-05-25 15:26:18 -07:00
Philippe Tillet
011bc83c1b [FRONTEND] For loops now promote initial value (#524) 2022-05-24 13:20:10 -07:00
Natalia Gimelshein
96bff90471 [FRONTEND] faster jit function launch (#523)
With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
2022-05-24 12:08:49 -07:00
daadaada
d5eaa8dfa0 Making the generated Triton IR deterministic & a script to compare cached assembly (#522) 2022-05-24 08:56:36 -07:00
Shantanu
80f6a2698b [FRONTEND] Ensure version_key is called at most once (#519)
Co-authored-by: hauntsaninja <>
2022-05-23 13:40:08 -07:00
daadaada
205a493b10 [FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)
Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
2022-05-21 09:45:54 -07:00
Jiabao Lei
abea3dc2c6 [FRONTEND] provide device kwargs && fix fstring error for py<3.8 (#515)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-05-14 16:21:46 -07:00
Philippe Tillet
d35617bea1 [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) 2022-05-14 15:26:13 -07:00
Mengchi Zhang
d1a22a94e6 [FRONTEND] Add empty return value and remove protect to open the access to contained_tys_vec_t (#514)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-05-13 11:46:12 -07:00
Jason Ansel
d954a05989 [FRONTEND] Handle torch.uint8 args (#513)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-05-12 13:07:39 -07:00
Philippe Tillet
0835a4fb05 [TUTORIALS] Removed #noformat in layer norm tutorial 2022-05-12 12:41:25 -07:00
Philippe Tillet
c736ba7c3e [TUTORIALS] Fixed formatting 2022-05-12 12:31:23 -07:00
Philippe Tillet
cd30a99aa2 [TUTORIALS] fixed formatting 2022-05-12 12:28:22 -07:00
Philippe Tillet
d87435e536 [TUTORIALS] Layer norm tutorial now uses residency control (#510) 2022-05-05 19:53:54 -07:00
Sriram Murali
7c9bc5a47b [CODEGEN] Change return type of generator::packed_type to appease build warnings (#507) 2022-05-04 20:03:37 -07:00
Philippe Tillet
95feb10ec9 [FRONTEND] fixup (#505) 2022-04-30 14:25:06 -07:00
Philippe Tillet
11a908655d [FRONTEND] Fixup 2022-04-29 14:35:09 -07:00
Phil Tillet
cd78ce4888 [FRONTEND] Improved error message when assigning None to non-constexpr 2022-04-29 09:17:54 -07:00
Philippe Tillet
ae2a1ab225 [BACKEND] Alignment pass improvements (#503) 2022-04-25 21:16:00 -07:00
Philippe Tillet
7d544799a0 [BACKEND] Now disabling L2 eviction policy for sm < 80 2022-04-25 09:35:36 -07:00
Philippe Tillet
3ca792043f [TEST] Added test for vectorization 2022-04-24 13:50:48 -07:00
Philippe Tillet
bda209002e [BACKEND][CODEGEN] vectorization bugfix (#502) 2022-04-23 13:18:33 -07:00
Philippe Tillet
0cc3b1129b [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) 2022-04-21 23:56:01 -07:00
Philippe Tillet
7d6c504e8d [TESTING] Added testing utilities for fixing clock and using cuda-memcheck (#500) 2022-04-21 22:40:10 -07:00
Philippe Tillet
073be1d2ee [FRONTEND] check that tensors have power-of-two number of elements (#499) 2022-04-14 19:30:02 -07:00
Philippe Tillet
5c7122004c [TUTORIALS] Tutorial shouldn't expose clock. Just removed it. 2022-04-14 17:33:44 -07:00
Philippe Tillet
dc4d40faec [FRONTEND] now mangle constexpr float containing "e-" 2022-04-14 10:26:48 -07:00
Philippe Tillet
25f6689508 [FRONTEND] rename current stream monkey patch (#495) 2022-04-13 11:45:55 -07:00
Philippe Tillet
76bfac9f15 [FRONTEND] Improved constexpr handling (#493) 2022-04-12 00:02:54 -07:00
Philippe Tillet
14b0fd4cfb [FRONTEND] Added possibility for users to customize current stream query (#492) 2022-04-07 12:11:32 -07:00
Philippe Tillet
6424771f55 [CI] Documentation fixup 2022-04-07 09:42:35 -07:00
Philippe Tillet
9f08ecd684 [FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
2022-04-06 16:13:53 -07:00
Philippe Tillet
2bed6fc850 [LANG] Added support for device functions (#484) 2022-04-03 20:58:16 -07:00
apd10
e85c7a7fc7 Bugfix in ptxas path. (#487)
Bug: "ret" value is destroyed when a failing "ptxas --version" is run
overwriting the previous valid "ret" value.

Fix: keep rets only for those runs which are successful. Pick the first
one
2022-03-30 20:45:41 -07:00
Philippe Tillet
bace26143d [TUTORIALS] Removed leftover print 2022-03-28 16:53:23 -07:00
Philippe Tillet
e0cc488055 [FRONTEND] Added tl.clock and tl.globaltimer (#485) 2022-03-28 16:15:43 -07:00
Philippe Tillet
76a9ee50a8 Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)
This reverts commit 539961072c.
2022-03-24 17:16:50 -07:00
Philippe Tillet
ea6d1f1b85 [DRIVER] LLVM driver fixup (#482)
Current way of doing things is probably not super thread safe. init is shared between threads and some threads my not call the LLVMInitialize* function.
2022-03-23 00:24:45 -07:00
Keren Zhou
a4f68165cd [FRONTEND] Hot fix for lineno (#481)
Override __reduce__ to make CompilationError pickable and print out error messages
2022-03-22 22:09:49 -07:00
daadaada
539961072c [FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py
Integer signedness now moved from C++ to python
Cleaner frontend type

Co-authored-by: Phil Tillet <phil@openai.com>
2022-03-16 21:25:30 -07:00
Yongjik Kim
0dd2ec2e3a [FRONTEND] Add an assert in case we get a CPU tensor. (#478) 2022-03-16 14:38:56 -07:00
Philippe Tillet
d4d8eaf6c0 [FRONTEND] improved caching mechanism (#474)
Co-authored-by: Greg Brockman <gdb@gregbrockman.com>
Co-authored-by: Christopher Hesse <christopherhesse@users.noreply.github.com>
2022-03-15 12:20:51 -07:00
Doğukan Tuna
21f8a0646d [DOCS] Minor README.md (#470)
Added binary distribution for quick installation
2022-03-05 00:50:37 -08:00
Philippe Tillet
a50a47a85b [CODEGEN] Reverted some changes from previous PR; fixed vectorization characteristics of mma layout (#469) 2022-03-04 01:53:31 -08:00
Philippe Tillet
bb5765df5c [CODEGEN] Now padding shared memory for layout conversion (#468) 2022-03-03 22:19:05 -08:00
daadaada
d9dd97492f Use unique_ptr in ir::context_impl (#462)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-02-24 16:07:10 -08:00
Philippe Tillet
98ed7db8c1 [CODEGEN] Improvements and bugfixes (#463) 2022-02-24 14:56:24 -08:00
daadaada
a9dfdcaaa9 [FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456) 2022-02-11 22:34:42 -08:00
Philippe Tillet
9b100302d3 [FRONTEND] Now using pybind11 to release GIL (#458) 2022-02-10 01:57:39 -08:00
Philippe Tillet
40093a9878 [DOCS] Multiple versions are now supported (#457) 2022-02-09 01:32:41 -08:00
Philippe Tillet
4941bc7001 [DOCS] Some more fixes (#455) 2022-02-08 16:53:56 -08:00
Philippe Tillet
2fdf0a4fe8 [DOCS] changed build command 2022-02-08 11:45:21 -08:00
Philippe Tillet
077d6c8ff0 [DOCS] re-activated tutorials 2022-02-08 11:42:39 -08:00
Philippe Tillet
822ddcd14b [DOCS] Added versioning (#453) 2022-02-08 11:28:18 -08:00
Philippe Tillet
7b48340ffd [CI] Some fixes for the build (#451) 2022-02-06 19:11:33 -08:00
Philippe Tillet
5a8a544d10 [OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code 
* at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks.
 * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros
 * blocksparse softmax now no longer modifies any data in-place
 * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention.
  * unit tests now test backward pass
2022-02-06 18:00:45 -08:00
Philippe Tillet
69ff52ea1f [CODEGEN] removed buggy (and mostly useless) optimization in peephole pass (#449) 2022-02-05 21:37:23 -08:00
TC
137bb67fad [LANG] Add fp16 to fp8 conversion (#444) 2022-02-02 20:42:09 -08:00
Philippe Tillet
3b20170fa3 Merge pull request #448 from openai/v2.0
`v2.0` is now merged into `master`
2022-01-30 20:49:08 -08:00
Philippe Tillet
b0d6e2f322 [STYLE] run autopep 2022-01-30 20:27:44 -08:00
Philippe Tillet
2922dc141c Merge branch 'master' into v2.0 2022-01-30 20:25:01 -08:00
Philippe Tillet
807d8a1945 [ALL] Merge master (#447) 2022-01-30 20:21:20 -08:00
Philippe Tillet
bef76b142a [BACKEND] float division is now approximate by default (#446) 2022-01-29 18:29:29 -08:00
Philippe Tillet
bd52e530a0 [OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445) 2022-01-28 21:40:30 -08:00
daadaada
e68d6a7776 [BACKEND] Making the warp-level tile "more square" to increase data-reuse for tl.dot. (#442)
* Increase smem data-reuse for some layouts

* tweak

* Keep the original tiling logic for sm < 80

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-01-27 09:59:54 -08:00
daadaada
59d371c6eb [BACKEND] Added Int8 mma (#440) 2022-01-27 09:12:44 -08:00
Benjamin Lefaudeux
3a23c1dd33 [BACKEND] minor, hotfix for gcc compilation (#439) 2022-01-23 14:24:02 -08:00
Philippe Tillet
ccf9abe0ba [FRONTEND][RANDOM] Improved backward compatibility of RNG (#438)
The unsigned int PR definitely improved our RNG. However, it requires
different floating point arithmetics which, means the results are not
bit-wise identical to how they were before. This commit revives backward
compatibility, but we should change it back to the "right" way later.
2022-01-21 18:05:55 -08:00
Philippe Tillet
4c97d1ecd7 [FRONTEND] Bunch of fixes here and there (#436) 2022-01-20 10:55:59 -08:00
Philippe Tillet
e0c5709cc8 [FRONTEND] Fixed semantics bug on ptr to bool conversions (#432) 2022-01-17 18:00:03 -08:00
daadaada
2a944ded53 [TESTS] Added bfloat16 tests (#430) 2022-01-13 23:38:32 -08:00
Philippe Tillet
4c94359199 [FRONTEND] Alignment fix-up (#428) 2022-01-11 23:11:58 -08:00
Philippe Tillet
bbc78f6516 [FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427) 2022-01-11 11:08:49 -08:00
Botao Yu
bf32205edc [OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425) 2022-01-11 11:07:16 -08:00
daadaada
94a2e10fe5 [BACKEND] Add bf16 & tf32 mma supports (on A100) (#426) 2022-01-11 10:20:31 -08:00
Madeleine Thompson
efdabe6073 [STYLE] check python with flake8 (#424)
I've been using this locally to find errors without running tests, and now that we're using autopep8, it passes with minimal suppressions. This is also what turned up the issues with the tutorials, which were fixed in #422.
2022-01-07 15:28:36 -08:00
Madeleine Thompson
a70acfec77 [STYLE] add isort and autopep8 config files and check on CI (#423)
Also a fix a few more style issues from the "aggressive" mode of autopep8.
2022-01-07 13:11:34 -08:00
Madeleine Thompson
9801aa7b56 [DOCS] fix tutorials for v2.0 (#422)
- Fix meta-parameter usage on tutorials.
- Install tutorial dependencies on CI.
- Switch from `requirements-test.txt` to `extras_require` for test dependencies, and also use it for tutorial dependencies.
- Make some performance tests deterministic.
2022-01-07 12:34:38 -08:00
Madeleine Thompson
8bf551ae7a [STYLE] run autopep8 and isort (#421)
Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
2022-01-06 14:34:17 -08:00
Shantanu
6f7acad48f [CODEGEN] Avoid use of deprecated AST nodes (#418)
Co-authored-by: hauntsaninja <>
2022-01-06 12:04:33 -08:00
Madeleine Thompson
120cda015e [FRONTEND] use unsigned integers to simplify RNG (#417) 2022-01-06 10:49:09 -08:00
Philippe Tillet
001fb757fe [OPS][BLOCKSPARSE] Added .contiguous() in blocksparse inputs when necessary (#420) 2022-01-06 09:56:22 -08:00
Madeleine Thompson
0ab9d67bad uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
2022-01-05 15:27:17 -08:00
Madeleine Thompson
d8db0308cb [TEST] use numpy for reference results in test_core.py (#409)
Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support.

This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency.
2022-01-04 13:07:29 -08:00
Philippe Tillet
03f1256f60 [FRONTEND] Added volatile flag for load (#407) 2021-12-30 22:33:24 -08:00
Noah Ziems
3edc2633e9 [TUTORIALS] Fix 01-vector-add.py typo (#406) 2021-12-29 15:09:34 -08:00
Madeleine Thompson
985798f101 add missing bfloat16 repr and improve assertions (#403)
- `BF16TyID` was missing a repr implementation.
- Throw a better exception on impossible casts.
- Add a few assertions. Tested with a debug build.
- Add `pointer_dtype.__str__` to aid kernel debugging.
2021-12-23 17:01:17 -08:00
Philippe Tillet
d8fce83e7a [FRONTEND] Remade exception picklable 2021-12-21 22:14:06 -08:00
Philippe Tillet
a425f24d54 [FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
2021-12-21 21:29:47 -08:00
Philippe Tillet
2509124dd0 [DRIVER] Fixed some issue with how ptxas is used (#399)
Now using tmpnam and properly deleting temporaries when an exception is raised
2021-12-21 14:31:51 -08:00
daadaada
39d4bfed83 [OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
2021-12-21 09:56:10 -08:00
Madeleine Thompson
5cdb948c05 [FRONTEND] signed-integer math fixes and testing (#395)
- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway.
- Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now.
- When testing signed integer operations, use random numbers from the full range of the type.
- Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type.
- Fix a bad `CompilationError` invocation.
- Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
2021-12-21 09:46:05 -08:00
daadaada
4a8953efa3 [FRONTEND] Replace the legacy print call in triton.cc with the SlotTracker-based one. (#396)
The legacy print call will assign names (e.g., %10) to values, which can be undesirable in some cases.
2021-12-18 18:03:22 -08:00
Madeleine Thompson
fa62b4a8f6 [FRONTEND] better stringification (#394)
- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace.
- Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
2021-12-17 20:11:45 -08:00
Philippe Tillet
4e93b41c52 [GENERAL] Some minor fixups (#393)
* [RUNTIME] Now displaying error message when generated PTX is invalid

* [CODEGEN] Now converting `if` condition to bool implicitly
2021-12-17 18:06:21 -08:00
Philippe Tillet
e062812969 [CODEGEN] Disabled peephole for masked load + select -- masked_load
doesn't work as expected when vectorized
2021-12-17 12:44:47 -08:00
Victor
eb077fc993 [RUNTIME] fixed NVidia DLL names on Windows (#392) 2021-12-16 22:09:52 -08:00
Philippe Tillet
e0b92c1380 [FRONTEND] Reverted from .random import *. There are still some
namespace errors in the Triton frontend apparently
2021-12-16 18:37:51 -08:00
Philippe Tillet
558555630f [FRONTEND] Added xor_sum 2021-12-16 17:55:35 -08:00
Madeleine Thompson
e575ae3443 [FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace.
- Make an undeclared dependency on `pytest` explicit.
- Fix deprecated `description-file` use.
- `#ifdef` out a deprecated `PyEval_InitThreads` call.
- Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests.
- Fix a deprecated cast in `test_core.py`.
- Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7.
- Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer.
- Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred.
- Fix a few bad escapes.
2021-12-10 15:19:20 -08:00
Philippe Tillet
9def2424ab [RUNTIME] Fix typo in IfExp 2021-12-09 15:14:41 -08:00
Philippe Tillet
e31b9b4e66 [RUNTIME] Better support for None (#387)
* regression test fails but it doesn't make sense to me.
2021-12-09 13:21:22 -08:00
Victor
73b04d71b2 Fixes for building on Windows (#382)
* make C++ code compatible with Windows + MSVC

* added dlfcn-win32 for cross-platform dlopen

* fixed building and pip install on Windows

* fixed shared library file name under Windows
2021-12-07 14:10:58 -08:00
Victor
0ff1a26b70 fixed p2p tests failing when there are no supported p2p devices (#386) 2021-12-06 18:14:03 -08:00
Philippe Tillet
f23bf55f15 [RUNTIME] release the gil on launch (#383) 2021-12-03 13:01:01 -08:00
Philippe Tillet
8ec9f037bb [BACKEND/CODE_GEN] Fixed float32 matmul problem (#380) 2021-11-30 22:00:56 -08:00
Philippe Tillet
c86ad9c9ab [FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379) 2021-11-29 19:11:26 -08:00
daadaada
1296eb877b [RUNTIME] Config hook v2.0 (#373)
* Add pre_hook to triton.Config
* Use argument names in triton.heuristics
* Update base perf
* Remove meta from heuristics
2021-11-21 11:20:59 -08:00
Philippe Tillet
5693b582ea [RUNTIME] Now using pybind11 to avoid memory leaks (#377) 2021-11-21 02:30:22 -08:00
Philippe Tillet
edd4b0c8b7 [CODEGEN] Fixed issue with jit function passed as constexpr 2021-11-16 09:53:34 -08:00
Philippe Tillet
5b7ba3eb96 [CODEGEN] Reverted to old launch method (memory leak?) 2021-11-16 01:21:03 -08:00
Philippe Tillet
791b953b21 [CODEGEN] Reverted to old way to query current stream 2021-11-16 00:17:27 -08:00
Philippe Tillet
b908095872 [VERSION] Bumped triton.__version__ to 2.0.0 2021-11-12 15:10:36 -08:00
Philippe Tillet
01cc3d4503 [RUNTIME] Restored do_not_specialize (#374) 2021-11-12 15:06:55 -08:00
Philippe Tillet
e66bf76354 [RUNTIME] Bunch of bugfixes (#372) 2021-11-12 00:55:00 -08:00
Philippe Tillet
f7ab96cfd7 [FRONTEND] Fixed some issues with constexpr 2021-11-09 13:03:09 -08:00
daadaada
9a02dddf29 Fix sdd_lut (#368) 2021-11-08 08:25:05 -08:00
Philippe Tillet
5d54352164 [FRONTEND] Significantly reduce kernel launch time (#367) 2021-11-04 13:25:24 -07:00
Philippe Tillet
2acaa4d0dd [LANG] Added support for constexpr (#361) 2021-10-30 00:32:58 -07:00
Philippe Tillet
b7f0e87dc2 [DRIVER] Removed std::cout log message 2021-10-29 10:42:10 -07:00
Philippe Tillet
770ea96cca [PACKAGING] Bumped dev version to 2.0.0 2021-10-29 01:28:17 -07:00
Philippe Tillet
969d6de8a2 [PACKAGING] Bumped dev version to 1.1.2 2021-10-29 01:26:21 -07:00
Philippe Tillet
2d6df9b518 [PACKAGING] Bumped dev version to 1.1.2 2021-10-29 01:24:19 -07:00
Philippe Tillet
1b842f8e5e [CI] Now running integration tests on pull requests on branch v2.0 2021-10-29 01:11:12 -07:00
Philippe Tillet
d3e584d4ba Revert "[DRIVER] Fixed CUDA 10.1 bug (#357)" (#358)
This reverts commit d35014ba47.
2021-10-26 15:04:49 -07:00
Philippe Tillet
d35014ba47 [DRIVER] Fixed CUDA 10.1 bug (#357) 2021-10-26 11:17:06 -07:00
Philippe Tillet
5ce1b726dc [CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356) 2021-10-24 02:30:46 -07:00
daadaada
858dec8372 [CODEGEN] Add cache modifier to tl.load (#351)
* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
2021-10-17 22:14:04 -07:00
Philippe Tillet
90ded16c32 [DOCS] Added placeholder docstring for layernorm tutorial 2021-10-15 19:04:01 -07:00
323 changed files with 34646 additions and 56603 deletions

1
.clang-format Normal file
View File

@@ -0,0 +1 @@
BasedOnStyle: LLVM

57
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1,57 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# @global-owner1 and @global-owner2 will be requested for
# review when someone opens a pull request.
* @ptillet
# --------
# Analyses
# --------
# Alias analysis
include/triton/Analysis/Alias.h @Jokeren
lib/Analysis/Alias.cpp @Jokeren
# Allocation analysis
include/triton/Analysis/Allocation.h @Jokeren
lib/Analysis/Allocation.cpp @Jokeren
# Membar analysis
include/triton/Analysis/Membar.h @Jokeren
lib/Analysis/Membar.cpp @Jokeren
# AxisInfo analysis
include/triton/Analysis/AxisInfo.h @ptillet
lib/Analysis/AxisInfo.cpp @ptillet
# Utilities
include/triton/Analysis/Utility.h @Jokeren
lib/Analysis/Utility.cpp @Jokeren
# ----------
# Dialects
# ----------
# Pipeline pass
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada
# Prefetch pass
lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @daadaada
# Coalesce pass
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet
# Layout simplification pass
lib/Dialect/TritonGPU/Transforms/Combine.cpp @ptillet
# -----------
# Conversions
# -----------
# TritonGPUToLLVM
include/triton/Conversion/TritonGPUToLLVM/ @goostavz @Superjomn
lib/Conversions/TritonGPUToLLVM @goostavz @Superjomn
# TritonToTritonGPU
include/triton/Conversion/TritonToTritonGPU/ @daadaada
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @daadaada
# -------
# Targets
# -------
# LLVMIR
include/triton/Target/LLVMIR/ @goostavz @Superjomn
lib/Target/LLVMIR @goostavz @Superjomn
# PTX
include/triton/Target/PTX/ @goostavz @Superjomn
lib/Target/PTX @goostavz @Superjomn

View File

@@ -1,42 +0,0 @@
name: Documentation
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
Build-Documentation:
runs-on: self-hosted
steps:
- name: Checkout gh-pages
uses: actions/checkout@v1
with:
ref: 'gh-pages'
- name: Checkout branch
uses: actions/checkout@v1
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e .
- name: Build docs
run: |
cd docs
make html
- name: Publish docs
run: |
git checkout gh-pages
sh ./update-website.sh
eval `ssh-agent -s`
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
git remote set-url origin git@github.com:openai/triton.git
git push

View File

@@ -5,42 +5,88 @@ on:
pull_request: pull_request:
branches: branches:
- master - master
- triton-mlir
jobs: jobs:
Runner-Preparation:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], ["self-hosted", "V100"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
Integration-Tests: Integration-Tests:
needs: Runner-Preparation
runs-on: self-hosted
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Clear cache - name: Clear cache
run: | run: |
rm -r /tmp/triton/ rm -rf ~/.triton/cache/
continue-on-error: true
- name: Check imports
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton - name: Install Triton
run: | run: |
alias python='python3'
cd python cd python
pip3 install -e . TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
- name: Unit tests - name: Run lit tests
run: | run: |
cd python/test/unit cd python
pytest -vs . LIT_TEST_DIR="build/$(ls build)/test"
if [ ! -d "$LIT_TEST_DIR" ]; then
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
fi
lit -v "$LIT_TEST_DIR"
- name: Regression tests - name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted'}}
run: | run: |
cd python/test/regression cd python/test/unit/
sudo nvidia-smi -i 0 -pm 1 pytest
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc
- name: Run CXX unittests
run: |
cd python/
cd "build/$(ls build)"
ctest

View File

@@ -8,7 +8,7 @@ jobs:
Build-Wheels: Build-Wheels:
runs-on: self-hosted runs-on: [self-hosted, V100]
steps: steps:
@@ -18,7 +18,7 @@ jobs:
- name: Patch setup.py - name: Patch setup.py
run: | run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py #sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g') export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg echo "[build_ext]" >> python/setup.cfg

20
.gitignore vendored
View File

@@ -1,6 +1,20 @@
# Triton builds
build/
# Triton Python module builds
python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
# Python caches
__pycache__ __pycache__
.pytest_cache .pytest_cache
python/build/ # VS Code project files
python/triton.egg-info/ .vscode
python/triton/_C/libtriton.so .vs
# JetBrains project files
.idea
cmake-build-*

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "deps/dlfcn-win32"]
path = deps/dlfcn-win32
url = https://github.com/dlfcn-win32/dlfcn-win32.git

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -1,18 +1,27 @@
cmake_minimum_required(VERSION 3.6) cmake_minimum_required(VERSION 3.6)
include(ExternalProject) include(ExternalProject)
if(NOT TRITON_LLVM_BUILD_DIR) set(CMAKE_CXX_STANDARD 17)
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
endif()
set(CMAKE_INCLUDE_CURRENT_DIR ON)
project(triton) project(triton)
include(CTest) include(CTest)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()
# Options # Options
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
# Ensure Python3 vars are set correctly
# used conditionally in this file and by lit tests
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
# Default build type # Default build type
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
@@ -20,100 +29,207 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
endif() endif()
find_library(TERMINFO_LIBRARY tinfo) if(NOT WIN32)
find_library(TERMINFO_LIBRARY tinfo)
endif()
# Compiler flags # Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()
########## ##########
# LLVM # LLVM
########## ##########
if("${LLVM_LIBRARY_DIR}" STREQUAL "") if (NOT MLIR_DIR)
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu") if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
include_directories(${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
add_definitions(${LLVM_DEFINITIONS_LIST})
llvm_map_components_to_libnames(LLVM_LIBRARIES support core
NVPTXInfo nvptxcodegen
AMDGPUInfo AMDGPUcodegen
)
else()
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
endif()
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
if(APPLE) if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14") set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
endif() endif()
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros # sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else() else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}") set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a libLLVMNVPTXInfo.a
libLLVMAMDGPUDisassembler.a libLLVMAMDGPUDisassembler.a
libLLVMMCDisassembler.a libLLVMMCDisassembler.a
libLLVMAMDGPUCodeGen.a libLLVMAMDGPUCodeGen.a
libLLVMMIRParser.a libLLVMMIRParser.a
libLLVMGlobalISel.a libLLVMGlobalISel.a
libLLVMSelectionDAG.a libLLVMSelectionDAG.a
libLLVMipo.a libLLVMipo.a
libLLVMInstrumentation.a libLLVMInstrumentation.a
libLLVMVectorize.a libLLVMVectorize.a
libLLVMLinker.a libLLVMLinker.a
libLLVMIRReader.a libLLVMIRReader.a
libLLVMAsmParser.a libLLVMAsmParser.a
libLLVMFrontendOpenMP.a libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a libLLVMAsmPrinter.a
libLLVMDebugInfoDWARF.a libLLVMDebugInfoDWARF.a
libLLVMCodeGen.a libLLVMCodeGen.a
libLLVMTarget.a libLLVMTarget.a
libLLVMScalarOpts.a libLLVMScalarOpts.a
libLLVMInstCombine.a libLLVMInstCombine.a
libLLVMAggressiveInstCombine.a libLLVMAggressiveInstCombine.a
libLLVMTransformUtils.a libLLVMTransformUtils.a
libLLVMBitWriter.a libLLVMBitWriter.a
libLLVMAnalysis.a libLLVMAnalysis.a
libLLVMProfileData.a libLLVMProfileData.a
libLLVMObject.a libLLVMObject.a
libLLVMTextAPI.a libLLVMTextAPI.a
libLLVMBitReader.a libLLVMBitReader.a
libLLVMAMDGPUAsmParser.a libLLVMAMDGPUAsmParser.a
libLLVMMCParser.a libLLVMMCParser.a
libLLVMAMDGPUDesc.a libLLVMAMDGPUDesc.a
libLLVMAMDGPUUtils.a libLLVMAMDGPUUtils.a
libLLVMMC.a libLLVMMC.a
libLLVMDebugInfoCodeView.a libLLVMDebugInfoCodeView.a
libLLVMDebugInfoMSF.a libLLVMDebugInfoMSF.a
libLLVMCore.a libLLVMCore.a
libLLVMRemarks.a libLLVMRemarks.a
libLLVMBitstreamReader.a libLLVMBitstreamReader.a
libLLVMBinaryFormat.a libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a libLLVMAMDGPUInfo.a
libLLVMSupport.a libLLVMSupport.a
libLLVMDemangle.a libLLVMDemangle.a
) libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
endif()
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
endif() endif()
include_directories("${LLVM_INCLUDE_DIRS}")
# Python module # Python module
if(BUILD_PYTHON_MODULE) if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module") message(STATUS "Adding Python module")
# Build CUTLASS python wrapper if requested
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}") set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}") include_directories("." ${PYTHON_SRC_PATH})
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL "")) if (PYTHON_INCLUDE_DIRS)
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc) include_directories(${PYTHON_INCLUDE_DIRS})
add_definitions(-DWITH_CUTLASS_BINDINGS) else()
set(CUTLASS_LIBRARIES "cutlass.a") include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif() endif()
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
endif() endif()
# Triton # # Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) # file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) # if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()
# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen) # required by AddMLIR
include(AddLLVM)
include(AddMLIR)
# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
# link_directories(${LLVM_LIBRARY_DIR})
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(bin)
add_library(triton SHARED ${PYTHON_SRC})
# find_package(PythonLibs REQUIRED)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(triton
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonLLVMIR
TritonPTX
${dialect_libs}
${conversion_libs}
# optimizations
MLIRPass
MLIRTransforms
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRIR
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS}) target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z)
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
endif()
if(BUILD_PYTHON_MODULE) if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS # Check if the platform is MacOS
if(APPLE) if(APPLE)
@@ -121,3 +237,7 @@ if(BUILD_PYTHON_MODULE)
endif() endif()
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS}) target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif() endif()
add_subdirectory(test)
add_subdirectory(unittest)

View File

@@ -1,6 +1,6 @@
/* /*
* Copyright 2018-2020 Philippe Tillet * Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2021 OpenAI * Copyright 2020-2022 OpenAI
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files * a copy of this software and associated documentation files
@@ -20,4 +20,4 @@
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */

View File

@@ -18,6 +18,30 @@ The foundations of this project are described in the following MAPL2019 publicat
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials. The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
# Quick Installation
You can install the latest stable release of Triton from pip:
```bash
pip install triton
```
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
And the latest nightly release:
```bash
pip install -U --pre triton
```
# Install from source
```
git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
```
# Changelog # Changelog
Version 1.1 is out! New features include: Version 1.1 is out! New features include:

60
bin/CMakeLists.txt Normal file
View File

@@ -0,0 +1,60 @@
add_subdirectory(FileCheck)
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
# TODO: what's this?
llvm_update_compile_flags(triton-opt)
target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIROptLib
MLIRPass
MLIRTransforms
)
mlir_check_all_link_libraries(triton-opt)
# add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
#llvm_update_compile_flags(triton-translate)
# target_link_libraries(triton-translate PRIVATE
# TritonAnalysis
# TritonTransforms
# TritonGPUTransforms
# TritonLLVMIR
# TritonDriver
# ${dialect_libs}
# ${conversion_libs}
# # tests
# TritonTestAnalysis
# LLVMCore
# LLVMSupport
# LLVMOption
# LLVMCodeGen
# LLVMAsmParser
# # MLIR core
# MLIROptLib
# MLIRIR
# MLIRPass
# MLIRSupport
# MLIRTransforms
# MLIRExecutionEngine
# MLIRMathToLLVM
# MLIRTransformUtils
# MLIRLLVMToLLVMIRTranslation
# MLIRNVVMToLLVMIRTranslation
# )
# mlir_check_all_link_libraries(triton-translate)

View File

@@ -0,0 +1,2 @@
add_llvm_executable(FileCheck FileCheck.cpp)
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)

882
bin/FileCheck/FileCheck.cpp Normal file
View File

@@ -0,0 +1,882 @@
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// FileCheck does a line-by line check of a file that validates whether it
// contains the expected content. This is useful for regression tests etc.
//
// This program exits with an exit status of 2 on error, exit status of 0 if
// the file matched the expected contents, and exit status of 1 if it did not
// contain the expected contents.
//
//===----------------------------------------------------------------------===//
#include "llvm/FileCheck/FileCheck.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <map>
using namespace llvm;
static cl::extrahelp FileCheckOptsEnv(
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
"from the command line.\n");
static cl::opt<std::string>
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
static cl::opt<std::string>
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
cl::init("-"), cl::value_desc("filename"));
static cl::list<std::string> CheckPrefixes(
"check-prefix",
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
static cl::alias CheckPrefixesAlias(
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
cl::NotHidden,
cl::desc(
"Alias for -check-prefix permitting multiple comma separated values"));
static cl::list<std::string> CommentPrefixes(
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
"LLVM's LIT-based test suites, which should be easier to\n"
"maintain if they all follow a consistent comment style. This\n"
"feature is meant for non-LIT test suites using FileCheck."));
static cl::opt<bool> NoCanonicalizeWhiteSpace(
"strict-whitespace",
cl::desc("Do not treat all horizontal whitespace as equivalent"));
static cl::opt<bool> IgnoreCase("ignore-case",
cl::desc("Use case-insensitive matching"));
static cl::list<std::string> ImplicitCheckNot(
"implicit-check-not",
cl::desc("Add an implicit negative check with this pattern to every\n"
"positive check. This can be used to ensure that no instances of\n"
"this pattern occur which are not matched by a positive pattern"),
cl::value_desc("pattern"));
static cl::list<std::string>
GlobalDefines("D", cl::AlwaysPrefix,
cl::desc("Define a variable to be used in capture patterns."),
cl::value_desc("VAR=VALUE"));
static cl::opt<bool> AllowEmptyInput(
"allow-empty", cl::init(false),
cl::desc("Allow the input file to be empty. This is useful when making\n"
"checks that some error message does not occur, for example."));
static cl::opt<bool> AllowUnusedPrefixes(
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
cl::desc("Allow prefixes to be specified but not appear in the test."));
static cl::opt<bool> MatchFullLines(
"match-full-lines", cl::init(false),
cl::desc("Require all positive matches to cover an entire input line.\n"
"Allows leading and trailing whitespace if --strict-whitespace\n"
"is not also passed."));
static cl::opt<bool> EnableVarScope(
"enable-var-scope", cl::init(false),
cl::desc("Enables scope for regex variables. Variables with names that\n"
"do not start with '$' will be reset at the beginning of\n"
"each CHECK-LABEL block."));
static cl::opt<bool> AllowDeprecatedDagOverlap(
"allow-deprecated-dag-overlap", cl::init(false),
cl::desc("Enable overlapping among matches in a group of consecutive\n"
"CHECK-DAG directives. This option is deprecated and is only\n"
"provided for convenience as old tests are migrated to the new\n"
"non-overlapping CHECK-DAG implementation.\n"));
static cl::opt<bool> Verbose(
"v", cl::init(false), cl::ZeroOrMore,
cl::desc("Print directive pattern matches, or add them to the input dump\n"
"if enabled.\n"));
static cl::opt<bool> VerboseVerbose(
"vv", cl::init(false), cl::ZeroOrMore,
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
"issues, or add it to the input dump if enabled. Implies\n"
"-v.\n"));
// The order of DumpInputValue members affects their precedence, as documented
// for -dump-input below.
enum DumpInputValue {
DumpInputNever,
DumpInputFail,
DumpInputAlways,
DumpInputHelp
};
static cl::list<DumpInputValue> DumpInputs(
"dump-input",
cl::desc("Dump input to stderr, adding annotations representing\n"
"currently enabled diagnostics. When there are multiple\n"
"occurrences of this option, the <value> that appears earliest\n"
"in the list below has precedence. The default is 'fail'.\n"),
cl::value_desc("mode"),
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
clEnumValN(DumpInputAlways, "always", "Always dump input"),
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
clEnumValN(DumpInputNever, "never", "Never dump input")));
// The order of DumpInputFilterValue members affects their precedence, as
// documented for -dump-input-filter below.
enum DumpInputFilterValue {
DumpInputFilterError,
DumpInputFilterAnnotation,
DumpInputFilterAnnotationFull,
DumpInputFilterAll
};
static cl::list<DumpInputFilterValue> DumpInputFilters(
"dump-input-filter",
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
"kind <value> plus any context specified by -dump-input-context.\n"
"When there are multiple occurrences of this option, the <value>\n"
"that appears earliest in the list below has precedence. The\n"
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
"-dump-input=always.\n"),
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
"Input lines with annotations"),
clEnumValN(DumpInputFilterAnnotation, "annotation",
"Input lines with starting points of annotations"),
clEnumValN(DumpInputFilterError, "error",
"Input lines with starting points of error "
"annotations")));
static cl::list<unsigned> DumpInputContexts(
"dump-input-context", cl::value_desc("N"),
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
"before and <N> input lines after any lines specified by\n"
"-dump-input-filter. When there are multiple occurrences of\n"
"this option, the largest specified <N> has precedence. The\n"
"default is 5.\n"));
typedef cl::list<std::string>::const_iterator prefix_iterator;
static void DumpCommandLine(int argc, char **argv) {
errs() << "FileCheck command line: ";
for (int I = 0; I < argc; I++)
errs() << " " << argv[I];
errs() << "\n";
}
struct MarkerStyle {
/// The starting char (before tildes) for marking the line.
char Lead;
/// What color to use for this annotation.
raw_ostream::Colors Color;
/// A note to follow the marker, or empty string if none.
std::string Note;
/// Does this marker indicate inclusion by -dump-input-filter=error?
bool FiltersAsError;
MarkerStyle() {}
MarkerStyle(char Lead, raw_ostream::Colors Color,
const std::string &Note = "", bool FiltersAsError = false)
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
assert((!FiltersAsError || !Note.empty()) &&
"expected error diagnostic to have note");
}
};
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
switch (MatchTy) {
case FileCheckDiag::MatchFoundAndExpected:
return MarkerStyle('^', raw_ostream::GREEN);
case FileCheckDiag::MatchFoundButExcluded:
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButWrongLine:
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButDiscarded:
return MarkerStyle('!', raw_ostream::CYAN,
"discard: overlaps earlier match");
case FileCheckDiag::MatchFoundErrorNote:
// Note should always be overridden within the FileCheckDiag.
return MarkerStyle('!', raw_ostream::RED,
"error: unknown error after match",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneAndExcluded:
return MarkerStyle('X', raw_ostream::GREEN);
case FileCheckDiag::MatchNoneButExpected:
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneForInvalidPattern:
return MarkerStyle('X', raw_ostream::RED,
"error: match failed for invalid pattern",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFuzzy:
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
/*FiltersAsError=*/true);
}
llvm_unreachable_internal("unexpected match type");
}
static void DumpInputAnnotationHelp(raw_ostream &OS) {
OS << "The following description was requested by -dump-input=help to\n"
<< "explain the input dump printed by FileCheck.\n"
<< "\n"
<< "Related command-line options:\n"
<< "\n"
<< " - -dump-input=<value> enables or disables the input dump\n"
<< " - -dump-input-filter=<value> filters the input lines\n"
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
<< " - -v and -vv add more annotations\n"
<< " - -color forces colors to be enabled both in the dump and below\n"
<< " - -help documents the above options in more detail\n"
<< "\n"
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
<< "maximum debugging output on failures:\n"
<< "\n"
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
<< "\n"
<< "Input dump annotation format:\n"
<< "\n";
// Labels for input lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
OS << " labels line number L of the input file\n"
<< " An extra space is added after each input line to represent"
<< " the\n"
<< " newline character\n";
// Labels for annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
OS << " labels the only match result for either (1) a pattern of type T"
<< " from\n"
<< " line L of the check file if L is an integer or (2) the"
<< " I-th implicit\n"
<< " pattern if L is \"imp\" followed by an integer "
<< "I (index origin one)\n";
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
OS << " labels the Nth match result for such a pattern\n";
// Markers on annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
OS << " marks good match (reported if -v)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
OS << " marks bad match, such as:\n"
<< " - CHECK-NEXT on same line as previous match (error)\n"
<< " - CHECK-NOT found (error)\n"
<< " - CHECK-DAG overlapping match (discarded, reported if "
<< "-vv)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
OS << " marks search range when no match is found, such as:\n"
<< " - CHECK-NEXT not found (error)\n"
<< " - CHECK-NOT not found (success, reported if -vv)\n"
<< " - CHECK-DAG not found after discarded matches (error)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
OS << " marks fuzzy match when no match is found\n";
// Elided lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
OS << " indicates elided input lines and annotations, as specified by\n"
<< " -dump-input-filter and -dump-input-context\n";
// Colors.
OS << " - colors ";
WithColor(OS, raw_ostream::GREEN, true) << "success";
OS << ", ";
WithColor(OS, raw_ostream::RED, true) << "error";
OS << ", ";
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
OS << "\n";
}
/// An annotation for a single input line.
struct InputAnnotation {
/// The index of the match result across all checks
unsigned DiagIndex;
/// The label for this annotation.
std::string Label;
/// Is this the initial fragment of a diagnostic that has been broken across
/// multiple lines?
bool IsFirstLine;
/// What input line (one-origin indexing) this annotation marks. This might
/// be different from the starting line of the original diagnostic if
/// !IsFirstLine.
unsigned InputLine;
/// The column range (one-origin indexing, open end) in which to mark the
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
/// before the newline.
unsigned InputStartCol, InputEndCol;
/// The marker to use.
MarkerStyle Marker;
/// Whether this annotation represents a good match for an expected pattern.
bool FoundAndExpectedMatch;
};
/// Get an abbreviation for the check type.
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
switch (Ty) {
case Check::CheckPlain:
if (Ty.getCount() > 1)
return "count";
return "check";
case Check::CheckNext:
return "next";
case Check::CheckSame:
return "same";
case Check::CheckNot:
return "not";
case Check::CheckDAG:
return "dag";
case Check::CheckLabel:
return "label";
case Check::CheckEmpty:
return "empty";
case Check::CheckComment:
return "com";
case Check::CheckEOF:
return "eof";
case Check::CheckBadNot:
return "bad-not";
case Check::CheckBadCount:
return "bad-count";
case Check::CheckNone:
llvm_unreachable("invalid FileCheckType");
}
llvm_unreachable("unknown FileCheckType");
}
static void
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
const std::vector<FileCheckDiag> &Diags,
std::vector<InputAnnotation> &Annotations,
unsigned &LabelWidth) {
struct CompareSMLoc {
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
return LHS.getPointer() < RHS.getPointer();
}
};
// How many diagnostics does each pattern have?
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
for (auto Diag : Diags)
++DiagCountPerPattern[Diag.CheckLoc];
// How many diagnostics have we seen so far per pattern?
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
// How many total diagnostics have we seen so far?
unsigned DiagIndex = 0;
// What's the widest label?
LabelWidth = 0;
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
++DiagItr) {
InputAnnotation A;
A.DiagIndex = DiagIndex++;
// Build label, which uniquely identifies this check result.
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
auto CheckLineAndCol =
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
llvm::raw_string_ostream Label(A.Label);
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
if (CheckBufferID == CheckFileBufferID)
Label << CheckLineAndCol.first;
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
CheckBufferID < ImpPatBufferIDRange.second)
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
else
llvm_unreachable("expected diagnostic's check location to be either in "
"the check file or for an implicit pattern");
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
A.Marker = GetMarker(DiagItr->MatchTy);
if (!DiagItr->Note.empty()) {
A.Marker.Note = DiagItr->Note;
// It's less confusing if notes that don't actually have ranges don't have
// markers. For example, a marker for 'with "VAR" equal to "5"' would
// seem to indicate where "VAR" matches, but the location we actually have
// for the marker simply points to the start of the match/search range for
// the full pattern of which the substitution is potentially just one
// component.
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
DiagItr->InputStartCol == DiagItr->InputEndCol)
A.Marker.Lead = ' ';
}
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
assert(!DiagItr->Note.empty() &&
"expected custom note for MatchFoundErrorNote");
A.Marker.Note = "error: " + A.Marker.Note;
}
A.FoundAndExpectedMatch =
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
// Compute the mark location, and break annotation into multiple
// annotations if it spans multiple lines.
A.IsFirstLine = true;
A.InputLine = DiagItr->InputStartLine;
A.InputStartCol = DiagItr->InputStartCol;
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
// Sometimes ranges are empty in order to indicate a specific point, but
// that would mean nothing would be marked, so adjust the range to
// include the following character.
A.InputEndCol =
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
Annotations.push_back(A);
} else {
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
"expected input range not to be inverted");
A.InputEndCol = UINT_MAX;
Annotations.push_back(A);
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
L <= E; ++L) {
// If a range ends before the first column on a line, then it has no
// characters on that line, so there's nothing to render.
if (DiagItr->InputEndCol == 1 && L == E)
break;
InputAnnotation B;
B.DiagIndex = A.DiagIndex;
B.Label = A.Label;
B.IsFirstLine = false;
B.InputLine = L;
B.Marker = A.Marker;
B.Marker.Lead = '~';
B.Marker.Note = "";
B.InputStartCol = 1;
if (L != E)
B.InputEndCol = UINT_MAX;
else
B.InputEndCol = DiagItr->InputEndCol;
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
Annotations.push_back(B);
}
}
}
}
static unsigned FindInputLineInFilter(
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
if (DumpInputFilter == DumpInputFilterAll)
return CurInputLine;
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
++AnnotationItr) {
switch (DumpInputFilter) {
case DumpInputFilterAll:
llvm_unreachable("unexpected DumpInputFilterAll");
break;
case DumpInputFilterAnnotationFull:
return AnnotationItr->InputLine;
case DumpInputFilterAnnotation:
if (AnnotationItr->IsFirstLine)
return AnnotationItr->InputLine;
break;
case DumpInputFilterError:
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
return AnnotationItr->InputLine;
break;
}
}
return UINT_MAX;
}
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
unsigned LabelWidth) {
if (ElidedLines.empty())
return;
unsigned EllipsisLines = 3;
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
for (unsigned i = 0; i < EllipsisLines; ++i) {
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
<< right_justify(".", LabelWidth);
OS << '\n';
}
} else
OS << ElidedLines;
ElidedLines.clear();
}
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
DumpInputFilterValue DumpInputFilter,
unsigned DumpInputContext,
StringRef InputFileText,
std::vector<InputAnnotation> &Annotations,
unsigned LabelWidth) {
OS << "Input was:\n<<<<<<\n";
// Sort annotations.
llvm::sort(Annotations,
[](const InputAnnotation &A, const InputAnnotation &B) {
// 1. Sort annotations in the order of the input lines.
//
// This makes it easier to find relevant annotations while
// iterating input lines in the implementation below. FileCheck
// does not always produce diagnostics in the order of input
// lines due to, for example, CHECK-DAG and CHECK-NOT.
if (A.InputLine != B.InputLine)
return A.InputLine < B.InputLine;
// 2. Sort annotations in the temporal order FileCheck produced
// their associated diagnostics.
//
// This sort offers several benefits:
//
// A. On a single input line, the order of annotations reflects
// the FileCheck logic for processing directives/patterns.
// This can be helpful in understanding cases in which the
// order of the associated directives/patterns in the check
// file or on the command line either (i) does not match the
// temporal order in which FileCheck looks for matches for the
// directives/patterns (due to, for example, CHECK-LABEL,
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
// that order but does not match the order of those
// diagnostics along an input line (due to, for example,
// CHECK-DAG).
//
// On the other hand, because our presentation format presents
// input lines in order, there's no clear way to offer the
// same benefit across input lines. For consistency, it might
// then seem worthwhile to have annotations on a single line
// also sorted in input order (that is, by input column).
// However, in practice, this appears to be more confusing
// than helpful. Perhaps it's intuitive to expect annotations
// to be listed in the temporal order in which they were
// produced except in cases the presentation format obviously
// and inherently cannot support it (that is, across input
// lines).
//
// B. When diagnostics' annotations are split among multiple
// input lines, the user must track them from one input line
// to the next. One property of the sort chosen here is that
// it facilitates the user in this regard by ensuring the
// following: when comparing any two input lines, a
// diagnostic's annotations are sorted in the same position
// relative to all other diagnostics' annotations.
return A.DiagIndex < B.DiagIndex;
});
// Compute the width of the label column.
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
*InputFileEnd = InputFileText.bytes_end();
unsigned LineCount = InputFileText.count('\n');
if (InputFileEnd[-1] != '\n')
++LineCount;
unsigned LineNoWidth = std::log10(LineCount) + 1;
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
// on input lines and (2) to the right of the (left-aligned) labels on
// annotation lines so that input lines and annotation lines are more
// visually distinct. For example, the spaces on the annotation lines ensure
// that input line numbers and check directive line numbers never align
// horizontally. Those line numbers might not even be for the same file.
// One space would be enough to achieve that, but more makes it even easier
// to see.
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
// Print annotated input lines.
unsigned PrevLineInFilter = 0; // 0 means none so far
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
std::string ElidedLines;
raw_string_ostream ElidedLinesOS(ElidedLines);
ColorMode TheColorMode =
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
if (TheColorMode == ColorMode::Enable)
ElidedLinesOS.enable_colors(true);
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
for (unsigned Line = 1;
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
const unsigned char *InputFileLine = InputFilePtr;
// Compute the previous and next line included by the filter.
if (NextLineInFilter < Line)
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
AnnotationItr, AnnotationEnd);
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
if (NextLineInFilter == Line)
PrevLineInFilter = Line;
// Elide this input line and its annotations if it's not within the
// context specified by -dump-input-context of an input line included by
// -dump-input-filter. However, in case the resulting ellipsis would occupy
// more lines than the input lines and annotations it elides, buffer the
// elided lines and annotations so we can print them instead.
raw_ostream *LineOS = &OS;
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
(NextLineInFilter == UINT_MAX ||
Line + DumpInputContext < NextLineInFilter))
LineOS = &ElidedLinesOS;
else {
LineOS = &OS;
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
}
// Print right-aligned line number.
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
TheColorMode)
<< format_decimal(Line, LabelWidth) << ": ";
// For the case where -v and colors are enabled, find the annotations for
// good matches for expected patterns in order to highlight everything
// else in the line. There are no such annotations if -v is disabled.
std::vector<InputAnnotation> FoundAndExpectedMatches;
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
++I) {
if (I->FoundAndExpectedMatch)
FoundAndExpectedMatches.push_back(*I);
}
}
// Print numbered line with highlighting where there are no matches for
// expected patterns.
bool Newline = false;
{
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
/*BG=*/false, TheColorMode);
bool InMatch = false;
if (Req.Verbose)
COS.changeColor(raw_ostream::CYAN, true, true);
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
bool WasInMatch = InMatch;
InMatch = false;
for (auto M : FoundAndExpectedMatches) {
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
InMatch = true;
break;
}
}
if (!WasInMatch && InMatch)
COS.resetColor();
else if (WasInMatch && !InMatch)
COS.changeColor(raw_ostream::CYAN, true, true);
if (*InputFilePtr == '\n') {
Newline = true;
COS << ' ';
} else
COS << *InputFilePtr;
++InputFilePtr;
}
}
*LineOS << '\n';
unsigned InputLineWidth = InputFilePtr - InputFileLine;
// Print any annotations.
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
/*BG=*/false, TheColorMode);
// The two spaces below are where the ": " appears on input lines.
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
unsigned Col;
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
COS << ' ';
COS << AnnotationItr->Marker.Lead;
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
++Col)
COS << '~';
const std::string &Note = AnnotationItr->Marker.Note;
if (!Note.empty()) {
// Put the note at the end of the input line. If we were to instead
// put the note right after the marker, subsequent annotations for the
// same input line might appear to mark this note instead of the input
// line.
for (; Col <= InputLineWidth; ++Col)
COS << ' ';
COS << ' ' << Note;
}
COS << '\n';
++AnnotationItr;
}
}
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
OS << ">>>>>>\n";
}
int main(int argc, char **argv) {
// Enable use of ANSI color codes because FileCheck is using them to
// highlight text.
llvm::sys::Process::UseANSIEscapeCodes(true);
InitLLVM X(argc, argv);
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
"FILECHECK_OPTS");
// Select -dump-input* values. The -help documentation specifies the default
// value and which value to choose if an option is specified multiple times.
// In the latter case, the general rule of thumb is to choose the value that
// provides the most information.
DumpInputValue DumpInput =
DumpInputs.empty()
? DumpInputFail
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
DumpInputFilterValue DumpInputFilter;
if (DumpInputFilters.empty())
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
: DumpInputFilterError;
else
DumpInputFilter =
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
unsigned DumpInputContext = DumpInputContexts.empty()
? 5
: *std::max_element(DumpInputContexts.begin(),
DumpInputContexts.end());
if (DumpInput == DumpInputHelp) {
DumpInputAnnotationHelp(outs());
return 0;
}
if (CheckFilename.empty()) {
errs() << "<check-file> not specified\n";
return 2;
}
FileCheckRequest Req;
append_range(Req.CheckPrefixes, CheckPrefixes);
append_range(Req.CommentPrefixes, CommentPrefixes);
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
bool GlobalDefineError = false;
for (StringRef G : GlobalDefines) {
size_t EqIdx = G.find('=');
if (EqIdx == std::string::npos) {
errs() << "Missing equal sign in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
if (EqIdx == 0) {
errs() << "Missing variable name in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
Req.GlobalDefines.push_back(G);
}
if (GlobalDefineError)
return 2;
Req.AllowEmptyInput = AllowEmptyInput;
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
Req.EnableVarScope = EnableVarScope;
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
Req.Verbose = Verbose;
Req.VerboseVerbose = VerboseVerbose;
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
Req.MatchFullLines = MatchFullLines;
Req.IgnoreCase = IgnoreCase;
if (VerboseVerbose)
Req.Verbose = true;
FileCheck FC(Req);
if (!FC.ValidateCheckPrefixes())
return 2;
Regex PrefixRE = FC.buildCheckPrefixRegex();
std::string REError;
if (!PrefixRE.isValid(REError)) {
errs() << "Unable to combine check-prefix strings into a prefix regular "
"expression! This is likely a bug in FileCheck's verification of "
"the check-prefix strings. Regular expression parsing failed "
"with the following error: "
<< REError << "\n";
return 2;
}
SourceMgr SM;
// Read the expected strings from the check file.
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
if (std::error_code EC = CheckFileOrErr.getError()) {
errs() << "Could not open check file '" << CheckFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
SmallString<4096> CheckFileBuffer;
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
unsigned CheckFileBufferID =
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
CheckFileText, CheckFile.getBufferIdentifier()),
SMLoc());
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
return 2;
// Open the file to check and add it to SourceMgr.
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
if (InputFilename == "-")
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
if (std::error_code EC = InputFileOrErr.getError()) {
errs() << "Could not open input file '" << InputFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &InputFile = *InputFileOrErr.get();
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
DumpCommandLine(argc, argv);
return 2;
}
SmallString<4096> InputFileBuffer;
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
InputFileText, InputFile.getBufferIdentifier()),
SMLoc());
std::vector<FileCheckDiag> Diags;
int ExitCode = FC.checkInput(SM, InputFileText,
DumpInput == DumpInputNever ? nullptr : &Diags)
? EXIT_SUCCESS
: 1;
if (DumpInput == DumpInputAlways ||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
errs() << "\n"
<< "Input file: " << InputFilename << "\n"
<< "Check file: " << CheckFilename << "\n"
<< "\n"
<< "-dump-input=help explains the following input dump.\n"
<< "\n";
std::vector<InputAnnotation> Annotations;
unsigned LabelWidth;
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
Annotations, LabelWidth);
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
InputFileText, Annotations, LabelWidth);
}
return ExitCode;
}

42
bin/triton-opt.cpp Normal file
View File

@@ -0,0 +1,42 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/Passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
namespace mlir {
namespace test {
void registerTestAliasPass();
void registerTestAlignmentPass();
void registerTestAllocationPass();
void registerTestMembarPass();
} // namespace test
} // namespace mlir
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
}

131
bin/triton-translate.cpp Normal file
View File

@@ -0,0 +1,131 @@
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/llvm.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <iostream>
namespace mlir {
namespace triton {
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
MLIRContext &context) {
std::string errorMessage;
auto input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
mlir::DialectRegistry registry;
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, arith::ArithmeticDialect,
StandardOpsDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
-> OwningOpRef<ModuleOp> {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
if (!module) {
llvm::errs() << "Parse MLIR file failed.";
return nullptr;
}
return module;
};
auto module = processBuffer(std::move(input));
if (!module) {
return nullptr;
}
return module;
}
LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::StringRef toolName) {
static llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> targetKind(
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
llvm::cl::init(80));
static llvm::cl::opt<int> ptxVersion(
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
llvm::InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
mlir::MLIRContext context;
auto module = loadMLIRModule(inputFilename, context);
if (!module) {
return failure();
}
std::string errorMessage;
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return failure();
}
llvm::LLVMContext llvmContext;
auto llvmir =
translateTritonGPUToLLVMIR(&llvmContext, *module, SMArch.getValue());
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
if (targetKind == "llvmir")
llvm::outs() << *llvmir << '\n';
else if (targetKind == "ptx")
llvm::outs() << ::triton::driver::llir_to_ptx(
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
return success();
}
} // namespace triton
} // namespace mlir
int main(int argc, char **argv) {
return failed(mlir::triton::tritonTranslateMain(
argc, argv, "Triton Translate Testing Tool."));
}

View File

@@ -25,7 +25,7 @@
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn). # LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0). # LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
# #
# Note: The variable names were chosen in conformance with the offical CMake # Note: The variable names were chosen in conformance with the official CMake
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt. # guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
# Try suffixed versions to pick up the newest LLVM install available on Debian # Try suffixed versions to pick up the newest LLVM install available on Debian
@@ -196,4 +196,4 @@ include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(LLVM find_package_handle_standard_args(LLVM
REQUIRED_VARS LLVM_ROOT_DIR REQUIRED_VARS LLVM_ROOT_DIR
VERSION_VAR LLVM_VERSION_STRING) VERSION_VAR LLVM_VERSION_STRING)

27
docs/_templates/versions.html vendored Normal file
View File

@@ -0,0 +1,27 @@
{%- if current_version %}
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: {{ current_version.name }}
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
{%- if versions.tags %}
<dl>
<dt>Tags</dt>
{%- for item in versions.tags %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
{%- if versions.branches %}
<dl>
<dt>Branches</dt>
{%- for item in versions.branches %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
</div>
</div>
{%- endif %}

View File

@@ -34,14 +34,17 @@ def process_sig(app, what, name, obj, options, signature, return_annotation):
def setup(app): def setup(app):
"""Customize function args retrieving to get args under decorator.""" """Customize function args retrieving to get args under decorator."""
import sphinx import sphinx
import triton import os
app.connect("autodoc-process-signature", process_sig) app.connect("autodoc-process-signature", process_sig)
os.system("pip install -e ../python")
def forward_jit_fn(func): def forward_jit_fn(func):
old = func old = func
def wrapped(obj, **kwargs): def wrapped(obj, **kwargs):
import triton
if isinstance(obj, triton.code_gen.JITFunction): if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn obj = obj.fn
return old(obj) return old(obj)
@@ -52,6 +55,7 @@ def setup(app):
old_documenter = sphinx.ext.autosummary.get_documenter old_documenter = sphinx.ext.autosummary.get_documenter
def documenter(app, obj, parent): def documenter(app, obj, parent):
import triton
if isinstance(obj, triton.code_gen.JITFunction): if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn obj = obj.fn
return old_documenter(app, obj, parent) return old_documenter(app, obj, parent)
@@ -66,9 +70,17 @@ def setup(app):
import sys import sys
import os import os
sys.path.insert(0, os.path.abspath('../python/')) sys.path.insert(0, os.path.abspath('../python/'))
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon'] extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion']
autosummary_generate = True autosummary_generate = True
# versioning config
smv_tag_whitelist = r'^(v1.1.2)$'
smv_branch_whitelist = r'^master$'
smv_remote_whitelist = None
smv_released_pattern = r'^tags/.*$'
smv_outputdir_format = '{ref.name}'
smv_prefer_remote_refs = False
# Sphinx gallery # Sphinx gallery
extensions += ['sphinx_gallery.gen_gallery'] extensions += ['sphinx_gallery.gen_gallery']
from sphinx_gallery.sorting import FileNameSortKey from sphinx_gallery.sorting import FileNameSortKey
@@ -85,6 +97,11 @@ sphinx_gallery_conf = {
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
html_sidebars = {
'**': [
'_templates/versions.html',
],
}
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:

View File

@@ -44,7 +44,7 @@ You can then test your installation by running the unit tests:
.. code-block:: bash .. code-block:: bash
pip install -r requirements-test.txt pip install -e '.[tests]'
pytest -vs test/unit/ pytest -vs test/unit/
and the benchmarks and the benchmarks

View File

@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
Program Representation Program Representation
+++++++++++++++++++++++ +++++++++++++++++++++++
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming. Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
.. table:: .. table::
:widths: 50 50 :widths: 50 50
@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
Limitations Limitations
++++++++++++ ++++++++++++
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular. This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
.. table:: .. table::
:widths: 50 50 :widths: 50 50

1
include/CMakeLists.txt Normal file
View File

@@ -0,0 +1 @@
add_subdirectory(triton)

View File

@@ -0,0 +1,80 @@
#ifndef TRITON_ANALYSIS_ALIAS_H
#define TRITON_ANALYSIS_ALIAS_H
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
class AliasInfo {
public:
AliasInfo() = default;
AliasInfo(Value value) { insert(value); }
void insert(Value value) { allocs.insert(value); }
const DenseSet<Value> &getAllocs() const { return allocs; }
bool operator==(const AliasInfo &other) const {
return allocs == other.allocs;
}
/// The pessimistic value state of a value without alias
static AliasInfo getPessimisticValueState(MLIRContext *context) {
return AliasInfo();
}
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
/// The union of both arguments
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
private:
/// The set of allocated values that are aliased by this lattice.
/// For now, we only consider aliased value produced by the following
/// situations:
/// 1. values returned by scf.yield
/// 2. block arguments in scf.for
/// Example:
/// alloc v1 alloc v2
/// | |
/// |--------------| |------------|
/// scf.for v3 scf.for v4 scf.for v5
/// |
/// scf.yield v6
///
/// v1's alloc [v1]
/// v2's alloc [v2]
/// v3's alloc [v1]
/// v4's alloc [v1, v2]
/// v5's alloc [v2]
/// v6's alloc [v1]
///
/// Therefore, v1's liveness range is the union of v3, v4, and v6
/// v2's liveness range is the union of v4 and v5.
DenseSet<Value> allocs;
};
//===----------------------------------------------------------------------===//
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
public:
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
/// Given two values, returns their aliasing behavior.
AliasResult alias(Value lhs, Value rhs);
/// Returns the modify-reference behavior of `op` on `location`.
ModRefResult getModRef(Operation *op, Value location);
/// Computes if the alloc set of the results are changed.
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_ALIAS_H

View File

@@ -0,0 +1,192 @@
#ifndef TRITON_ANALYSIS_ALLOCATION_H
#define TRITON_ANALYSIS_ALLOCATION_H
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/raw_ostream.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <atomic>
#include <limits>
namespace mlir {
namespace triton {
class AllocationAnalysis;
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
/// A class that represents an interval, specified using a start and an end
/// values: [Start, End).
template <typename T> class Interval {
public:
Interval() {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
T size() const { return End - Start; }
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
bool intersects(const Interval &R) const {
return Start < R.End && R.Start < End;
}
bool operator==(const Interval &R) const {
return Start == R.Start && End == R.End;
}
bool operator!=(const Interval &R) const { return !(*this == R); }
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
private:
T Start = std::numeric_limits<T>::min();
T End = std::numeric_limits<T>::max();
};
class Allocation {
public:
/// A unique identifier for shared memory buffers
using BufferId = size_t;
using BufferIdSetT = DenseSet<BufferId>;
static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();
/// Creates a new Allocation analysis that computes the shared memory
/// information for all associated shared memory values.
Allocation(Operation *operation) : operation(operation) { run(); }
/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
/// Returns the offset of the given buffer in the shared memory.
size_t getOffset(BufferId bufferId) const {
return bufferSet.lookup(bufferId).offset;
}
/// Returns the size of the given buffer in the shared memory.
size_t getAllocatedSize(BufferId bufferId) const {
return bufferSet.lookup(bufferId).size;
}
/// Returns the buffer id of the given value.
/// This interface only returns the allocated buffer id.
/// If you want to get all the buffer ids that are associated with the given
/// value, including alias buffers, use getBufferIds.
BufferId getBufferId(Value value) const {
if (valueBuffer.count(value)) {
return valueBuffer.lookup(value)->id;
} else {
return InvalidBufferId;
}
}
/// Returns all the buffer ids of the given value, including alias buffers.
BufferIdSetT getBufferIds(Value value) const {
BufferIdSetT bufferIds;
auto allocBufferId = getBufferId(value);
if (allocBufferId != InvalidBufferId)
bufferIds.insert(allocBufferId);
for (auto *buffer : aliasBuffer.lookup(value)) {
if (buffer->id != InvalidBufferId)
bufferIds.insert(buffer->id);
}
return bufferIds;
}
/// Returns the scratch buffer id of the given value.
BufferId getBufferId(Operation *operation) const {
if (opScratch.count(operation)) {
return opScratch.lookup(operation)->id;
} else {
return InvalidBufferId;
}
}
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
return false;
auto lhsBuffer = bufferSet.lookup(lhsId);
auto rhsBuffer = bufferSet.lookup(rhsId);
return lhsBuffer.intersects(rhsBuffer);
}
private:
/// A class that represents a shared memory buffer
struct BufferT {
enum class BufferKind { Explicit, Scratch };
/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;
BufferKind kind;
BufferId id;
size_t size;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
bool intersects(const BufferT &other) const {
return Interval<size_t>(offset, offset + size)
.intersects(
Interval<size_t>(other.offset, other.offset + other.size));
}
};
/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = DenseMap<BufferId, BufferT>;
/// Runs allocation analysis on the given top-level operation.
void run();
private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&...args) {
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
bufferSet[buffer.id] = std::move(buffer);
if constexpr (Kind == BufferT::BufferKind::Explicit) {
valueBuffer[key] = &bufferSet[buffer.id];
} else {
opScratch[key] = &bufferSet[buffer.id];
}
}
void addAlias(Value value, Value alloc) {
aliasBuffer[value].insert(valueBuffer[alloc]);
}
private:
Operation *operation;
OpScratchMapT opScratch;
ValueBufferMapT valueBuffer;
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;
friend class triton::AllocationAnalysis;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_ALLOCATION_H

View File

@@ -0,0 +1,144 @@
#ifndef TRITON_ANALYSIS_AXISINFO_H
#define TRITON_ANALYSIS_AXISINFO_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
/// This lattice value represents known information on the axes of a lattice.
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef SmallVector<int, 4> DimVectorT;
public:
// Default constructor
AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == (size_t)rank);
assert(knownConstancy.size() == (size_t)rank);
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const DimVectorT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DimVectorT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const DimVectorT &getConstancy() const { return constancy; }
int getRank() const { return rank; }
// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&
(divisibility == other.divisibility) &&
(constancy == other.constancy);
}
/// The pessimistic value state of the contiguity is unknown.
static AxisInfo getPessimisticValueState(MLIRContext *context) {
return AxisInfo();
}
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of contiguous integers along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
/// Would have contiguity [1, 4].
/// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
/// [18, 22, 26, 30]
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
DimVectorT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
// would have divisibility [1, 2]
// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of constant integer along it. This is
/// particularly useful to infer the contiguity
/// of operations (e.g., add) involving a constant
/// For example
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
DimVectorT constancy;
// number of dimensions of the lattice
int rank;
};
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n) {
if (n == 0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
unsigned getPtrVectorSize(Value ptr);
unsigned getPtrAlignment(Value ptr);
unsigned getMaskAlignment(Value mask);
};
} // namespace mlir
#endif

View File

@@ -0,0 +1,119 @@
#ifndef TRITON_ANALYSIS_MEMBAR_H
#define TRITON_ANALYSIS_MEMBAR_H
#include "Allocation.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
class OpBuilder;
//===----------------------------------------------------------------------===//
// Shared Memory Barrier Analysis
//===----------------------------------------------------------------------===//
class MembarAnalysis {
public:
/// Creates a new Membar analysis that generates the shared memory barrier
/// in the following circumstances:
/// - RAW: If a shared memory write is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// - WAR: If a shared memory read is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// The following circumstances do not require a barrier:
/// - WAW: not possible because overlapped memory allocation is not allowed.
/// - RAR: no write is performed.
/// Temporary storage of operations such as Reduce are considered as both
/// a shared memory read. If the temporary storage is written but not read,
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
private:
struct RegionInfo {
using BufferIdSetT = Allocation::BufferIdSetT;
BufferIdSetT syncReadBuffers;
BufferIdSetT syncWriteBuffers;
RegionInfo() = default;
RegionInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
}
/// Unions two RegionInfo objects.
void join(const RegionInfo &other) {
syncReadBuffers.insert(other.syncReadBuffers.begin(),
other.syncReadBuffers.end());
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
other.syncWriteBuffers.end());
}
/// Returns true if buffers in two RegionInfo objects are intersected.
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/
isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation) ||
/*WAW*/
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
allocation);
}
/// Clears the buffers because a barrier is inserted.
void sync() {
syncReadBuffers.clear();
syncWriteBuffers.clear();
}
private:
/// Returns true if buffers in two sets are intersected.
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
Allocation *allocation) const {
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
return allocation->isIntersected(lhsId, rhsId);
});
});
}
};
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:
/// region1
/// op1
/// op2 (scf.if)
/// region2
/// op3
/// op4
/// region3
/// op5
/// op6
/// op7
/// region2 and region3 started with the information of region1.
/// Each region is analyzed separately and keeps their own copy of the
/// information. At op7, we union the information of the region2 and region3
/// and update the information of region1.
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
/// Updates the RegionInfo operation based on the operation.
void transfer(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
private:
Allocation *allocation;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_MEMBAR_H

View File

@@ -0,0 +1,82 @@
#ifndef TRITON_ANALYSIS_UTILITY_H
#define TRITON_ANALYSIS_UTILITY_H
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
#include <string>
namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.operand().getType().cast<RankedTensorType>();
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
Attribute getSrcLayout() { return srcTy.getEncoding(); }
bool isFastReduction();
unsigned getInterWarpSize();
unsigned getIntraWarpSize();
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();
SmallVector<SmallVector<unsigned>> getScratchConfigsFast();
unsigned getScratchSizeInBytes();
private:
triton::ReduceOp op;
RankedTensorType srcTy{};
};
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);
bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
Type getElementType(Value value);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>
inline SmallVector<T_OUT> convertType(ArrayRef<T_IN> in) {
SmallVector<T_OUT> out;
for (const T_IN &i : in)
out.push_back(T_OUT(i));
return out;
}
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
template <typename T, typename RES_T = T>
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<RES_T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[it.index()] = input[it.value()];
}
return result;
}
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -0,0 +1,2 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TritonConversionPassIncGen)

View File

@@ -0,0 +1,40 @@
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
#define TRITON_CONVERSION_MLIR_TYPES_H_
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
// This file redefines some common MLIR types for easy usage.
namespace mlir {
namespace triton {
namespace type {
// Integer types
// TODO(Superjomn): may change `static` into better implementations
static Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
static Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); }
static Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
static Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
static Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types
static Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
static Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
static Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
static Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
}
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
} // namespace type
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_MLIR_TYPES_H_

View File

@@ -0,0 +1,17 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
namespace mlir {
namespace triton {
#define GEN_PASS_REGISTRATION
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,54 @@
#ifndef TRITON_CONVERSION_PASSES
#define TRITON_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
let summary = "Convert Triton to TritonGPU";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::StandardOpsDialect",
// TODO: Does this pass depend on SCF?
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::gpu::GPUDialect",
"mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect",
"mlir::StandardOpsDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
#endif

View File

@@ -0,0 +1,326 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
namespace mlir {
class ConversionPatternRewriter;
class Location;
namespace triton {
using llvm::StringRef;
struct PTXInstr;
struct PTXInstrCommon;
struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
//
// A helper for building an ASM program, the objective of PTXBuilder is to give
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// Currently, several factors are introduced to reduce the need for mixing
// string and C++ if-else code.
//
// Usage:
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
// "b"(p));
//
// PTXBuilder builder;
// auto& add = builder.create<>();
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
// // predicate here binds %0 to pVal, pVal is a mlir::Value
//
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
//
// To get the asm code:
// builder.dump()
//
// To get all the mlir::Value used in the PTX code,
//
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
//
// To get the string containing all the constraints with "," separated,
// builder.getConstraints() // get "=r,r,k"
//
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
//
// PTXBuilder builder;
// auto& mov = builder.create("mov");
// auto& cp = builder.create("cp");
// mov(...);
// cp(...);
// This will get a PTX code with two instructions.
//
// Similar to a C function, a declared PTXInstr instance can be launched
// multiple times with different operands, e.g.
//
// auto& mov = builder.create("mov");
// mov(... some operands ...);
// mov(... some different operands ...);
//
// Finally, we will get a PTX code with two mov instructions.
//
// There are several derived instruction type for typical instructions, for
// example, the PtxIOInstr for ld and st instructions.
struct PTXBuilder {
struct Operand {
std::string constraint;
Value value;
int idx{-1};
llvm::SmallVector<Operand *> list;
std::function<std::string(int idx)> repr;
// for list
Operand() = default;
Operand(const Operation &) = delete;
Operand(Value value, StringRef constraint)
: constraint(constraint), value(value) {}
bool isList() const { return !value && constraint.empty(); }
Operand *listAppend(Operand *arg) {
list.push_back(arg);
return this;
}
Operand *listGet(size_t nth) const {
assert(nth < list.size());
return list[nth];
}
std::string dump() const;
};
template <typename INSTR = PTXInstr, typename... Args>
INSTR *create(Args &&...args) {
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
return static_cast<INSTR *>(instrs.back().get());
}
// Create a list of operands.
Operand *newListOperand() { return newOperand(); }
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
auto *list = newOperand();
for (auto &item : items) {
list->listAppend(newOperand(item.first, item.second));
}
return list;
}
Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) {
auto *list = newOperand();
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint));
}
return list;
}
Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand();
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint));
}
return list;
}
// Create a new operand. It will not add to operand list.
// @value: the MLIR value bind to this operand.
// @constraint: ASM operand constraint, .e.g. "=r"
// @formatter: extra format to represent this operand in ASM code, default is
// "%{0}".format(operand.idx).
Operand *newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int idx)> formatter = nullptr);
// Create a new operand which is written to, that is, the constraint starts
// with "=", e.g. "=r".
Operand *newOperand(StringRef constraint);
// Create a constant integer operand.
Operand *newConstantOperand(int64_t v);
// Create a constant operand with explicit code specified.
Operand *newConstantOperand(const std::string &v);
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
llvm::SmallVector<Operand *, 4> getAllArgs() const;
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
std::string getConstraints() const;
std::string dump() const;
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
Type resTy, bool hasSideEffect = true,
bool isAlignStack = false,
ArrayRef<Attribute> attrs = {}) const;
private:
Operand *newOperand() {
argArchive.emplace_back(std::make_unique<Operand>());
return argArchive.back().get();
}
// Make the operands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
// determined by PTX code snippet passed from external.
sort(argArchive.begin(), argArchive.end(),
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
auto ida = std::find(order.begin(), order.end(), a.get());
auto idb = std::find(order.begin(), order.end(), b.get());
assert(ida != order.end());
assert(idb != order.end());
return ida < idb;
});
}
friend struct PTXInstr;
friend struct PTXInstrCommon;
protected:
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
int oprCounter{};
};
// PTX instruction common interface.
// Put the generic logic for all the instructions here.
struct PTXInstrCommon {
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
using Operand = PTXBuilder::Operand;
// clang-format off
PTXInstrExecution& operator()() { return call({}); }
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
// clang-format on
// Set operands of this instruction.
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
protected:
// "Call" the instruction with operands.
// \param oprs The operands of this instruction.
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
// to the inline Asm without generating the operand ids(such as $0, $1) in PTX
// code.
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs = false);
PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
friend struct PTXInstrExecution;
};
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
using Operand = PTXBuilder::Operand;
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
: PTXInstrCommon(builder) {
o(name);
}
// Append a suffix to the instruction.
// e.g. PTXInstr("add").o("s32") get a add.s32.
// A predicate is used to tell whether to apply the suffix, so that no if-else
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
// get a `add.s32` if isS32 is true.
ConcreteT &o(const std::string &suffix, bool predicate = true) {
if (predicate)
instrParts.push_back(suffix);
return *static_cast<ConcreteT *>(this);
}
};
struct PTXInstr : public PTXInstrBase<PTXInstr> {
using PTXInstrBase<PTXInstr>::PTXInstrBase;
// Append a ".global" to the instruction.
PTXInstr &global();
// Append a ".shared" to the instruction.
PTXInstr &shared();
// Append a ".v[0-9]+" to the instruction
PTXInstr &v(int vecWidth, bool predicate = true);
// Append a".b[0-9]+" to the instruction
PTXInstr &b(int width);
};
// Record the operands and context for "launching" a PtxInstr.
struct PTXInstrExecution {
using Operand = PTXBuilder::Operand;
llvm::SmallVector<Operand *> argsInOrder;
PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs,
bool onlyAttachMLIRArgs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
// Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
pred = instr->builder->newOperand(value, constraint);
return *this;
}
// Prefix a !predicate to the instruction.
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
pred = instr->builder->newOperand(value, constraint);
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
return *this;
}
std::string dump() const;
SmallVector<Operand *> getArgList() const;
PTXInstrCommon *instr{};
Operand *pred{};
bool onlyAttachMLIRArgs{};
};
/// ====== Some instruction wrappers ======
// We add the wrappers to make the usage more intuitive by avoiding mixing the
// PTX code with some trivial C++ code.
struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXInstrBase(builder, "cp.async") {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
};
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,25 @@
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps);
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,2 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,19 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(TritonTableGen)

View File

@@ -0,0 +1,48 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.h.inc"
namespace mlir {
namespace triton {
class DialectInferLayoutInterface
: public DialectInterface::Base<DialectInferLayoutInterface> {
public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
Optional<Location> location) const = 0;
// Note: this function only verify operand encoding but doesn't infer result
// encoding
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
Optional<Location> location) const = 0;
};
} // namespace triton
} // namespace mlir
#endif // TRITON_IR_DIALECT_H_

View File

@@ -0,0 +1,9 @@
#ifndef TRITON_IR_INTERFACES_H_
#define TRITON_IR_INTERFACES_H_
#include "mlir/IR/OpDefinition.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,60 @@
#ifndef TRITON_IR_TRAITS_H_
#define TRITON_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include <iostream>
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
LogicalResult verifySameOperandsEncoding(Operation *op);
// The rationale for this trait is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
// Since H100 has 256KB registers, we should allow users to create tensors
// of size up to 256K elements. It will spill for datatypes wider than 1B,
// but we probably should limit number of elements (rather than bytes) to
// keep specs simple
int constexpr maxTensorNumElements = 1048576;
LogicalResult verifyTensorSize(Operation *op);
} // namespace impl
template <class ConcreteType>
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyTensorSize(op);
}
};
template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultEncoding(op);
}
};
template <typename ConcreteType>
class SameOperandsEncoding
: public TraitBase<ConcreteType, SameOperandsEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -0,0 +1,68 @@
#ifndef TRITON_ATTR_DEFS
#define TRITON_ATTR_DEFS
include "mlir/IR/EnumAttr.td"
// Attrs for LoadOp
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {
let cppNamespace = "::mlir::triton";
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> {
let cppNamespace = "::mlir::triton";
}
// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
[
I32EnumAttrCase<"AND", 1, "and">,
I32EnumAttrCase<"OR", 2, "or">,
I32EnumAttrCase<"XOR", 3, "xor">,
I32EnumAttrCase<"ADD", 4, "add">,
I32EnumAttrCase<"FADD", 5, "fadd">,
I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin">,
I32EnumAttrCase<"XCHG", 10, "exch">
]> {
let cppNamespace = "::mlir::triton";
}
#endif

View File

@@ -0,0 +1,46 @@
#ifndef TRITON_DIALECT
#define TRITON_DIALECT
include "mlir/IR/OpBase.td"
def Triton_Dialect : Dialect {
let name = "tt";
let cppNamespace = "::mlir::triton";
let summary = "The Triton IR in MLIR";
let description = [{
Triton Dialect.
Dependent Dialects:
* Arithmetic:
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
* Math:
* exp, sin, cos, log, ...
* StructuredControlFlow:
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
}];
let dependentDialects = [
"arith::ArithmeticDialect",
"math::MathDialect",
"StandardOpsDialect",
"scf::SCFDialect",
// Since LLVM 15
// "cf::ControlFlowDialect",
// "func::FuncDialect"
];
let extraClassDeclaration = [{
void registerTypes();
}];
let hasConstantMaterializer = 1;
}
include "triton/Dialect/Triton/IR/TritonTypes.td"
#endif // TRITON_DIALECT

View File

@@ -0,0 +1,11 @@
#ifndef TRITON_INTERFACES
#define TRITON_INTERFACES
include "mlir/IR/OpBase.td"
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
#endif // TRITON_INTERFACES

View File

@@ -0,0 +1,423 @@
#ifndef TRITON_OPS
#define TRITON_OPS
include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
//
// Op Base
//
class TT_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
}
//
// CastOps
//
// Use cast ops in arith:
// bitcast
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
let arguments = (ins TT_I64Like:$from);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";
let arguments = (ins TT_PtrLike:$from);
let results = (outs TT_I64Like:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast between types of the same bitwidth";
let arguments = (ins TT_Type:$from);
let results = (outs TT_Type:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: Add verifier
}
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Floating point casting for custom types";
let description = [{
Floating point casting for custom types (F8).
F8 <-> FP16, BF16, FP32, FP64
}];
let arguments = (ins TT_FloatLike:$from);
let results = (outs TT_FloatLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: We need a verifier here.
}
//
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)";
}
//
// Load/Store Ops
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr", "getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none",
"result", "other", "$_self",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
// let assemblyFormat = "operands attr-dict `:` type($result)";
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
let hasCanonicalizer = 1;
}
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
SameOperandsEncoding,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
// let assemblyFormat = "operands attr-dict `:` type($value)";
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
let hasCanonicalizer = 1;
}
//
// Atomic Op
//
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"val", "ptr",
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"val", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "atomic rmw";
let description = [{
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
return old value at $ptr
}];
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
TT_Type:$val, Optional<TT_BoolLike>:$mask);
let results = (outs TT_Type:$result);
}
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>,
MemoryEffects<[MemWrite]>,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic cas";
let description = [{
compare $cmp with data $old at location $ptr,
if $old == $cmp, store $val to $ptr,
else store $old to $ptr,
return $old
}];
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}
//
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "expand_dims";
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Dot Op
//
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
let extraClassDeclaration = [{
// This member function is marked static because we need to call it before the ReduceOp
// is constructed, see the implementation of create_reduce in triton.cc.
static bool withIndex(mlir::triton::RedOp redOp);
}];
}
//
// External elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
//
// Make Range Op
//
// TODO: should have ConstantLike as Trait
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let summary = "make range";
let description = [{
Returns an 1D int32 tensor.
Values span from $start to $end (exclusive), with step = 1
}];
let arguments = (ins I32Attr:$start, I32Attr:$end);
let results = (outs TT_IntTensor:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict ($args^ `:` type($args))?
}];
}
#endif // Triton_OPS

View File

@@ -0,0 +1,71 @@
#ifndef TRITON_TYPES
#define TRITON_TYPES
include "triton/Dialect/Triton/IR/TritonDialect.td"
//
// Types
//
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
// Boolean Type
// TT_Bool -> I1
def TT_BoolTensor : TensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntTensor : TensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
// I32 Type
// TT_I32 -> I32
// TT_I32Tensor -> I32Tensor
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
// I64 Type
// TT_I64 -> I64
// TT_I64Tensor -> I64Tensor
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
// Pointer Type
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
Triton PointerType
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
#endif

View File

@@ -0,0 +1,10 @@
#ifndef TRITON_IR_TYPES_H_
#define TRITON_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
add_public_tablegen_target(TritonTransformsIncGen)

View File

@@ -0,0 +1,18 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
} // namespace triton
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
} // namespace mlir
#endif

View File

@@ -0,0 +1,23 @@
#ifndef TRITON_PASSES
#define TRITON_PASSES
include "mlir/Pass/PassBase.td"
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
let summary = "combine ops";
let description = [{
dot(a, b, 0) + c => dot(a, b, c)
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)
}];
let constructor = "mlir::triton::createCombineOpsPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
/*SelectOp*/"mlir::StandardOpsDialect"];
}
#endif

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,12 @@
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
add_public_tablegen_target(TritonGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

View File

@@ -0,0 +1,48 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace gpu {
unsigned getElemsPerThread(Type type);
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout);
SmallVector<unsigned> getWarpsPerCTA(const Attribute &layout);
SmallVector<unsigned> getSizePerThread(const Attribute &layout);
SmallVector<unsigned> getContigPerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout);
bool isaDistributedLayout(const Attribute &layout);
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -0,0 +1,31 @@
#ifndef TRITON_GPU_IR_TRAITS_H_
#define TRITON_GPU_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifyResultsAreSharedEncoding(Operation *op);
} // namespace impl
template <typename ConcreteType>
class ResultsAreSharedEncoding
: public TraitBase<ConcreteType, ResultsAreSharedEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreSharedEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -0,0 +1,481 @@
#ifndef TRITONGPU_ATTRDEFS
#define TRITONGPU_ATTRDEFS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
//===----------------------------------------------------------------------===//
class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
let description = [{
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
to the indices of the CUDA threads allowed to access some data at index $i$.
For example, let us consider the layout function:
\mathcal{L}(0, 0) = {0, 4}
\mathcal{L}(0, 1) = {1, 5}
\mathcal{L}(1, 0) = {2, 6}
\mathcal{L}(1, 1) = {3, 7}
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
- T[0,0] is owned by both cuda thread 0 and 4
- T[0,1] is owned by both cuda thread 1 and 5
- T[1,0] is owned by both cuda thread 2 and 6
- T[1,1] is owned by both cuda thread 3 and 7
Right now, Triton implements two classes of layouts: shared, and distributed.
}];
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> {
let mnemonic = "shared";
let description = [{
An encoding for tensors whose elements may be simultaneously accessed by
different cuda threads in the programs, via shared memory. In other words,
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
In order to avoid shared memory bank conflicts, elements may be swizzled
in memory. For example, a swizzled row-major layout could store its data
as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements
are stored contiguously
_ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
}];
let parameters = (
ins
// swizzle info
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
);
let builders = [
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"Type":$eltTy), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
int opIdx = dotOpEnc.getOpIdx();
// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
perPhase = std::max<int>(perPhase, 1);
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// ---- begin Volta ----
if (mmaEnc.isVolta()) {
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
// TODO[Superjomn]: Support the case when is_vec4=false later
// Currently, we only support ld.v2, for the mma layout varies with different ld vector width.
is_vec4 = true;
int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) :
((is_row && !is_vec4) ? 2 : 1);
int rep = 2 * pack_size;
int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase;
int vec = 2 * rep;
return $_get(context, vec, perPhase, maxPhase, order);
}
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
return $_get(context, 1, 1, 1, order);
// --- handle A operand ---
if (opIdx == 0) { // compute swizzling for A operand
int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m
int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
// --- handle B operand ---
if (opIdx == 1) {
int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k
int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1];
int maxPhase = mmaStride / perPhase;
return $_get(context, vec, perPhase, maxPhase, order);
}
llvm_unreachable("invalid operand index");
}
// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
// Distributed Layout Encoding
//===----------------------------------------------------------------------===//
class DistributedEncoding<string name> : TritonGPU_Attr<name> {
let description = [{
Distributed encodings have a layout function that is entirely characterized
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
(or even the same rank) as the tensor it is encoding.
The layout function \mathcal{L} of this layout is then defined, for an
index `i` \in R^D, as follows:
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
For example, for a tensor/layout pair
A = [x x x x x x x x]
[x x x x x x x x]
L = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
}];
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
let mnemonic = "blocked";
let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
used to promote memory coalescing in LoadInst and StoreInst.
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
}>
}];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank - 1; ++_dim) {
int i = order[_dim];
unsigned threadsPerCTA = std::clamp<unsigned>(remainingThreads, 1, shape[i] / sizePerThread[i]);
threadsPerWarp[i] = std::clamp<unsigned>(threadsPerCTA, 1, remainingLanes);
warpsPerCTA[i] = std::clamp<unsigned>(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps);
remainingWarps /= warpsPerCTA[i];
remainingLanes /= threadsPerWarp[i];
remainingThreads /= threadsPerCTA;
prevLanes *= threadsPerWarp[i];
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
// fastest-changing axis first
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
// These attributes can be inferred from the rest
// ArrayRefParameter<"unsigned">:$sizePerWarp,
// ArrayRefParameter<"unsigned">:$sizePerCTA
);
}
//===----------------------------------------------------------------------===//
// MMA Layout Encoding
//===----------------------------------------------------------------------===//
// TODO: MMAv1 and MMAv2 should be two instances of the same class
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
let mnemonic = "mma";
let description = [{
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'versionMajor' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A 'versionMinor' which indicates the specific layout of a tensor core
generation, e.g. for Volta, there might be multiple kinds of layouts annotated
by 0,1,2 and so on.
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Note: the layout is different from the recommended in PTX ISA
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, when versionMinor=1, the matrix L corresponding to
blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ]
[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ]
[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ]
[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ]
[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ]
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ]
[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ]
[ ............................................................... ]
// -------------------------------- version = 2 --------------------------- //
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.16816 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0 warp 1
-----------------/\------------- ----------------/\-------------
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
warp 3 warp 4
----------------/\------------- ----------------/\-------------
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
}];
let parameters = (
ins
"unsigned":$versionMajor,
"unsigned":$versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
let builders = [
// specific for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"ArrayRef<unsigned>":$warpsPerCTA,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"bool":$isARow,
"bool":$isBRow), [{
assert(versionMajor == 1 && "Only MMAv1 has multiple versionMinor.");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool isVolta() const;
bool isAmpere() const;
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool> decodeVoltaLayoutStates() const;
}];
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
let mnemonic = "slice";
let description = [{
TODO: improve docs
A = [x x x x x x x x]
parent = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
dim = 0
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
}];
let parameters = (
ins
"unsigned":$dim,
// TODO: constraint here to only take distributed encodings
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration # [{
template<class T>
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
}];
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";
let description = [{
In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
For MMA v1, an additional attribute `isMMAv1Row` determines whether e.g. the a operand is used
in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation
section 9.7.13.4.1 for more details.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent,
"Attribute":$isMMAv1Row
);
let builders = [
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -0,0 +1,36 @@
#ifndef TRITONGPU_DIALECT
#define TRITONGPU_DIALECT
include "mlir/IR/OpBase.td"
def TritonGPU_Dialect : Dialect {
let name = "triton_gpu";
let cppNamespace = "::mlir::triton::gpu";
let hasOperationAttrVerify = 1;
let description = [{
Triton GPU Dialect.
}];
let dependentDialects = [
"triton::TritonDialect",
"mlir::gpu::GPUDialect",
"tensor::TensorDialect",
];
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
}];
}
#endif

View File

@@ -0,0 +1,198 @@
#ifndef TRITONGPU_OPS
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape, NoSideEffect]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let summary = "async wait";
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static bool isSupported(int computeCapability) {
return computeCapability >= 80;
}
}];
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntLike:$lhs,
TT_IntLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
TT_FloatLike:$lhs,
TT_FloatLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";
let description = [{}];
let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
let results = (outs TT_Tensor:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[AttrSizedOperandSegments,
ResultsAreSharedEncoding,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
let summary = "insert slice async";
let description = [{
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operations
`$index` argument and `$axis` attribute.
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
When converting from `tt.load` to `triton_gpu.insert_slice_async`, the `$evict`, `$cache`, and `$isVolatile` fields
might be ignored on certain hardware. For example, on NVIDIA GPUs, the cache policy is determined by the backend,
and `$evict` and `$isVolatile` are ignored because they apply to L1 cache only.
The insert_slice_async operation supports the following arguments:
* src: the tensor that is inserted.
* dst: the tensor into which the `$src` tensor is inserted.
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
* mask: optional tensor-rank number of boolean masks which specify which
elements of the `$src` tensor are inserted into the `$dst` tensor.
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
In the future, we may decompose this operation into a sequence of:
* `async` operation to specify a sequence of asynchronous operations
* `load` operation to load a tensor from global memory
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
Example:
```
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
%2 = triton_gpu.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
triiton_gpu.async_wait { num = 0 : i32 }
```
}];
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let builders = [
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"Value":$mask, "Value":$other,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
];
let results = (outs TT_Tensor:$result);
//let assemblyFormat = [{
// $src `,` $dst ``
// $index, $mask, $other
// attr-dict `:` type($src) `->` type($dst)
//}];
let extraClassDeclaration = [{
static DenseSet<unsigned> getEligibleLoadByteWidth(int computeCapability) {
DenseSet<unsigned> validLoadBytes;
if (computeCapability >= 80) {
validLoadBytes = {4, 8, 16};
}
return validLoadBytes;
}
}];
// The custom parser could be replaced with oilist in LLVM-16
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
ResultsAreSharedEncoding]> {
let summary = "allocate tensor";
let description = [{
This operation defines a tensor of a particular shape.
The contents of the tensor are supposed to be in shared memory.
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let results = (outs TT_Tensor:$result);
}
#endif

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU)
add_public_tablegen_target(TritonGPUTransformsIncGen)

View File

@@ -0,0 +1,25 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
// TODO(Keren): prefetch pass not working yet
std::unique_ptr<Pass> createTritonGPUPrefetchPass();
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
std::unique_ptr<Pass> createTritonGPUVerifier();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
} // namespace mlir
#endif

View File

@@ -0,0 +1,87 @@
#ifndef TRITONGPU_PASSES
#define TRITONGPU_PASSES
include "mlir/Pass/PassBase.td"
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
Unroll loops to hide global memory -> shared memory latency.
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"2",
"number of pipeline stages">
];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";
let description = [{
Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency.
}];
let constructor = "mlir::createTritonGPUPrefetchPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
let description = [{
TODO
}];
let constructor = "mlir::createTritonGPUCoalescePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";
let description = [{
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
convert_layout(%src, #LAYOUT_1)
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
}];
let constructor = "mlir::createTritonGPUCombineOpsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
let summary = "canonicalize scf.ForOp ops";
let description = [{
This implements some optimizations that are missing in the standard scf.ForOp
canonicalizer.
}];
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif

View File

@@ -0,0 +1,33 @@
//===----------------------------------------------------------------------===//
//
// Defines utilities to use while converting to the TritonGPU dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
int getNumWarps() const { return numWarps; }
private:
MLIRContext *context;
int numWarps;
};
class TritonGPUConversionTarget : public ConversionTarget {
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx,
TritonGPUTypeConverter &typeConverter);
};
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_

View File

@@ -0,0 +1,37 @@
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIRTRANSLATION_H
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
#include <vector>
namespace llvm {
class Module;
class LLVMContext;
} // namespace llvm
namespace mlir {
class ModuleOp;
} // namespace mlir
namespace mlir {
namespace triton {
// add external dependent libs
void addExternalLibs(mlir::ModuleOp &module,
const std::vector<std::string> &names,
const std::vector<std::string> &paths);
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
} // namespace triton
} // namespace mlir
#endif // TRITON_TARGET_LLVMIRTRANSLATION_H

View File

@@ -0,0 +1,17 @@
#ifndef TRITON_TARGET_PTXTRANSLATION_H
#define TRITON_TARGET_PTXTRANSLATION_H
#include <string>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate TritonGPU IR to PTX code.
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
} // namespace triton
#endif

View File

@@ -22,35 +22,32 @@
#ifndef TDL_TOOLS_SYS_GETENV_HPP #ifndef TDL_TOOLS_SYS_GETENV_HPP
#define TDL_TOOLS_SYS_GETENV_HPP #define TDL_TOOLS_SYS_GETENV_HPP
#include <string> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <string>
namespace triton namespace triton {
{
namespace tools namespace tools {
{
inline std::string getenv(const char * name)
{
#ifdef _MSC_VER
char* cache_path = 0;
std::size_t sz = 0;
_dupenv_s(&cache_path, &sz, name);
#else
const char * cstr = std::getenv(name);
#endif
if(!cstr)
return "";
std::string result(cstr);
#ifdef _MSC_VER
free(cache_path);
#endif
return result;
}
inline std::string getenv(const char *name) {
const char *cstr = std::getenv(name);
if (!cstr)
return "";
std::string result(cstr);
return result;
} }
inline bool getBoolEnv(const std::string &env) {
const char *s = std::getenv(env.c_str());
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
} }
} // namespace tools
} // namespace triton
#endif #endif

View File

@@ -1,80 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map>
#include <vector>
namespace triton {
namespace ir {
class value;
class module;
class phi_node;
class splat_inst;
class cast_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class align {
private:
struct cst_info {
unsigned num_cst;
unsigned value;
};
// helpers
std::vector<unsigned> get_shapes(ir::value *v);
// populate is_constant
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps
void populate(ir::value *v);
public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
};
}
}
}
#endif

View File

@@ -1,47 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#include <map>
#include <set>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
namespace triton{
namespace ir{
class value;
class function;
class module;
}
namespace codegen{
namespace analysis{
class tiles;
class liveness;
class cts;
class allocation {
public:
allocation(liveness *live)
: liveness_(live) { }
// accessors
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
void run(ir::module& mod);
private:
std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_;
// dependences
liveness *liveness_;
};
}
}
}
#endif

View File

@@ -1,52 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
#include "triton/tools/graph.h"
#include <map>
#include <vector>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
private:
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async=false);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);
public:
axes();
void run(ir::module &mod);
// accessors
int get(ir::value *value, unsigned dim);
std::vector<int> get(ir::value *value);
private:
tools::graph<node_t> graph_;
std::map<node_t, size_t> axes_;
};
}
}
}
#endif

View File

@@ -1,265 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
#include "triton/codegen/target.h"
namespace triton{
namespace ir{
class value;
class type;
class module;
class instruction;
class phi_node;
}
namespace codegen{
namespace analysis{
class axes;
class align;
class layout_visitor;
class data_layout;
class mma_layout;
class scanline_layout;
class shared_layout;
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_mma(mma_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
MMA,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
std::vector<int> get_axes() const { return axes_; }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
int find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); }
private:
// fragment per warp
std::vector<int> fpw_;
// shape per warp
std::vector<int> spw_;
// warp per tile
std::vector<int> wpt_;
// shape per tile
std::vector<int> spt_;
// repetitions
std::vector<int> rep_;
};
struct scanline_layout: public distributed_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align,
target* tgt);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
public:
// micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_;
// nano tile size. The size of a tile held by a thread.
std::vector<int> nts_;
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
struct N_buffer_info_t {
std::vector<ir::value*> firsts; // not necessarily ordered as input order
ir::value* latch;
ir::phi_node* phi;
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
public:
shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
int get_num_stages() const;
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
size_t get_per_stage_elements() const;
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
data_layout* get_arg_layout() { return arg_layout_; }
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
std::shared_ptr<N_buffer_info_t> N_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
};
class layouts {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// execution
void run(ir::module &mod);
private:
analysis::axes* axes_;
analysis::align* align_;
size_t num_warps_;
target* tgt_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
};
}
}
}
#endif

View File

@@ -1,67 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
namespace triton{
namespace ir{
class value;
class phi_node;
class function;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
typedef unsigned slot_index;
class tiles;
class layouts;
class data_layout;
struct segment {
slot_index start;
slot_index end;
bool contains(slot_index idx) const {
return start <= idx && idx < end;
}
bool intersect(const segment &Other){
return contains(Other.start) || Other.contains(start);
}
};
class liveness {
private:
typedef std::map<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); }
// run
void run(ir::module &mod);
private:
// analysis
layouts *layouts_;
intervals_map_t intervals_;
};
}
}
}
#endif

View File

@@ -1,43 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#include <map>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class target;
namespace analysis{
class layouts;
class data_layout;
class swizzle {
public:
// constructor
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
// accessors
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
int get_vec (data_layout* layout) { return vec_.at(layout); }
// run
void run(ir::module &mod);
private:
layouts* layouts_;
target* tgt_;
std::map<data_layout*, int> per_phase_;
std::map<data_layout*, int> max_phase_;
std::map<data_layout*, int> vec_;
};
}
}
}
#endif

View File

@@ -1,42 +0,0 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <memory>
namespace llvm{
class Module;
class LLVMContext;
}
namespace triton{
namespace codegen {
class target;
}
namespace ir{
class module;
}
namespace driver{
class device;
class module;
class kernel;
}
}
namespace triton{
namespace codegen{
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target,
int sm, int num_warps,
int num_stages, bool force_nc_cache, int &shared_static);
}
}
#endif

View File

@@ -1,258 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_GENERATOR_H_
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class PHINode;
class BasicBlock;
class Attribute;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class attribute;
class load_inst;
class store_inst;
}
namespace codegen{
// forward
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
class swizzle;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Attribute Attribute;
typedef llvm::BasicBlock BasicBlock;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
class target;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class adder{
public:
adder(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class multiplier{
public:
multiplier(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class geper{
public:
geper(Builder** builder): builder_(builder) { }
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
private:
Builder** builder_;
};
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void init_idx(ir::value *x);
Instruction* add_barrier();
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps,
bool force_nc_cache = false);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_load_inst(ir::load_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_cos_inst(ir::cos_inst*);
void visit_sin_inst(ir::sin_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
target *tgt_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *shmem_;
std::set<ir::value*> seen_;
unsigned num_warps_;
bool force_nc_cache_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
/// layout -> base ptr
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
/// offset for double-buffered layout
std::map<analysis::data_layout*, Value*> shared_off_;
/// Base shmem pointer of ir value
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
// helper for creating llvm values
adder add;
multiplier mul;
geper gep;
/// PHI nodes
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
};
}
}
#endif

View File

@@ -1,105 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
class nvidia_cu_target;
class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
virtual ~target() {}
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
nvidia_cu_target* as_nvidia();
bool is_gpu() const;
private:
bool is_gpu_;
};
class amd_cl_target: public target {
public:
amd_cl_target(): target(true){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class nvidia_cu_target: public target {
public:
nvidia_cu_target(int sm): target(true), sm_(sm){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
int sm() { return sm_; }
unsigned guaranteed_alignment() { return 16; }
private:
int sm_;
};
class cpu_target: public target {
public:
cpu_target(): target(false){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 1; }
};
}
}
#endif

View File

@@ -1,48 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace ir {
class module;
class value;
class io_inst;
class instruction;
class builder;
}
namespace codegen{
namespace analysis{
class align;
class layouts;
class cts;
}
namespace transform{
class coalesce {
private:
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
analysis::align* align_;
analysis::layouts* layout_;
};
}
}
}
#endif

View File

@@ -1,36 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#include <set>
#include <map>
namespace triton {
namespace ir {
class module;
class value;
class phi_node;
class instruction;
class builder;
}
namespace codegen{
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
public:
cts(bool use_async = false): use_async_(use_async) {}
void run(ir::module &mod);
private:
bool use_async_;
};
}
}
}
#endif

View File

@@ -1,24 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class dce {
public:
dce() {}
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -1,22 +0,0 @@
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class disassociate {
public:
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -1,72 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
namespace codegen{
namespace analysis{
class allocation;
class liveness;
class layouts;
class cts;
class shared_layout;
}
namespace transform{
class prefetch;
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
bool intersect(const val_set_t &X, const val_set_t &Y);
bool check_safe_war(ir::instruction* i);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
transform::prefetch *prefetch, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
transform::prefetch *prefetch_;
target* tgt_;
};
}
}
}
#endif

View File

@@ -1,53 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class value;
class instruction;
class trans_inst;
class builder;
class constant_int;
class dot_inst;
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
private:
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};
}
}
}
#endif

View File

@@ -1,30 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async, int num_stages)
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
void run(ir::module &module);
private:
bool has_copy_async_;
int num_stages_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -1,27 +0,0 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#include <set>
// forward dclaration
namespace triton::ir{
class module;
class value;
}
namespace triton::codegen {
class target;
}
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
std::set<ir::value*> prefetched_vals_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
};
}
#endif

View File

@@ -1,26 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
namespace triton {
// forward declaration
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class reorder {
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@@ -1,316 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_DISPATCH_H_
#define _TRITON_DRIVER_DISPATCH_H_
#include <type_traits>
#include <dlfcn.h>
//CUDA Backend
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions
#include <iostream>
#include <stdexcept>
namespace llvm {
class PassRegistry;
class Module;
}
namespace triton
{
namespace driver
{
class cu_context;
template<class T> void check(T){}
void check(CUresult err);
void check(hipError_t err);
class dispatch
{
protected:
template <class F>
struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
initializer();
if(cache == nullptr){
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
check(res);
return res;
}
public:
static void release();
// Nvidia
static bool nvmlinit();
static bool cuinit();
// AMD
static bool hipinit();
/* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
/* ------------------- *
* HIP
* ------------------- */
// context management
static hipError_t hipInit(unsigned int Flags);
static hipError_t hipCtxDestroy(hipCtx_t ctx);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
static hipError_t hipCtxGetDevice(hipDevice_t* result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
static hipError_t hipDriverGetVersion(int *driverVersion);
// device management
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
static hipError_t hipGetDeviceCount(int *count);
// module management
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
static hipError_t hipModuleUnload(hipModule_t hmod);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
// stream management
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
static hipError_t hipStreamSynchronize(hipStream_t hStream);
static hipError_t hipStreamDestroy(hipStream_t hStream);
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
// function management
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
// memory management
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
static hipError_t hipFree(hipDeviceptr_t dptr);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
// event management
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
private:
// Libraries
static void* cuda_;
static void* nvml_;
static void* hip_;
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuCtxCreate_v2_;
static void* cuCtxGetDevice_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuCtxEnablePeerAccess_;
static void* cuDriverGetVersion_;
static void* cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
// link management
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;
static void* cuLinkComplete_;
// module management
static void* cuModuleGetGlobal_v2_;
static void* cuModuleLoad_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuModuleLoadData_;
static void* cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuLaunchKernel_;
// function management
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
// memory management
static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
/* ------------------- *
* HIP
* ------------------- */
// context management
static void* hipInit_;
static void* hipCtxDestroy_;
static void* hipCtxCreate_;
static void* hipCtxPushCurrent_;
static void* hipCtxPopCurrent_;
static void* hipCtxGetDevice_;
static void* hipCtxEnablePeerAccess_;
static void* hipDriverGetVersion_;
// device management
static void* hipGetDevice_;
static void* hipDeviceGetName_;
static void* hipDeviceGetPCIBusId_;
static void* hipDeviceGetAttribute_;
static void* hipGetDeviceCount_;
// module management
static void* hipModuleGetGlobal_;
static void* hipModuleLoad_;
static void* hipModuleLoadData_;
static void* hipModuleUnload_;
static void* hipModuleLoadDataEx_;
static void* hipModuleGetFunction_;
// stream management
static void* hipStreamCreate_;
static void* hipStreamSynchronize_;
static void* hipStreamDestroy_;
static void* hipModuleLaunchKernel_;;
// function management
static void* hipFuncGetAttributes_;
static void* hipFuncSetAttribute_;
static void* hipFuncSetCacheConfig_;
// memory management
static void* hipMalloc_;
static void* hipPointerGetAttribute_;
static void* hipMemsetD8Async_;
static void* hipMemcpyDtoH_;
static void* hipFree_;
static void* hipMemcpyDtoHAsync_;
static void* hipMemcpyHtoDAsync_;
static void* hipMemcpyHtoD_;
// event management
static void* hipEventCreate_;
static void* hipEventElapsedTime_;
static void* hipEventRecord_;
static void* hipEventDestroy_;
};
}
}
#endif

View File

@@ -1,220 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_ERROR_H_
#define _TRITON_DRIVER_ERROR_H_
#include <exception>
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
namespace exception
{
namespace nvrtc
{
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
#undef TRITON_CREATE_NVRTC_EXCEPTION
}
namespace cuda
{
class base: public std::exception{};
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
namespace cublas
{
class base: public std::exception{};
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUBLAS_EXCEPTION
}
namespace cudnn
{
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
}
namespace hip
{
class base: public std::exception{};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "HIP: Error- " msg; } }
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
}
}
}
#endif

View File

@@ -1,19 +0,0 @@
#include <string>
#include "triton/driver/dispatch.h"
namespace llvm{
class Module;
}
namespace triton{
namespace driver{
void init_llvm();
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,288 +0,0 @@
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
// we can make that yield a warning
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
#include <cstddef>
typedef enum hipError_t {
hipSuccess = 0, ///< Successful completion.
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
///< or not in an acceptable range.
hipErrorOutOfMemory = 2,
// Deprecated
hipErrorMemoryAllocation = 2, ///< Memory allocation error.
hipErrorNotInitialized = 3,
// Deprecated
hipErrorInitializationError = 3,
hipErrorDeinitialized = 4,
hipErrorProfilerDisabled = 5,
hipErrorProfilerNotInitialized = 6,
hipErrorProfilerAlreadyStarted = 7,
hipErrorProfilerAlreadyStopped = 8,
hipErrorInvalidConfiguration = 9,
hipErrorInvalidPitchValue = 12,
hipErrorInvalidSymbol = 13,
hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer
hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction
hipErrorInsufficientDriver = 35,
hipErrorMissingConfiguration = 52,
hipErrorPriorLaunchFailure = 53,
hipErrorInvalidDeviceFunction = 98,
hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices
hipErrorInvalidDevice = 101, ///< DeviceID must be in range 0...#compute-devices.
hipErrorInvalidImage = 200,
hipErrorInvalidContext = 201, ///< Produced when input context is invalid.
hipErrorContextAlreadyCurrent = 202,
hipErrorMapFailed = 205,
// Deprecated
hipErrorMapBufferObjectFailed = 205, ///< Produced when the IPC memory attach failed from ROCr.
hipErrorUnmapFailed = 206,
hipErrorArrayIsMapped = 207,
hipErrorAlreadyMapped = 208,
hipErrorNoBinaryForGpu = 209,
hipErrorAlreadyAcquired = 210,
hipErrorNotMapped = 211,
hipErrorNotMappedAsArray = 212,
hipErrorNotMappedAsPointer = 213,
hipErrorECCNotCorrectable = 214,
hipErrorUnsupportedLimit = 215,
hipErrorContextAlreadyInUse = 216,
hipErrorPeerAccessUnsupported = 217,
hipErrorInvalidKernelFile = 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX
hipErrorInvalidGraphicsContext = 219,
hipErrorInvalidSource = 300,
hipErrorFileNotFound = 301,
hipErrorSharedObjectSymbolNotFound = 302,
hipErrorSharedObjectInitFailed = 303,
hipErrorOperatingSystem = 304,
hipErrorInvalidHandle = 400,
// Deprecated
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
hipErrorNotFound = 500,
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
///< ready. This is not actually an error, but is used to distinguish
///< from hipSuccess (which indicates completion). APIs that return
///< this error include hipEventQuery and hipStreamQuery.
hipErrorIllegalAddress = 700,
hipErrorLaunchOutOfResources = 701, ///< Out of resources error.
hipErrorLaunchTimeOut = 702,
hipErrorPeerAccessAlreadyEnabled =
704, ///< Peer access was already enabled from the current device.
hipErrorPeerAccessNotEnabled =
705, ///< Peer access was never enabled from the current device.
hipErrorSetOnActiveProcess = 708,
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
hipErrorHostMemoryAlreadyRegistered =
712, ///< Produced when trying to lock a page-locked memory.
hipErrorHostMemoryNotRegistered =
713, ///< Produced when trying to unlock a non-page-locked memory.
hipErrorLaunchFailure =
719, ///< An exception occurred on the device while executing a kernel.
hipErrorCooperativeLaunchTooLarge =
720, ///< This error indicates that the number of blocks launched per grid for a kernel
///< that was launched via cooperative launch APIs exceeds the maximum number of
///< allowed blocks for the current device
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
hipErrorUnknown = 999, //< Unknown error.
// HSA Runtime Error Codes start here.
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
///< in production systems.
hipErrorRuntimeOther = 1053, ///< HSA runtime call other than memory returned error. Typically
///< not seen in production systems.
hipErrorTbd ///< Marker that more error codes are needed.
} hipError_t;
typedef struct ihipCtx_t* hipCtx_t;
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
typedef int hipDevice_t;
typedef enum hipDeviceP2PAttr {
hipDevP2PAttrPerformanceRank = 0,
hipDevP2PAttrAccessSupported,
hipDevP2PAttrNativeAtomicSupported,
hipDevP2PAttrHipArrayAccessSupported
} hipDeviceP2PAttr;
typedef struct ihipStream_t* hipStream_t;
#define hipIpcMemLazyEnablePeerAccess 0
#define HIP_IPC_HANDLE_SIZE 64
typedef struct hipIpcMemHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcMemHandle_t;
typedef struct hipIpcEventHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcEventHandle_t;
typedef struct ihipModule_t* hipModule_t;
typedef struct ihipModuleSymbol_t* hipFunction_t;
typedef struct hipFuncAttributes {
int binaryVersion;
int cacheModeCA;
size_t constSizeBytes;
size_t localSizeBytes;
int maxDynamicSharedSizeBytes;
int maxThreadsPerBlock;
int numRegs;
int preferredShmemCarveout;
int ptxVersion;
size_t sharedSizeBytes;
} hipFuncAttributes;
typedef struct ihipEvent_t* hipEvent_t;
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
///< bytes.
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
///< thread block. This number is shared by all thread
///< blocks simultaneously resident on a
///< multiprocessor.
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
///< cache.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
///< multiprocessor.
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
///< concurrently.
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
///< Multiprocessor.
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeIntegrated, ///< iGPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///devices with unmatched shared memories
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
/// the device without migration
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
/// concurrently with the CPU
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
/// without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
/// the host's page tables
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64() , '0' otherwise.
} hipDeviceAttribute_t;
typedef void* hipDeviceptr_t;
/*
* @brief hipJitOption
* @enum
* @ingroup Enumerations
*/
typedef enum hipJitOption {
hipJitOptionMaxRegisters = 0,
hipJitOptionThreadsPerBlock,
hipJitOptionWallTime,
hipJitOptionInfoLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionOptimizationLevel,
hipJitOptionTargetFromContext,
hipJitOptionTarget,
hipJitOptionFallbackStrategy,
hipJitOptionGenerateDebugInfo,
hipJitOptionLogVerbose,
hipJitOptionGenerateLineInfo,
hipJitOptionCacheMode,
hipJitOptionSm3xOpt,
hipJitOptionFastCompile,
hipJitOptionNumOptions
} hipJitOption;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncAttribute {
hipFuncAttributeMaxDynamicSharedMemorySize = 8,
hipFuncAttributePreferredSharedMemoryCarveout = 9,
hipFuncAttributeMax
} hipFuncAttribute;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncCache_t {
hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default)
hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 cache
hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
} hipFuncCache_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)

View File

@@ -1,88 +0,0 @@
#pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_
#include <string>
#include <list>
#include "value.h"
#include "visitor.h"
namespace triton{
namespace ir{
class context;
class function;
class instruction;
/* Basic Block */
class basic_block: public value{
public:
// instruction iterator types
typedef std::list<instruction*> inst_list_t;
typedef inst_list_t::iterator iterator;
typedef inst_list_t::const_iterator const_iterator;
typedef inst_list_t::reverse_iterator reverse_iterator;
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
private:
// constructors
basic_block(context &ctx, const std::string &name, function *parent);
public:
// accessors
function* get_parent() { return parent_; }
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
iterator get_first_non_phi();
// get instruction list
inst_list_t &get_inst_list() { return inst_list_; }
const inst_list_t &get_inst_list() const { return inst_list_; }
void erase(instruction *i) { inst_list_.remove(i); }
// instruction iterator functions
inline iterator begin() { return inst_list_.begin(); }
inline const_iterator begin() const { return inst_list_.begin(); }
inline iterator end () { return inst_list_.end(); }
inline const_iterator end () const { return inst_list_.end(); }
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
inline reverse_iterator rend () { return inst_list_.rend(); }
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
inline size_t size() const { return inst_list_.size(); }
inline bool empty() const { return inst_list_.empty(); }
inline const instruction &front() const { return *inst_list_.front(); }
inline instruction &front() { return *inst_list_.front(); }
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
// predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
const std::vector<basic_block*>& get_successors() const { return succs_; }
void add_predecessor(basic_block* pred);
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent);
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_basic_block(this); }
private:
context &ctx_;
std::string name_;
function *parent_;
std::vector<basic_block*> preds_;
std::vector<basic_block*> succs_;
inst_list_t inst_list_;
};
}
}
#endif

View File

@@ -1,173 +0,0 @@
#pragma once
#ifndef _TRITON_IR_BUILDER_H_
#define _TRITON_IR_BUILDER_H_
#include <vector>
#include <string>
#include "instructions.h"
#include "basic_block.h"
#include "type.h"
namespace triton{
namespace ir{
class basic_block;
class value;
class type;
class constant_int;
class instruction;
class context;
class phi_node;
/* Builder */
class builder{
typedef basic_block::iterator iterator;
public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
void set_insert_point_after(instruction* i);
void set_insert_point(basic_block* block);
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();
type *get_int1_ty();
type *get_int8_ty();
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_half_ty();
type *get_float_ty();
type *get_double_ty();
// Insert
template<typename InstTy>
InstTy* insert(InstTy *inst){
assert(block_);
block_->get_inst_list().insert(insert_point_, inst);
inst->set_parent(block_);
// for(ir::value* op: inst->ops())
// op->add_use(inst);
return inst;
}
// terminator instructions
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Cast instructions
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
value* create_fp_to_si(value *src, type *dst_ty);
value* create_fp_to_ui(value *src, type *dst_ty);
value* create_fp_ext(value *src, type *dst_ty);
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs);
value *create_fdiv(value *lhs, value *rhs);
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
// Comparison (int)
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_icmpSLE(value *lhs, value *rhs);
value *create_icmpSLT(value *lhs, value *rhs);
value *create_icmpSGE(value *lhs, value *rhs);
value *create_icmpSGT(value *lhs, value *rhs);
value *create_icmpULE(value *lhs, value *rhs);
value *create_icmpULT(value *lhs, value *rhs);
value *create_icmpUGE(value *lhs, value *rhs);
value *create_icmpUGT(value *lhs, value *rhs);
value *create_icmpEQ(value *lhs, value *rhs);
value *create_icmpNE(value *lhs, value *rhs);
// Comparison (float)
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_fcmpOLT(value *lhs, value *rhs);
value *create_fcmpOGT(value *lhs, value *rhs);
value *create_fcmpOLE(value *lhs, value *rhs);
value *create_fcmpOGE(value *lhs, value *rhs);
value *create_fcmpOEQ(value *lhs, value *rhs);
value *create_fcmpONE(value *lhs, value *rhs);
value *create_fcmpULT(value *lhs, value *rhs);
value *create_fcmpUGT(value *lhs, value *rhs);
value *create_fcmpULE(value *lhs, value *rhs);
value *create_fcmpUGE(value *lhs, value *rhs);
value *create_fcmpUEQ(value *lhs, value *rhs);
value *create_fcmpUNE(value *lhs, value *rhs);
// Logical
value *create_and(value *lhs, value *rhs);
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value);
value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
value *create_prefetch_s(value *arg, int inc);
private:
context &ctx_;
basic_block *block_;
iterator insert_point_;
};
}
}
#endif

View File

@@ -1,113 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONSTANT_H_
#define _TRITON_IR_CONSTANT_H_
#include "enums.h"
#include "value.h"
#include <cassert>
#include "visitor.h"
namespace triton{
namespace ir{
class type;
class context;
/* Constant */
class constant: public user{
protected:
using user::user;
public:
static constant* get_all_ones_value(type *ty);
static constant* get_null_value(type *ty);
virtual std::string repr() const = 0;
};
/* Undef value */
class undef_value: public constant{
private:
undef_value(type *ty);
public:
static undef_value* get(type* ty);
std::string repr() const { return "undef"; }
void accept(visitor* vst) { vst->visit_undef_value(this); }
};
/* Constant int */
class constant_int: public constant{
protected:
constant_int(type *ty, uint64_t value);
public:
virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_int(this); }
protected:
uint64_t value_;
};
/* Constant fp */
class constant_fp: public constant{
constant_fp(type *ty, double value);
public:
double get_value() { return value_; }
static constant* get_negative_zero(type *ty);
static constant* get_zero_value_for_negation(type *ty);
static constant* get(context &ctx, double v);
static constant* get(type *ty, double v);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_fp(this); }
private:
double value_;
};
/* Global Value */
class global_value: public constant {
public:
enum linkage_types_t {
external
};
public:
global_value(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space);
std::string repr() const { return get_name(); }
private:
linkage_types_t linkage_;
};
/* global object */
class global_object: public global_value {
public:
global_object(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space = 0);
std::string repr() const { return get_name(); }
};
/* global variable */
class alloc_const: public global_object {
public:
alloc_const(type *ty, constant_int *size,
const std::string &name = "");
std::string repr() const { return get_name(); }
void accept(visitor* vst) { vst->visit_alloc_const(this); }
};
}
}
#endif

View File

@@ -1,31 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_H_
#define _TRITON_IR_CONTEXT_H_
#include <memory>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class builder;
class type;
class context_impl;
/* Context */
class context {
public:
context();
context(const context&) = delete;
context& operator=(const context&) = delete;
public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl;
};
}
}
#endif

View File

@@ -1,48 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_
#include <map>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class context;
class constant;
class constant_int;
class constant_fp;
class undef_value;
/* Context impl */
class context_impl {
public:
// constructors
context_impl(context &ctx);
public:
// non-numeric types
type void_ty, label_ty;
// floating point types
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
// Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
// undef values
std::map<type*, undef_value*> uv_constants_;
};
}
}
#endif

View File

@@ -1,108 +0,0 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
// math
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -1,173 +0,0 @@
#pragma once
#ifndef _TRITON_IR_ENUMS_H_
#define _TRITON_IR_ENUMS_H_
namespace triton{
namespace ir{
enum binary_op_t: unsigned int{
Add,
FAdd,
Sub,
FSub,
Mul,
FMul,
UDiv,
SDiv,
FDiv,
URem,
SRem,
FRem,
Shl,
LShr,
AShr,
And,
Or,
Xor
};
enum class atomic_rmw_op_t: unsigned int{
And,
Or,
Xor,
Add,
Max,
Min,
UMax,
UMin,
FAdd,
Xchg,
};
enum cast_op_t: unsigned int {
Trunc,
ZExt,
SExt,
FPTrunc,
FPExt,
UIToFP,
SIToFP,
FPToUI,
FPToSI,
PtrToInt,
IntToPtr,
BitCast,
AddrSpaceCast
};
enum cmp_pred_t: unsigned int {
FIRST_FCMP_PREDICATE,
FCMP_FALSE,
FCMP_OEQ,
FCMP_OGT,
FCMP_OGE,
FCMP_OLT,
FCMP_OLE,
FCMP_ONE,
FCMP_ORD,
FCMP_UNO,
FCMP_UEQ,
FCMP_UGT,
FCMP_UGE,
FCMP_ULT,
FCMP_ULE,
FCMP_UNE,
FCMP_TRUE,
LAST_FCMP_PREDICATE,
FIRST_ICMP_PREDICATE,
ICMP_EQ,
ICMP_NE,
ICMP_UGT,
ICMP_UGE,
ICMP_ULT,
ICMP_ULE,
ICMP_SGT,
ICMP_SGE,
ICMP_SLT,
ICMP_SLE,
LAST_ICMP_PREDICATE
};
enum value_id_t: unsigned {
/* ------------ *
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// phi
INST_PHI,
// arithmetic
INST_BINOP,
INST_GETELEMENTPTR,
INST_SELECT,
INST_SQRT,
// cmp
INST_ICMP,
INST_FCMP,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
INST_CAST_SEXT,
INST_CAST_FP_TRUNC,
INST_CAST_FP_EXT,
INST_CAST_UI_TO_FP,
INST_CAST_SI_TO_FP,
INST_CAST_FP_TO_UI,
INST_CAST_FP_TO_SI,
INST_CAST_PTR_TO_INT,
INST_CAST_INT_TO_PTR,
INST_CAST_BIT_CAST,
INST_CAST_ADDR_SPACE_CAST,
// terminators
INST_RETURN,
INST_COND_BRANCH,
INST_UNCOND_BRANCH,
// io
INST_UNMASKED_LOAD,
INST_MASKED_LOAD,
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// retile
INST_RESHAPE,
INST_SPLAT,
INST_BROADCAST,
INST_DOWNCAST,
// builtin
INST_GET_PROGRAM_ID,
INST_GET_NUM_PROGRAMS,
// atomics
INST_ATOMIC_CAS,
INST_ATOMIC_EXCH,
INST_ATOMIC_RMW,
// math
INST_EXP,
INST_COS,
INST_SIN,
INST_LOG,
// array arithmetic
INST_TRANS,
INST_REDUCE,
INST_DOT,
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_CVT_LAYOUT,
INST_CVT_SCANLINE,
INST_DECOALESCE,
INST_RECOALESCE,
INST_BARRIER,
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
};
}
}
#endif

View File

@@ -1,142 +0,0 @@
#pragma once
#ifndef _TRITON_IR_FUNCTION_H_
#define _TRITON_IR_FUNCTION_H_
#include <string>
#include <map>
#include "value.h"
#include "constant.h"
namespace triton{
namespace ir{
class function;
class function_type;
class module;
class basic_block;
/* Argument */
class argument: public value{
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
void accept(visitor *v);
private:
function *parent_;
unsigned arg_no_;
};
/* Attribute */
enum attribute_kind_t {
readonly = 0,
writeonly,
noalias,
aligned,
multiple_of,
retune,
not_implemented
};
class attribute {
public:
attribute(attribute_kind_t kind, unsigned value = 0):
kind_(kind), value_(value){}
bool operator<(const attribute& other) const {
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
}
attribute_kind_t get_kind() const {
return kind_;
}
unsigned get_value() const {
return value_;
}
bool is_llvm_attr() const {
return kind_ != multiple_of;
}
std::string repr() const {
switch(kind_){
case readonly: return ".readonly";
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
case retune: return ".retunr";
default: break;
}
assert(false);
return "";
}
private:
attribute_kind_t kind_;
unsigned value_;
};
/* Function */
class function: public global_object{
typedef std::vector<argument*> args_t;
typedef args_t::iterator arg_iterator;
typedef args_t::const_iterator const_arg_iterator;
typedef std::vector<basic_block*> blocks_t;
typedef blocks_t::iterator block_iterator;
typedef blocks_t::const_iterator const_block_iterator;
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
private:
function(function_type *ty, linkage_types_t linkage,
const std::string &name = "", module *parent = nullptr);
public:
// accessors
const args_t &args() const { return args_; }
function_type* get_fn_type() { return fn_ty_; }
const function_type* get_fn_type() const { return fn_ty_; }
module *get_parent() { return parent_; }
const module *get_parent() const { return parent_; }
// factory methods
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_function(this); }
private:
module *parent_;
bool init_;
function_type *fn_ty_;
args_t args_;
blocks_t blocks_;
attr_map_t attrs_;
};
}
}
#endif

View File

@@ -1,919 +0,0 @@
#pragma once
#ifndef _TRITON_IR_INSTRUCTIONS_H_
#define _TRITON_IR_INSTRUCTIONS_H_
#include <vector>
#include <map>
#include "triton/ir/enums.h"
#include "triton/ir/constant.h"
#include "triton/ir/value.h"
#include "triton/ir/type.h"
#include "triton/ir/metadata.h"
#include "triton/ir/visitor.h"
#define _TRITON_DEFINE_CLONE(name) \
ir::instruction* clone_impl() const { return new name(*this); }
#define _TRITON_DEFINE_ACCEPT(name) \
void accept(visitor* v) { v->visit_ ## name (this); }
namespace triton{
namespace ir{
class constant_int;
class constant;
class make_range;
class basic_block;
class context;
class visitor;
//===----------------------------------------------------------------------===//
// instruction classes
//===----------------------------------------------------------------------===//
class result_reference;
class instruction: public user{
public:
virtual std::string repr_impl() const = 0;
private:
virtual ir::instruction* clone_impl() const = 0;
protected:
// constructors
instruction(type *ty, value_id_t ity, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// parent
void set_parent(basic_block *block) { parent_ = block; }
const basic_block *get_parent() const { return parent_; }
basic_block *get_parent() { return parent_; }
void erase_from_parent();
// helpers
bool has_tile_result_or_op();
// repr
std::string repr() const { return repr_impl(); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
// for(auto it = op_begin(); it != op_end(); it++)
// (*it)->add_use(res);
res->parent_ = nullptr;
res->users_.clear();
return res;
}
// instruction id
value_id_t get_id() const { return id_; }
void print(std::ostream &os);
private:
basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
value_id_t id_;
};
//===----------------------------------------------------------------------===//
// phi_node classes
//===----------------------------------------------------------------------===//
class phi_node: public instruction {
private:
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
std::string repr_impl() const { return "phi"; }
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
void add_incoming(value *v, basic_block *block);
// Type
void set_type(type *ty) { ty_ = ty; }
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(phi_node)
_TRITON_DEFINE_ACCEPT(phi_node)
private:
unsigned num_reserved_;
std::vector<basic_block*> blocks_;
};
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction {
public:
typedef binary_op_t op_t;
private:
std::string repr_impl() const;
protected:
// Constructors
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
public:
// Get operand
binary_op_t get_op() const { return op_; }
// Bool
bool is_terminator() const;
bool is_binary_op() const;
bool is_int_div_rem() const;
bool is_shift() const;
bool is_cast() const;
bool is_int_mult() const;
bool is_int_add_sub() const;
bool is_int_div() const;
bool is_int_rem() const;
bool is_shl() const;
bool is_shr() const;
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
// Factory methods
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(binary_operator)
_TRITON_DEFINE_ACCEPT(binary_operator)
public:
binary_op_t op_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
};
//===----------------------------------------------------------------------===//
// cmp_inst classes
//===----------------------------------------------------------------------===//
class cmp_inst: public instruction{
public:
typedef cmp_pred_t pred_t;
private:
std::string repr_impl() const;
protected:
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
static bool is_fp_predicate(cmp_pred_t pred);
static bool is_int_predicate(cmp_pred_t pred);
static type* make_cmp_result_type(type *ty);
public:
cmp_pred_t get_pred() const { return pred_; }
private:
cmp_pred_t pred_;
};
class icmp_inst: public cmp_inst {
icmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(icmp_inst)
_TRITON_DEFINE_ACCEPT(icmp_inst)
};
class fcmp_inst: public cmp_inst {
fcmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(fcmp_inst)
_TRITON_DEFINE_ACCEPT(fcmp_inst)
};
//===----------------------------------------------------------------------===//
// unary_inst classes
//===----------------------------------------------------------------------===//
class unary_inst: public instruction {
protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// cast_inst classes
//===----------------------------------------------------------------------===//
class cast_inst: public unary_inst{
private:
std::string repr_impl() const;
protected:
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
: unary_inst(ty, id, v, name, next), op_(op) { }
private:
static bool is_valid(cast_op_t op, value *arg, type *ty);
public:
// accessors
cast_op_t get_op() const { return op_; }
// factory methods
static cast_inst *create(cast_op_t op, value *arg, type *ty,
const std::string &name = "", instruction *next = nullptr);
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_ACCEPT(cast_inst)
private:
cast_op_t op_;
};
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
class name : public cast_inst { \
_TRITON_DEFINE_CLONE(name) \
friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, id, v, name, next, op){ } \
};
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
//===----------------------------------------------------------------------===//
// terminator_inst classes
//===----------------------------------------------------------------------===//
class terminator_inst: public instruction{
using instruction::instruction;
};
// return instruction
class return_inst: public terminator_inst {
private:
std::string repr_impl() const { return "ret"; }
return_inst(context &ctx, value *ret_val, instruction *next);
public:
// accessors
value *get_return_value()
{ return get_num_operands() ? get_operand(0) : nullptr; }
unsigned get_num_successors() const { return 0; }
// factory methods
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
_TRITON_DEFINE_CLONE(return_inst)
_TRITON_DEFINE_ACCEPT(return_inst)
};
// base branch instruction
class branch_inst: public terminator_inst{
private:
std::string repr_impl() const { return "br"; }
protected:
using terminator_inst::terminator_inst;
public:
static branch_inst* create(basic_block *dest,
instruction *next = nullptr);
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
instruction *next = nullptr);
};
// conditional branch
class cond_branch_inst: public branch_inst {
private:
friend class branch_inst;
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
public:
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
value *get_cond() { return get_operand(2); }
_TRITON_DEFINE_CLONE(cond_branch_inst)
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
};
// unconditional branch
class uncond_branch_inst: public branch_inst {
private:
friend class branch_inst;
uncond_branch_inst(basic_block *dst, instruction *next);
public:
basic_block *get_dest() { return (basic_block*)get_operand(0); }
_TRITON_DEFINE_CLONE(uncond_branch_inst)
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
};
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//
class getelementptr_inst: public instruction {
private:
std::string repr_impl() const { return "getelementptr"; }
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
private:
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
public:
// accessors
type *get_source_elt_ty() { return source_elt_ty; }
op_iterator idx_begin() { return op_begin() + 1; }
op_iterator idx_end() { return op_end(); }
value *get_pointer_operand() { return *op_begin(); }
// factory methods
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(getelementptr_inst)
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
private:
type *source_elt_ty;
type *res_elt_ty;
};
//===----------------------------------------------------------------------===//
// load_inst/store_inst classes
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
};
// load
class load_inst: public io_inst {
protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
private:
static type *get_pointee_type(type *ty);
};
// unmasked load
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load"; }
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
};
// masked load
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load"; }
masked_load_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
_TRITON_DEFINE_ACCEPT(masked_load_inst)
};
// masked load async
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async_async"; }
masked_load_async_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
};
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
value *get_value_operand() { return get_operand(1); }
};
// unmasked_store
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
};
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
_TRITON_DEFINE_ACCEPT(masked_store_inst)
};
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
// retile
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
};
// reshape
class reshape_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "reshape"; }
public:
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
_TRITON_DEFINE_ACCEPT(reshape_inst)
};
// splat
class splat_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "splat"; }
public:
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
_TRITON_DEFINE_ACCEPT(splat_inst)
};
// broadcast
class broadcast_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "broadcast"; }
public:
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
_TRITON_DEFINE_ACCEPT(broadcast_inst)
};
// downcast
class downcast_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "downcast"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(downcast_inst)
_TRITON_DEFINE_ACCEPT(downcast_inst)
};
//===----------------------------------------------------------------------===//
// builtin_inst classes
//===----------------------------------------------------------------------===//
class builtin_inst: public instruction{
protected:
using instruction::instruction;
};
class get_program_id_inst: public builtin_inst {
private:
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_program_id_inst)
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
private:
unsigned axis_;
};
class get_num_programs_inst: public builtin_inst {
private:
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_programs_inst)
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
private:
unsigned axis_;
};
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
};
class atomic_rmw_inst: public atomic_inst {
private:
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_rmw"; }
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
public:
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
atomic_rmw_op_t get_op() { return op_; }
private:
atomic_rmw_op_t op_;
};
class atomic_cas_inst: public atomic_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; }
_TRITON_DEFINE_CLONE(atomic_cas_inst)
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class exp_inst: public builtin_inst {
private:
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "exp"; }
_TRITON_DEFINE_CLONE(exp_inst)
_TRITON_DEFINE_ACCEPT(exp_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class cos_inst: public builtin_inst {
private:
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "cos"; }
_TRITON_DEFINE_CLONE(cos_inst)
_TRITON_DEFINE_ACCEPT(cos_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class sin_inst: public builtin_inst {
private:
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "sin"; }
_TRITON_DEFINE_CLONE(sin_inst)
_TRITON_DEFINE_ACCEPT(sin_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class log_inst: public builtin_inst {
private:
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "log"; }
_TRITON_DEFINE_CLONE(log_inst)
_TRITON_DEFINE_ACCEPT(log_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst)
};
//class outer_inst: public builtin_inst {
//private:
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
//public:
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
//};
class trans_inst: public builtin_inst {
public:
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
private:
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; }
public:
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<int> get_perm() const;
_TRITON_DEFINE_CLONE(trans_inst)
_TRITON_DEFINE_ACCEPT(trans_inst)
private:
std::vector<int> perm_;
};
class sqrt_inst: public builtin_inst {
private:
sqrt_inst(value *arg, const std::string& name, instruction* next);
std::string repr_impl() const { return "sqrt"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(sqrt_inst)
_TRITON_DEFINE_ACCEPT(sqrt_inst)
};
class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN,
FADD, FSUB, FMAX, FMIN
};
private:
static type* get_res_type(value *arg, unsigned axis);
static std::string to_str(op_t op);
private:
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
_TRITON_DEFINE_CLONE(reduce_inst)
_TRITON_DEFINE_ACCEPT(reduce_inst)
public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
private:
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
std::string repr_impl() const { return "select"; }
_TRITON_DEFINE_CLONE(select_inst)
_TRITON_DEFINE_ACCEPT(select_inst)
public:
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
value* get_pred_op() { return get_operand(0); }
value* get_if_value_op() { return get_operand(1); }
value* get_else_value_op() { return get_operand(2); }
};
//===----------------------------------------------------------------------===//
// intrinsics classes
//===----------------------------------------------------------------------===//
class copy_to_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_to_shared"; }
public:
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
};
class copy_from_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_from_shared"; }
public:
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class cvt_layout_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "cvt_layout_inst"; }
public:
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(cvt_layout_inst)
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
};
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "barrier"; }
_TRITON_DEFINE_CLONE(barrier_inst)
_TRITON_DEFINE_ACCEPT(barrier_inst)
public:
static barrier_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
};
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
void set_N(int n) { N_ = n; }
private:
int N_;
};
class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}
int get_inc() const { return inc_; }
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
instruction *next=nullptr);
};
//// On NVIDIA, implementation is such that
//// constant_range = nv_dynamic_program_idx + nv_static_program_idx
//// so as to enable re-association on nv_static_program_idx which is constant
//class make_range_dyn: public instruction {
//private:
// make_range_dyn(type *ty, const std::string &name, instruction *next);
// std::string repr_impl() const { return "nv_dynamic_program_idx"; }
// _TRITON_DEFINE_CLONE(make_range_dyn)
// _TRITON_DEFINE_ACCEPT(make_range_dyn)
//public:
// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr);
//};
//class make_range_sta: public constant {
//private:
// make_range_sta(make_range *range);
//public:
// static make_range_sta *get(make_range* range);
// make_range* get_range() const;
// std::string repr() const { return "nv_static_program_idx"; }
// _TRITON_DEFINE_ACCEPT(make_range_sta)
//private:
// make_range *range_;
//};
/* constant range */
class make_range: public instruction{
make_range(type *ty, constant_int* first, constant_int* last);
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
_TRITON_DEFINE_CLONE(make_range)
_TRITON_DEFINE_ACCEPT(make_range)
public:
static make_range *create(constant_int *first, constant_int *last);
const constant_int* get_first() const;
const constant_int* get_last() const;
private:
constant_int* first_;
constant_int* last_;
};
}
}
#endif

View File

@@ -1,32 +0,0 @@
#pragma once
#ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_
namespace triton{
namespace ir{
/* Metadata */
class metadata{
public:
enum kind_t{
multiple_of,
max_contiguous
};
private:
metadata(kind_t kind, unsigned value);
public:
static metadata* get(kind_t kind, unsigned value);
private:
kind_t kind_;
unsigned value_;
};
}
}
#endif

View File

@@ -1,112 +0,0 @@
#pragma once
#ifndef _TRITON_IR_MODULE_H_
#define _TRITON_IR_MODULE_H_
#include <map>
#include <set>
#include <stack>
#include <string>
#include <functional>
#include "triton/ir/builder.h"
#include "triton/ir/metadata.h"
#include "triton/ir/context.h"
namespace triton{
namespace lang{
class iteration_statement;
class compound_statement;
}
namespace ir{
class basic_block;
class phi_node;
class value;
class context;
class function;
class attribute;
class function_type;
class constant;
class global_value;
class alloc_const;
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
lang::iteration_statement *statement;
basic_block *block;
};
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder& builder);
builder& get_builder();
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty);
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
// Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
void print(std::ostream &os);
private:
std::string name_;
builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
};
}
}
#endif

Some files were not shown because too many files have changed in this diff Show More