324 Commits

Author SHA1 Message Date
Qingyi Liu
a6d672166c [Triton-MLIR][OPTIMIZER] Add ExtElemwiseOp to expensive_to_remat list 2022-11-04 15:23:58 +08:00
Keren Zhou
4218e68d74 [Triton-MLIR] [Frontend] Return a scalar if all input args are scalar (#839) 2022-11-03 20:27:47 -07:00
ben-zhang-609
61f2ff98df [triton-mlir] add flag "Link only needed" for external libs. (#838) 2022-11-03 18:50:20 +08:00
Philippe Tillet
91a9773b38 [OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834) 2022-11-02 22:58:09 -07:00
Philippe Tillet
847a318a03 [CI] macos-latest -> macos-10.15 (#836)
This avoids trying to install missing pytorch wheels after the latest
runner image update
2022-11-02 22:22:02 -07:00
ben-zhang-609
5feb6e24f9 [Triton-MLIR]Add ptx vprintf support (#825)
Not know how to write unit test for this feature.

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
2022-11-02 16:39:09 +08:00
Philippe Tillet
12d60cb4a3 [BACKEND] Added support for 1D conversion blocked -> slice (#831) 2022-11-01 13:19:58 -07:00
Chenggang Zhao
c9d84237e8 [Triton-MLIR][Frontend] Interface fixes for libdevice (#829)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-11-01 10:51:32 -07:00
Qingyi Liu
cdc0ec5077 [Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826) 2022-11-01 17:42:59 +08:00
Yan Chunwei
031c2ae77b [Triton-MLIR][BACKEND] Port the mma<v1> conversion (#815)
This PR does

- port the mma<v1> related code, and support dot conversion and
convert_layout[shared->dot_op<mma<v1>>]
- add a lit test for dot v1
2022-11-01 09:42:14 +08:00
Philippe Tillet
cb1b87a688 [FRONTEND] Made test_if/test_default pass (#823) 2022-10-30 15:32:55 -07:00
Philippe Tillet
e61dc75942 [FRONTEND] Fixed inliner and got more tests to pass (#822)
This adds a `DialectInlinerInterface` to the Triton dialect. This, along
with a few other minor semantic changes, fixes our tests on call
instructions. Also added the option to provide use an "LLVM_SYSPATH"
environment variable to link against locally build of LLVM; this was
useful for debugging this issue.
2022-10-30 14:10:02 -07:00
Ian Bearman
71428194a1 [BUILD] Add Back Test Target (#820) 2022-10-29 10:38:50 -07:00
Philippe Tillet
7dfab26a39 [FRONTEND][BACKEND] Fixed various bugs (#819)
- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
2022-10-29 06:34:14 +00:00
Philippe Tillet
82834d34f9 [BUILD] No longer use include((HandleLLVMOptions) (#818) 2022-10-28 17:02:49 -07:00
Ian Bearman
f2106d0aa2 [BUILD] Fix Warnings and Enable Warnings as Errors (#794) 2022-10-28 12:36:09 -07:00
Philippe Tillet
ac0f6793cc [BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)
Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
2022-10-28 16:17:55 +08:00
ben-zhang-609
3685194456 [Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-10-28 05:26:29 +00:00
Keren Zhou
3b80801dff [Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)
1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
2022-10-27 22:09:06 -07:00
Qingyi Liu
42db3538e4 [Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)
What is done in this PR:
- [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA`
implementation for `SliceEncodingAttr`
- [x] Split `emitIndices` into two phases:
`emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout`
- [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation
- [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with
ptx instruction `shfl.sync`
- [x] Add support for scalar value in `StoreOpConversion`
- [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests

Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
2022-10-28 11:07:45 +08:00
Philippe Tillet
3e6cc6d66c [FRONTEND] Made more tests pass (#805) 2022-10-26 17:47:33 -07:00
goostavz
bb7008651a [Backend] Hacky fix of missing barrier in ConvertLayout blocked->shared (#803)
Barrier should be set by a separate pass, but it seems like there may be some bugs
2022-10-26 13:39:38 -07:00
Yan Chunwei
4dc2396ca0 [Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798)
This PR does

1. Support the case where $c holding a mma layout, this should be useful
in forloop in k-axis in GEMM
2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op]

Known issue

1. There is some IO conflict in GEMM with a k-forloop, it is temporarily
solved by [adding a
barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028)
in dot conversion, but we are still working on it, will get a more
generic fix for it in the following PR.
2. The parallel pass will result in a buggy instruction result type
```mlir
%1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", ""  : () -> !llvm.void
%1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr<f16, 3>
```
So we temporarily disable it.
2022-10-26 10:33:04 +08:00
Philippe Tillet
a2cbe7af91 [FRONTEND] Enhanced support for binary operators (#801)
Disabled modulo test (due to change in behavior for `frem` in nvptx
between llvm-11 and llvm-14) and bfloat16 (will require some work to
emulate in software similar to how it's done in `master`)
2022-10-24 19:47:01 -07:00
Philippe Tillet
fcb228d1d4 Merge select commits from master branch into triton-mlir (#799)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: vesuppi <zt9465@gmail.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: daadaada <dyanab@connect.ust.hk>
Co-authored-by: Anton Kostin <masguit42@users.noreply.github.com>
Co-authored-by: Yunxing Dai <nov503@gmail.com>
Co-authored-by: Shintaro Iwasaki <shintaro.iwasaki.work@gmail.com>
2022-10-24 14:52:37 -07:00
Yan Chunwei
877844de4f [Triton-MLIR][BACKEND] add convert_layout[shared->dot_op] converstion to adapt DotOperand layout (#786)
This PR helps to

1. Adapt the existing DotOp conversion to the design of the new
DotOperand layout,
2. Making the DotOp conversion work with both shared-layout inputs case
and dotoperand-layout inputs case for further upstream switch.
2022-10-24 11:40:13 +08:00
Philippe Tillet
3aa8296b06 [BUILD] Download pybind11 in setup.py (#703) (#797)
Cherry-picks #703 and resolves conflicts

Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
2022-10-23 18:52:48 -07:00
Yan Chunwei
1bf59d315c [Triton-MLIR][FRONTEND] Remove the dangling check-triton call in setup.py (#796) 2022-10-23 18:26:18 -07:00
Philippe Tillet
bb0f9235d1 [OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790) 2022-10-21 16:52:15 -07:00
goostavz
c4726333bf [Triton-MLIR] Minor fixes related with scf/swizzling support (#791)
1, Disable static loop unrolling in the frontend by default;
2, A minor fix in axisAnalysis in order to support scf;
3, A minor fix in TritonGPUToLLVM to support swizzling.
2022-10-21 11:46:28 +08:00
Philippe Tillet
dc0588a898 [OPTIMIZER] Improved layout simplification pass so it handles swizzled layouts better (#789)
Note: uncommented `test_gemm`, since backend has an issue with swizzling. This will get uncommented in a subsequent PR.
2022-10-20 19:03:37 -07:00
Shintaro Iwasaki
0d22d2bc03 [TritonMLIR] Disallow 0D tensor (#788) 2022-10-19 10:34:32 -07:00
Yan Chunwei
4464646efb [Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment
correctly considered.

Some cases:

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
    // mask = make_range(128) < n_element
}
```
This should get the vec=2 `ld`/`st` instructions.

While the following example

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
    // mask = make_range(128) < n_element
}
```
it should get the vec=1 `ld`/`st` instructions.
2022-10-18 11:43:50 +08:00
Philippe Tillet
38a80664b5 [OPTIMIZER] Updated TritonGPU-combine pass (#784)
WIP but should work int t…he cases we need so far
2022-10-16 21:19:42 -07:00
goostavz
e948a618b3 [Triton-MLIR] fix a tiny bug in coalesce pass (#782) 2022-10-16 20:29:55 -07:00
Shintaro Iwasaki
5898352f97 [Triton-IR] Fix LoadOp definition (#771) (#777) 2022-10-13 18:53:00 -07:00
Da Yan
963d031247 [Triton-IR] Fix LoadOp Triton->TritonGPU conversion (#775) 2022-10-13 12:57:39 -07:00
Yan Chunwei
1baa4e125f [triton-mlir][BACKEND] decouple loading from mma codegen in dot conversion (#764)
This PR decouples the operand loading from the mma codegen to make it
ready for the ongoing `DotOperandEncodingAttr` migration.
The existing DotOp conversion is composed of the following two
procedures:
1. Loading the $a,$b,$c operand from smem to registers
2. Conducting the MMA instruction codegen.

While in the latest design, the 1st stage should be part of the
`convert_layout(shared_layout) -> dot_operand_layout`, that's why the
decoupling is necessary.

Some details, this PR introduces a `MMA16816ConversionHelper` class, it
has `loadA`, `loadB` and `loadC` methods to help load $a, $b and $c from
smem to registers, both `loadA` and `loadB` methods returns a
`LLVM::Struct` which should be compatible with the new
`DotOperandEncodingAttr` conversion.

The conversion layout for $a and $b is as follows:
```c++
// The layout is a list of Value with coordinate of (i,j), the order is as
// the follows:
// [
//  (0,0), (0,1), (1,0), (1,1), # i=0, j=0
//  (0,2), (0,3), (1,2), (1,3), # i=0, j=1
//  (0,4), (0,5), (1,4), (1,5), # i=0, j=2
//  ...
//  (2,0), (2,1), (3,0), (3,1), # i=1, j=0
//  (2,2), (2,3), (3,2), (3,3), # i=1, j=1
//  (2,4), (2,5), (2,4), (2,5), # i=1, j=2
//  ...
// ]
// i \in [0, n0) and j \in [0, n1)
```
In the `convertDot` method, it takes loaded $a, $b, $c($a and $b are
type of `LLVM::Struct` while $c is a scalar Value), extract the elements
from `LLVM::Struct` following the layout above, and pass the elements to
MMA inline asm.
2022-10-12 10:45:17 +08:00
Philippe Tillet
623c99609f [Triton-IR] Added type inference and verifier for Triton-IR operations (#767) 2022-10-11 18:16:41 -07:00
Philippe Tillet
b6e5a231e5 [OPTIMIZER] Added swizzling pass (#758) 2022-10-10 01:12:37 -07:00
Yan Chunwei
555f94f9b9 [triton-mlir][BACKEND] Support masked load/store (#657)
This PR does

- fix some bugs to support masked load/store,
- refine frontend, and support the `and` and `or` syntax in mask(by
extending the BoolOp in python ast.visitor), e.g. `tl.store(...,
mask=offset<n and other_conditions)`,
- add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by
mask),
- add more test cases in vecadd.
2022-10-10 13:29:53 +08:00
Ian Bearman
ccc5ab6ac9 [BUILD] When set, use MLIR_DIR for finding both MLIR and LLVM (#755) 2022-10-09 13:11:20 -07:00
Ian Bearman
89f6e1db5e [BUILD] use cmake to set include path when build isn't triggered by setup.py (#754) 2022-10-09 12:30:44 -07:00
Ian Bearman
863578a7fa [BUILD] Enable current-dir inclusion (#753)
This change enables `CMAKE_INCLUDE_CURRENT_DIR` when building Triton.
2022-10-09 18:09:49 +00:00
Ian Bearman
448d14a598 [BUILD] Add TRITON Prefix to build variables (#752) 2022-10-09 10:55:17 -07:00
goostavz
1d772cd843 [Triton-MLIR][Backend] Add SCF lowering in the backend (#750) 2022-10-08 18:36:37 +08:00
Philippe Tillet
498c685b46 [OPTIMIZER] layout simplification: ignore non-tensor iter arguments in for loop rematerialization (#749) 2022-10-07 21:52:29 -07:00
goostavz
e843257295 [Backend] Fix a bug in emitIndicesForBlocked (#740) 2022-10-04 21:29:59 -07:00
Keren Zhou
289ff293cc [Triton-MLIR] Generate LLVM/PTX code for async ops (#735) 2022-10-04 09:37:00 -07:00
goostavz
f9d7f2f126 [Triton-MLIR][Backend] Support ConvertLayout blocked->shared and a few fixes related with mma(#716) 2022-10-03 19:33:25 +08:00
Keren Zhou
baba98ad69 [Triton-MLIR] Fix threadsPerWarp derivation in BlockedEncodingAttr (#722)
Example:

```
    auto encoding = triton::gpu::BlockedEncodingAttr::get(
        &getContext(), {8, 32}, {2, 2}, {1, 0}, 2);
     //shape = [32 x 8], order = [1, 0], sizePerThread=[2, 2], numWarps=2
```

Expected output:

```
      //#triton_gpu.blocked_layout<{
      //  sizePerThread = {2, 2}
      //  threadsPerWarp = {8, 4}
      //  warpsPerCTA = {2, 1}
      //}>
```

Incorrect output by the current branch

```
      //#triton_gpu.blocked_layout<{
      //  sizePerThread = {2, 2}
      //  threadsPerWarp = {16, 2}
      //  warpsPerCTA = {2, 1}
      //}>
```
2022-09-27 16:41:30 -07:00
Philippe Tillet
9ddf0921fb [OPTIMIZER] Added DotOp to the list of expensive ops we don't want to rematerialize. (#718) 2022-09-27 09:05:49 -07:00
Yan Chunwei
df8d276089 [Triton-MLIR][Backend] Fix smem base bug in dot codegen (#715)
Get SMEM base address of an input operand from `adapter.arg()` instead
of `getSharedMemoryBase(arg, ...)`, for the latter one not works with
memory alias, for example:

```llvm
%a = extract_slice %b, %offset
%c = dot %a, %d
```

`%a` should have different smem base address from `%b`
2022-09-27 17:28:17 +08:00
Yan Chunwei
3a84278530 [Triton-MLIR][BACKEND] Refine dot conversion (#710)
This PR does

1. Refine the dot conversion
2. some other tiny code refinement
2022-09-27 14:38:34 +08:00
goostavz
61b61755e5 [Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693) 2022-09-27 03:58:47 +00:00
Philippe Tillet
1e91ed30d0 [RUNTIME] Major code cleanup (#711)
This PR does the following:
- CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore.
- Refactoring driver/llvm.cc to split it between PTX codegen and python.
- By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore.
- `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
2022-09-26 16:38:06 -07:00
Philippe Tillet
8bb09f83ee [CI] Added CODEOWNERS file (#709) 2022-09-24 16:32:44 -07:00
Philippe Tillet
22ec22c257 [FRONTEND] Backport new runtime from master (#706)
This PR merges the new runtime back into the `triton-mlir` branch. This
adds caching and just-in-time compilation functionality to the
triton-mlir project, and paves the way for re-using tests from the
master branch.
2022-09-23 16:09:43 -07:00
Keren Zhou
ecd1bc33df [Triton-MLIR] Keren/code gen for extract slice and alloc tensor (#692)
Co-authored-by: gzhu <goostavz@outlook.com>
2022-09-23 19:38:14 +00:00
Philippe Tillet
c56f0198dd Revert "[Triton-MLIR][pybind11] Update pybind11 to 2.10.0" (#702)
Reverts openai/triton#694
2022-09-23 12:31:33 -07:00
Yan Chunwei
922155f1d2 [BACKEND] add dot conversion (mma version=2) (#672)
LLVM Conversion for Dot op.

Due to the lack of `convert_layout`, currently, the dot only supports
the following combination of operands

- `$a` in shared layout
- `$b` in shared layout
- `$c` in MMA layout(but only Splat-like, leaving the generic cases to
`convert_layout`)

This PR focus on `mma.16816` related logic support, leaving the other
cases to the following PR.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-09-22 20:43:54 -07:00
Shintaro Iwasaki
23f424c660 [Triton-MLIR][pybind11] Update pybind11 to 2.10.0 (#694)
This PR applies #691 to the Triton-MLIR branch.
2022-09-22 17:53:42 -07:00
Shintaro Iwasaki
940ef3f0ac [BACKEND] llvm::dyn_cast -> llvm::dyn_cast_or_null (#689) 2022-09-22 03:26:40 +00:00
goostavz
15bfd0cb79 [BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658) 2022-09-17 14:58:42 -07:00
Shintaro Iwasaki
13669b46a6 [DOCS] Correct spelling (#665)
This PR corrects spelling like #664 for Triton-MLIR. It should not break anything.
2022-09-16 15:07:34 -07:00
Shintaro Iwasaki
e9e1a4e682 [FRONTEND] Fix the implicit broadcasting rule (#663)
This PR solves the cast issue that appears in some tutorial code.
2022-09-16 10:49:15 -07:00
Philippe Tillet
80e3fb5270 [CI] Now using clang-format from pip (#662) 2022-09-15 16:24:37 -07:00
Shintaro Iwasaki
43be75ad42 [FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
2022-09-15 16:12:52 -07:00
Da Yan
2e08450c80 [OPTIMIZER] Better pipeline tests (#660) 2022-09-14 23:26:40 -07:00
Shintaro Iwasaki
297d27e1c8 [Triton-MLIR] add GitHub CI runners (#655)
This PR is to add GitHub Actions runners to the CI for better coverage.
2022-09-14 23:09:56 -07:00
Philippe Tillet
c14dff2190 [CI] Added A10 tag to disambiguate self-hosted runners (#652) 2022-09-14 13:08:01 -07:00
Keren Zhou
16aed94ff5 [Analysis/Allocation] Allocation passes now assumes that slices always alias (#108)
This code in this branch assumes the `src` operand in
`insert_slice_async` always aliases the result, which shouldn't hold for
generally cases but is just a workaround to make the pipeline pass work.

I'm also working on the complete analysis in another
[branch](https://github.com/openai/triton-mlir/tree/keren/analyze-slice).
2022-09-09 12:03:41 -07:00
Philippe Tillet
9bd5a3dcd2 [OPTIMIZER] Pipeline async buffer (#110) 2022-09-09 11:01:14 -07:00
Yan Chunwei
2a852044d9 [BACKEND] Add C++ tests for PTXFormat and some tiny refinement (#109)
This PR does

1. Add some C++ tests for `PTXFormat`
2. Enhance the functionality of `PTXFormat`, make a `PTXInstr` instance
can be called multiple times similar as a C function.
2022-09-09 09:15:07 -07:00
Yan Chunwei
a9464f4993 [Backend] Vectorize Load/Store Ops (#86)
This PR does the following things:

- Code refactoring on Load and Store op codegen, rewrite with same logic
and share much code
- Support the vectorized load/store
2022-09-06 12:28:09 -07:00
Da Yan
35e346bcff [OPTIMIZER] Better pipeline pass (#100)
* Use `insert_slice_async` instead of `CopyAsync`
* Move async.wait to loop header

Co-authored-by: Jokeren <kerenzhou@openai.com>
2022-09-06 08:31:13 -07:00
Philippe Tillet
a0bab9748e [OPTIMIZER] Coalesce pass no longer takes a num-warps argument (#99)
Improved design to avoid inconsistent `num-warps` value between the pass and the parent module of the operation it processes.
2022-09-05 18:09:02 -07:00
Jun Yang
ea175f689e [CI]Added initial framework of CXX unittest (#98)
Based on the discussion in #53 , I just added the initial flow of CXX unittests for this repo, with providing two dummy UTs as placeholder to show the usage, feel free to add your own CXX unittests. 
@Superjomn  @ptillet 

@ptillet , in this PR, I also configure the integration-tests.yml to add the unittest into github CI check. 

Thanks
2022-09-04 12:50:27 +08:00
Philippe Tillet
d0b4c67b05 [OPTIMIZER] Improved layout conversion simplification algorithm (#97)
This PR both simplifies the layout conversion simplification algorithm, and also improves it to make it work with vectorized element-wise ops. The conversion optimizer still has a lot of room for improvements, and other PRs will address its limitations (ideally via some sort of explicit cost model)
2022-09-02 16:52:44 -07:00
Shintaro Iwasaki
3c635449e5 [Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) 
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
  - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
  - No constant folding etc for `libdevice` ops.

```py
import triton
import triton.language as tl
import sys

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    x = tl.sin(x)
    output = tl.libdevice.sin(x)
    output = tl.libdevice.fdiv_rn(output, output)
    output = tl.libdevice.fmaf_rd(output, output, output)
    tl.store(y_ptr + offsets, output)


if __name__ == "__main__" and len(sys.argv) >= 2:
    signature = "*fp32,*fp32"
    constants = {'BLOCK_SIZE': 1024}
    output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
    print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
    %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
    %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
    %4 = math.sin %3 : tensor<1024xf32, #blocked>
    %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.store %9, %7 : tensor<1024xf32, #blocked>
    return
  }
}
```
2022-09-01 16:34:27 -07:00
Keren Zhou
328b87aec6 Keren/tensor slice insert alloc (#94)
This branch defines three new triton_gpu operations to partially solve #87. Below is an overview:

```
%tensor = triton_gpu.alloc_tensor : tensor<2x16x16xf16, #A>
%b = triton_gpu.insert_slice_async %a_ptr, %tensor, %offset {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr<f16>, #AL> -> tensor<2x16x16xf16, #A>
%c = triton_gpu.extract_slice %b, %offset {axis = 0 : i32} : tensor<2x16x16xf16, #A> -> tensor<16x16xf16, #A>
```

We plan to fully replace `copy_async` with `insert_slice_async`. **This hasn't been done yet.**
2022-09-01 12:37:17 -07:00
Shintaro Iwasaki
d01353de07 [CI] add assert-enabled MLIR option (#78)
This deprecates the use of release-build LLVM hosted by the LLVM project, which makes debugging harder for developers.

This PR implements the following solution:
1. Create LLVM release tarballs with assert enabled on our own (using Docker)
2. Host them in our own GitHub repositories
3. Use our LLVM for CI and/or development if `TRITON_USE_ASSERT_ENABLED_LLVM=1` is set.
2022-08-31 18:55:32 -07:00
Keren Zhou
02ebf24d35 Analyze shared memory alias (#81)
The purpose of this PR is analyzing shared memory aliases so that we can
fix memory allocation bugs and save memory allocations in triton code
involving complex control flows.

Changes to memory bar and allocation are on the way.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-08-29 10:43:20 -07:00
Philippe Tillet
83287d7193 [CI] enable self-hosted runner (#85) 2022-08-25 19:12:16 -07:00
goostavz
bedbf221c0 [BACKEND] Support optional mask in TritonGPUToLLVM (#80)
Co-authored-by: gzhu <gzhu@nvidia.com>
2022-08-24 17:51:37 -07:00
Shintaro Iwasaki
84aa7d025a [TritonIR] simplify Load/StoreOps when mask is true/false (#79)
* [TritonIR] fix Load/Store/CopyAsyncOp's parsers

* [TritonIR] simplify Load/StoreOps when mask is true/false

* [TEST] adds tests to check load/store simplification
2022-08-24 12:55:49 -07:00
Yan Chunwei
1b513c9866 [BACKEND] Refactoring codegen for LoadOp with PTXFormat (#77)
This PR does following things:

Enhance the PTXFormat by
Introducing PTXBuilder to enable multiple instructions in a single asm program
override PTXInstr's operator() method to enable instr(opr0, opr1) style of setting operands for an instruction
Refactor the PTX code used in LoadOpConversion with PTXFormat

Authored-by: goostavz <gzhu@nvidia.com>
2022-08-23 15:51:13 -07:00
Shintaro Iwasaki
0ebef11c77 [TritonIR] Make mask operand optional (#74) 2022-08-22 22:00:17 -07:00
goostavz
de2dd04c8a [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69)
* [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering

* Clean code

Co-authored-by: goostavz <gzhu@nvidia.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
2022-08-22 21:47:09 -07:00
Da Yan
92ef552a54 [OPTIMIZER] Fix Num in AsyncWaitOp generated by the pipeline pass (#72) 2022-08-22 15:58:10 -07:00
Yan Chunwei
10ba51c3bb [FRONTEND] add python e2e launch empty kernel test (#68) 2022-08-19 10:46:01 -07:00
Shintaro Iwasaki
9aa00249a6 [TritonIR] make other optional and remove isOtherUnspecified (#67)
[Triton] make other optional and remove isOtherUnspecified
2022-08-18 18:19:55 -07:00
Philippe Tillet
192be76b3c [OPTIMIZER] Rewrite patterns for layout conversions (#64) 2022-08-18 12:49:37 -07:00
Keren Zhou
e0bedeb44c [BACKEND] Keren/shared memory barrier (#59) 2022-08-18 12:32:57 -07:00
Da Yan
8776ad1a0e [OPTIMIZER] Let the pipeline pass insert async wait. (#63) 2022-08-18 10:31:57 -07:00
Shintaro Iwasaki
d69ce77b19 [FRONTEND] add an attr for masked load without explicit other (#55) 2022-08-18 09:51:37 -07:00
goostavz
fc58250a06 [BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60)
Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp

Co-authored-by: gzhu <gzhu@nvidia.com>
2022-08-18 20:46:45 +08:00
Yan Chunwei
b1673caaf6 [FRONTEND] Expose end-to-end compile to python frontend (#58) 2022-08-17 10:42:48 -07:00
Yan Chunwei
95bbac41e7 [BACKEND] Add LLVM-translation for store and splat ops (#47) 2022-08-15 00:46:37 -07:00
goostavz
993ba7035a [BACKEND] Codegen bringup, index calculation of blocked_layout & support of LoadOp, BroadcastOp, ViewOp & MakeRangeOp (#38)
Co-authored-by: gzhu <gzhu@nvidia.com>
2022-08-14 19:58:59 -07:00
Da Yan
e5ec8e16ea [BUILD] Fix setup.py (#45) 2022-08-13 16:38:31 -07:00
Shintaro Iwasaki
d5856435d7 [CI] explicitly run unit tests (#54) 2022-08-12 13:39:04 -07:00
Shintaro Iwasaki
2ba9a83465 [BUILD] fix minor issues with MLIR assert enabled (#46) 2022-08-11 21:20:47 -07:00
Philippe Tillet
3a48ca0d4d [BUILD] Fix includes (#49) 2022-08-11 11:49:29 -07:00
Yan Chunwei
83ef74f248 [BACKEND] Extracting numWarps from tritonGPU module (#39) 2022-08-08 09:40:20 -07:00
Yan Chunwei
920723cf3d [BACKEND] add triton-translate to translate mlir to llvmir or PTX code (#37) 2022-08-07 22:34:36 -07:00
Philippe Tillet
490d34e0d5 [FRONTEND] Fixed python bindings link options (#40) 2022-08-07 13:09:12 -07:00
Philippe Tillet
78ebbe24c7 [FRONTEND] Added ExpandDimsOp primitive (#36) 2022-08-04 18:41:06 -07:00
Keren Zhou
a7b49b3227 [BACKEND] Memory allocation (#33) 2022-08-04 11:22:49 -07:00
Yan Chunwei
b988bae813 Init TritonGPU to LLVM dialect conversion (#32)
* add toLLVM pass

* update num-warps setting in mlir
2022-08-04 10:15:45 +08:00
Philippe Tillet
3236642e8f [OPTIMIZER] Added memory coalescing pass (#31) 2022-07-31 20:59:31 -07:00
Philippe Tillet
d1593e6ca8 [TritonGPU] Improved documentation and semantics of layout encodings (#30) 2022-07-31 13:59:44 -07:00
Yan Chunwei
e02c82c765 [TritonIR] Convert Triton dialect's Combine pass to MLIR DRR based (#16) 2022-07-27 12:50:08 -07:00
Philippe Tillet
432c3df265 [BUILD] MacOS can now build compiler and run MLIR tests (#25) 2022-07-27 01:32:10 -07:00
Philippe Tillet
6d62d88d4f [CI] run clang-format (#24) 2022-07-26 17:25:03 -07:00
Philippe Tillet
25357083e6 [CI] Added basic CI skeletons (#23)
Includes minor fixes to make things compile and pass static checks properly
2022-07-26 14:16:30 -07:00
Philippe Tillet
3265e0df5a [PYTHON] Cleaned up legacy code; added simple standalone compilation API (#22) 2022-07-26 11:06:45 -07:00
Keren Zhou
96cc6fb563 [TritonGPU] Pretty printer for layouts (#21) 2022-07-26 10:50:11 -07:00
Philippe Tillet
27c9f3d8cb [FRONTEND] Added comment on TensorSizeTrait::maxElement (#20) 2022-07-25 01:18:45 -07:00
Keren Zhou
7eda373a12 Add lit dependency (#9) 2022-07-24 19:14:52 -07:00
Philippe Tillet
a633d2b403 [Analysis] Added Axis Info Analysis (#8) 2022-07-19 13:38:48 -07:00
Philippe Tillet
df940aaab0 Merge pull request #7 from openai/broadcastAxis-fix
Fix blocked layout parser
2022-07-15 08:39:49 -07:00
Yan Da
63e6a85901 Fix blocked layout parser 2022-07-15 15:19:11 +08:00
Phil Tillet
65237f6117 [PACKAGING] Added FileCheck 2022-07-07 16:53:19 -07:00
Yan Da
9d1b5e3f79 special encoding for broadcast 2022-06-18 21:16:45 +08:00
Yan Da
53cf93ce6a Revert "Remove TypeConverter from TritonToTritonGPU conversion"
This reverts commit 64d0b87ef0.
2022-06-18 14:57:41 +08:00
Yan Da
64d0b87ef0 Remove TypeConverter from TritonToTritonGPU conversion 2022-06-18 14:34:59 +08:00
Yan Da
9feb256b71 op combine in Triton Dialect: broadcast(cst) -> cst 2022-06-17 16:19:47 +08:00
Yan Da
35736aa44e more progress on the testing infrastructure 2022-06-12 15:14:45 +08:00
Yan Da
22c65a53d9 more progress on test/CMakeLists.txt 2022-06-10 21:37:56 +08:00
Yan Da
0ee6e486f8 add cse pass to the pipeline & pass num-warps as an argument 2022-06-10 17:31:48 +08:00
Yan Da
117a402c1b more comments to TypeConverter & update warpTileSize 2022-06-08 16:20:07 +08:00
Yan Da
49d1821149 conversion test 2022-06-08 16:19:15 +08:00
Yan Da
26fcc12afd better unit tests 2022-06-07 19:35:38 +08:00
Yan Da
7b09b5f9e9 the pipeline pass now generates and accepts valid IR 2022-06-07 19:34:59 +08:00
Yan Da
560e29229b register conversion in triton-opt 2022-06-07 19:33:51 +08:00
Yan Da
0e11435448 more tests 2022-06-06 21:10:28 +08:00
Yan Da
366dddc3bc update mma encoding & triton-opt 2022-06-06 21:03:58 +08:00
Yan Da
7807f64ef3 rename sharded_layout => blocked_layout 2022-06-05 16:14:59 +08:00
Yan Da
bbf75b492f more tests 2022-06-05 15:10:09 +08:00
Yan Da
a4a2c72173 default address space of PointerType 0 => 1 2022-06-05 15:09:41 +08:00
Yan Da
d5eca56cf3 more TritonGPU unit tests 2022-06-05 14:25:09 +08:00
Yan Da
55cf9a0a97 Add triton's opt 2022-06-04 22:10:00 +08:00
Yan Da
830fe19d58 Merge branch 'mlir-rewrite' of https://github.com/daadaada/mlir-rewrite into mlir-rewrite 2022-06-01 10:59:20 +08:00
Yan Da
935390dc03 update examples 2022-06-01 10:59:16 +08:00
Da Yan
e36a54eb86 more progress on the definition of layouts 2022-05-31 11:43:21 +00:00
Yan Da
41d338d848 Fix op mapping in pipeline.cpp 2022-05-26 13:57:01 +08:00
Yan Da
c529b462f5 more fixes on pipeline.cpp 2022-05-26 13:14:41 +08:00
Yan Da
71d1c10e19 Remove weird includes 2022-05-25 21:54:06 +08:00
Yan Da
9308e9c90c A more general pipeliner 2022-05-25 21:52:51 +08:00
Yan Da
441fd7c3cc assembly format 2022-05-25 17:53:24 +08:00
Yan Da
e6f89a5777 Fix ReduceOp conversion 2022-05-25 16:03:06 +08:00
Yan Da
9b670cfb9f Add ReduceOp 2022-05-25 14:15:36 +08:00
Yan Da
a2c9f919a8 TritonGPU verifier 2022-05-24 19:48:56 +08:00
Yan Da
36c45ec687 make numStages an option in PipelinePass 2022-05-23 12:47:55 +08:00
Yan Da
39b1235082 fix atomic_cas 2022-05-22 19:43:04 +08:00
Yan Da
79298d61bc fix a pipeline issue 2022-05-16 19:38:40 +08:00
Yan Da
c3c4ac3733 TritonGPU combiner 2022-05-16 19:17:15 +08:00
Yan Da
e3916c3a46 TritonGPU combiner 2022-05-16 19:16:01 +08:00
Yan Da
0e68e6eb59 delete erroneous include 2022-05-15 22:30:26 +08:00
Yan Da
7027af9666 The pipeline pass is now functional 2022-05-15 22:29:27 +08:00
Yan Da
7e0e7ec365 more progress on the pipeline pass 2022-05-14 22:04:36 +08:00
Yan Da
978463ba39 more progress on the pipeline pass 2022-05-13 21:32:35 +08:00
Yan Da
d23d7b244c More on the pipeline pass 2022-05-11 20:31:08 +08:00
Yan Da
1a4fbed25b Skeleton for the pipeline pass 2022-05-11 16:13:53 +08:00
Yan Da
96876a46d1 More progress on Triton=>TritonGPU conversion (works for matmul) 2022-05-09 21:19:53 +08:00
Yan Da
0c5319eed9 More progress on SCF type conversion 2022-05-05 20:56:55 +08:00
Yan Da
26c59e4718 More on SCF conversion 2022-05-04 21:50:32 +08:00
Yan Da
a96fe07e1c DotOp conversion 2022-05-04 15:56:24 +08:00
Yan Da
2d281cbc0a ConstantOp conversion pattern 2022-05-04 15:35:43 +08:00
Yan Da
b9279d2e3b More progress on TritonGPU conversion 2022-05-04 14:54:31 +08:00
Yan Da
3ad7bee35e More conversion patterns 2022-05-04 12:50:02 +08:00
Yan Da
5f08e2fdae More arith patterns 2022-05-02 22:31:29 +08:00
Yan Da
75d32e2442 More on TritonGPU conversion 2022-05-02 21:51:00 +08:00
Yan Da
1428185c9c More progress on TritonGPUTypeConverter & TritonGPUConversionTarget 2022-05-01 22:06:54 +08:00
Yan Da
4ece9fd1f3 Move dependentDialects from .cpp to .td 2022-05-01 13:06:51 +08:00
Phil Tillet
d9017f8593 add basic template for legalizing arithmetic op 2022-04-30 20:42:25 -07:00
Phil Tillet
2c6a213131 [TRITONGPU] Added template for Triton -> TritonGPU conversion 2022-04-30 16:00:39 -07:00
Yan Da
2239ac1998 more progress on TritonGPU 2022-04-28 18:51:31 +08:00
Philippe Tillet
012e8c5b2b fixup 2022-04-27 16:39:27 -07:00
Philippe Tillet
513bcaee50 Added some ASCII art for encoding documentation 2022-04-27 16:28:27 -07:00
Yan Da
29859605ee Remove unused files 2022-04-27 21:20:07 +08:00
Yan Da
38d13ae618 Some progress on TritonGPU 2022-04-27 21:16:45 +08:00
Yan Da
edca91bf8f Update traits (NoSideEffect) 2022-04-27 19:41:07 +08:00
Yan Da
8dfe78f6cf Add TritonCombineOps 2022-04-27 19:28:21 +08:00
Yan Da
c70f6b666e Merge previous changes 2022-04-27 14:06:55 +08:00
Yan Da
74585fb970 Add Triton CombineOps 2022-04-27 13:45:56 +08:00
Philippe Tillet
81001d318c Putting Triton dialect in its own folder 2022-04-26 14:39:27 -07:00
Philippe Tillet
62a64ff29b Fixed Python link bug in CMakeLists 2022-04-26 14:39:18 -07:00
Yan Da
9e304cf79d Allow JITFunction to return multiple results 2022-04-15 15:38:19 +08:00
Yan Da
1c52bd587d Device function & PassManager 2022-04-15 14:41:57 +08:00
apd10
44d75cf9bb Bugfix in ptxas path. (#487)
Bug: "ret" value is destroyed when a failing "ptxas --version" is run
overwriting the previous valid "ret" value.

Fix: keep rets only for those runs which are successful. Pick the first
one
2022-04-12 13:07:28 +08:00
Philippe Tillet
9be2d655a3 [DRIVER] LLVM driver fixup (#482)
Current way of doing things is probably not super thread safe. init is shared between threads and some threads my not call the LLVMInitialize* function.
2022-04-12 13:03:02 +08:00
Keren Zhou
f51e0b1be4 [FRONTEND] Hot fix for lineno (#481)
Override __reduce__ to make CompilationError pickable and print out error messages
2022-04-12 13:02:33 +08:00
Yan Da
7e0fd97965 Add set_attr(...) to ir.OpState 2022-04-11 12:26:54 +08:00
Yan Da
4eb062f313 fix issues in visit_If 2022-04-10 16:28:45 +08:00
Yan Da
fcbbb3c10e Fix visit_While issues 2022-04-10 16:16:13 +08:00
Yan Da
19f81b7dea Add scf-codegen tests 2022-04-10 15:49:09 +08:00
Yan Da
9c7b3d5173 Manage insertion block with context manager 2022-04-10 15:02:12 +08:00
Yan Da
aa6e086881 Add more comments 2022-04-10 14:36:03 +08:00
Yan Da
f1cc67bbc3 triton -> tt 2022-04-10 12:07:19 +08:00
Yan Da
28e96bbfd1 Remove the dependency on TensorDialect 2022-04-08 19:43:09 +08:00
Yan Da
a3d0812d27 Update example 2022-04-08 19:38:25 +08:00
Yan Da
62f7609612 More on type inference & assembly format 2022-04-08 19:37:57 +08:00
Yan Da
13aead4808 Use TableGen to define new types 2022-04-08 16:32:46 +08:00
Yan Da
6002340456 Better textual representation 2022-04-07 20:44:41 +08:00
Yan Da
0864b253bb the matmul example 2022-04-07 20:27:18 +08:00
Yan Da
62f772123c now kernel functions return nothing (instead of none) 2022-04-07 20:22:17 +08:00
Yan Da
040a2b6c75 Fix OpBuilder 2022-04-07 20:01:31 +08:00
Yan Da
6b4da6f016 Documentation 2022-04-07 16:00:53 +08:00
Yan Da
16d44e5c4c Verify power-of-2 2022-04-07 15:28:02 +08:00
Yan Da
9cf4107990 Add TensorSizeTrait 2022-04-07 15:18:43 +08:00
Yan Da
39fad2b18a More progress on WhileOp 2022-04-05 17:55:43 +08:00
Yan Da
d7fbddc7d4 Fix ret::reference issue 2022-04-05 16:09:09 +08:00
Yan Da
c7ad928e60 More progress on WhileOp codegen 2022-04-05 15:55:48 +08:00
Yan Da
76d9249724 examples 2022-04-04 12:59:54 +08:00
Yan Da
0f96da336a codegen for If 2022-04-04 12:58:37 +08:00
Yan Da
9df899b291 Some progress on visit_If 2022-04-03 22:34:46 +08:00
Yan Da
c71c50cd0c ForOp's SSA construction 2022-04-03 19:11:47 +08:00
Yan Da
61413b8a97 More python bindings 2022-04-01 22:22:39 +08:00
Yan Da
9dafa0e2e3 Update trtion dependencies 2022-04-01 20:16:07 +08:00
Yan Da
bde103fab0 Replace MlirType with mlir::Type 2022-04-01 18:46:46 +08:00
Yan Da
4ad432f1fc More on scf Ops 2022-03-31 21:42:48 +08:00
Yan Da
2041b67fbf Now vecadd works 2022-03-30 20:21:47 +08:00
Yan Da
e381dc72c5 Use mlir::Block to replace MlirBlock 2022-03-30 16:31:03 +08:00
Yan Da
e95d98a886 bindings for ModuleOp 2022-03-30 13:32:52 +08:00
Yan Da
38e67b4293 Add more Ops 2022-03-28 19:50:23 +08:00
Yan Da
0d139ec460 Introducing SCF 2022-03-26 17:02:32 +08:00
Yan Da
c53f3486e4 create shr 2022-03-26 16:41:49 +08:00
Yan Da
ba16116f96 Let python manage created objects 2022-03-26 16:31:01 +08:00
Yan Da
fed9925bbd Using stable LLVM release 2022-03-26 16:25:18 +08:00
Yan Da
a17fba86b1 Logic Op creation 2022-03-26 16:16:20 +08:00
Yan Da
5e117966d0 CatOp 2022-03-25 14:17:17 +08:00
Yan Da
d5612333c0 More fcmp ops 2022-03-25 14:12:20 +08:00
Yan Da
07881b4d41 Update includes 2022-03-24 13:46:35 +08:00
Yan Da
cf7fc8d642 Update includes 2022-03-24 13:33:54 +08:00
Yan Da
78c3480c85 Add vecadd example 2022-03-23 13:32:12 +08:00
Yan Da
14a71dcb6f Replace MlirOperation with MlirValue 2022-03-23 13:31:14 +08:00
Yan Da
f2ab318614 New python binding 2022-03-22 21:53:22 +08:00
Yan Da
419bbe0f6e Reverts back to MLIR 14 & updates CMakeLists 2022-03-20 16:41:48 +08:00
Yan Da
a2c31ff434 Init commit 2022-03-17 20:40:55 +08:00
daadaada
539961072c [FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py
Integer signedness now moved from C++ to python
Cleaner frontend type

Co-authored-by: Phil Tillet <phil@openai.com>
2022-03-16 21:25:30 -07:00
Yongjik Kim
0dd2ec2e3a [FRONTEND] Add an assert in case we get a CPU tensor. (#478) 2022-03-16 14:38:56 -07:00
Philippe Tillet
d4d8eaf6c0 [FRONTEND] improved caching mechanism (#474)
Co-authored-by: Greg Brockman <gdb@gregbrockman.com>
Co-authored-by: Christopher Hesse <christopherhesse@users.noreply.github.com>
2022-03-15 12:20:51 -07:00
Doğukan Tuna
21f8a0646d [DOCS] Minor README.md (#470)
Added binary distribution for quick installation
2022-03-05 00:50:37 -08:00
Philippe Tillet
a50a47a85b [CODEGEN] Reverted some changes from previous PR; fixed vectorization characteristics of mma layout (#469) 2022-03-04 01:53:31 -08:00
Philippe Tillet
bb5765df5c [CODEGEN] Now padding shared memory for layout conversion (#468) 2022-03-03 22:19:05 -08:00
daadaada
d9dd97492f Use unique_ptr in ir::context_impl (#462)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-02-24 16:07:10 -08:00
Philippe Tillet
98ed7db8c1 [CODEGEN] Improvements and bugfixes (#463) 2022-02-24 14:56:24 -08:00
daadaada
a9dfdcaaa9 [FRONTEND] Make the performance model work for int8, tf32, and fp32 (#456) 2022-02-11 22:34:42 -08:00
Philippe Tillet
9b100302d3 [FRONTEND] Now using pybind11 to release GIL (#458) 2022-02-10 01:57:39 -08:00
Philippe Tillet
40093a9878 [DOCS] Multiple versions are now supported (#457) 2022-02-09 01:32:41 -08:00
Philippe Tillet
4941bc7001 [DOCS] Some more fixes (#455) 2022-02-08 16:53:56 -08:00
Philippe Tillet
2fdf0a4fe8 [DOCS] changed build command 2022-02-08 11:45:21 -08:00
Philippe Tillet
077d6c8ff0 [DOCS] re-activated tutorials 2022-02-08 11:42:39 -08:00
Philippe Tillet
822ddcd14b [DOCS] Added versioning (#453) 2022-02-08 11:28:18 -08:00
Philippe Tillet
7b48340ffd [CI] Some fixes for the build (#451) 2022-02-06 19:11:33 -08:00
Philippe Tillet
5a8a544d10 [OPS][BLOCKSPARSE] Improved robustness, clarity and performance (#450)
* dds layout now internally re-uses dsd code path for increased code 
* at_mask and kp_mask related things are now dropped from the softmax API. I couldn't think of any case where it was needed beyond is_causal. And if there is any, we should probably find a way to get it implemented statically so that users don't have to materialize masks.
 * fixed bug in blocksparse matmul that caused troubles when layout had a full row/col of zeros
 * blocksparse softmax now no longer modifies any data in-place
 * blocksparse softmax now takes an is_dense arguments that provides better performance. Passing is_dense=True, is_causal=True is the best way to achieve triangular attention.
  * unit tests now test backward pass
2022-02-06 18:00:45 -08:00
Philippe Tillet
69ff52ea1f [CODEGEN] removed buggy (and mostly useless) optimization in peephole pass (#449) 2022-02-05 21:37:23 -08:00
TC
137bb67fad [LANG] Add fp16 to fp8 conversion (#444) 2022-02-02 20:42:09 -08:00
Philippe Tillet
3b20170fa3 Merge pull request #448 from openai/v2.0
`v2.0` is now merged into `master`
2022-01-30 20:49:08 -08:00
Philippe Tillet
b0d6e2f322 [STYLE] run autopep 2022-01-30 20:27:44 -08:00
Philippe Tillet
2922dc141c Merge branch 'master' into v2.0 2022-01-30 20:25:01 -08:00
Philippe Tillet
807d8a1945 [ALL] Merge master (#447) 2022-01-30 20:21:20 -08:00
Philippe Tillet
bef76b142a [BACKEND] float division is now approximate by default (#446) 2022-01-29 18:29:29 -08:00
Philippe Tillet
bd52e530a0 [OPS][BLOCKSPARSE] Fix padding issue in DSD LUT (#445) 2022-01-28 21:40:30 -08:00
daadaada
e68d6a7776 [BACKEND] Making the warp-level tile "more square" to increase data-reuse for tl.dot. (#442)
* Increase smem data-reuse for some layouts

* tweak

* Keep the original tiling logic for sm < 80

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-01-27 09:59:54 -08:00
daadaada
59d371c6eb [BACKEND] Added Int8 mma (#440) 2022-01-27 09:12:44 -08:00
Benjamin Lefaudeux
3a23c1dd33 [BACKEND] minor, hotfix for gcc compilation (#439) 2022-01-23 14:24:02 -08:00
Philippe Tillet
ccf9abe0ba [FRONTEND][RANDOM] Improved backward compatibility of RNG (#438)
The unsigned int PR definitely improved our RNG. However, it requires
different floating point arithmetics which, means the results are not
bit-wise identical to how they were before. This commit revives backward
compatibility, but we should change it back to the "right" way later.
2022-01-21 18:05:55 -08:00
Philippe Tillet
4c97d1ecd7 [FRONTEND] Bunch of fixes here and there (#436) 2022-01-20 10:55:59 -08:00
Philippe Tillet
e0c5709cc8 [FRONTEND] Fixed semantics bug on ptr to bool conversions (#432) 2022-01-17 18:00:03 -08:00
daadaada
2a944ded53 [TESTS] Added bfloat16 tests (#430) 2022-01-13 23:38:32 -08:00
Philippe Tillet
4c94359199 [FRONTEND] Alignment fix-up (#428) 2022-01-11 23:11:58 -08:00
Philippe Tillet
bbc78f6516 [FRONTEND][RANDOM] Make sure offset dtype is always uint32 before calling uint32_to_uniform_float (#427) 2022-01-11 11:08:49 -08:00
Botao Yu
bf32205edc [OPS][BLOCKSPARSE] Remove unnecessary loop and add cuda bool layout support (#425) 2022-01-11 11:07:16 -08:00
daadaada
94a2e10fe5 [BACKEND] Add bf16 & tf32 mma supports (on A100) (#426) 2022-01-11 10:20:31 -08:00
Madeleine Thompson
efdabe6073 [STYLE] check python with flake8 (#424)
I've been using this locally to find errors without running tests, and now that we're using autopep8, it passes with minimal suppressions. This is also what turned up the issues with the tutorials, which were fixed in #422.
2022-01-07 15:28:36 -08:00
Madeleine Thompson
a70acfec77 [STYLE] add isort and autopep8 config files and check on CI (#423)
Also a fix a few more style issues from the "aggressive" mode of autopep8.
2022-01-07 13:11:34 -08:00
Madeleine Thompson
9801aa7b56 [DOCS] fix tutorials for v2.0 (#422)
- Fix meta-parameter usage on tutorials.
- Install tutorial dependencies on CI.
- Switch from `requirements-test.txt` to `extras_require` for test dependencies, and also use it for tutorial dependencies.
- Make some performance tests deterministic.
2022-01-07 12:34:38 -08:00
Madeleine Thompson
8bf551ae7a [STYLE] run autopep8 and isort (#421)
Run:
```
isort ./python
autopep8 -i --ignore E501,E701,E731 $(find ./python/ -name '*.py')
```
with an `.isort.cfg` and then clean up a few warts. This PR should be a no-op; the idea is that this is all boring whitespace changes, and any config file changes will be in a different change to make it easier to review.
2022-01-06 14:34:17 -08:00
Shantanu
6f7acad48f [CODEGEN] Avoid use of deprecated AST nodes (#418)
Co-authored-by: hauntsaninja <>
2022-01-06 12:04:33 -08:00
Madeleine Thompson
120cda015e [FRONTEND] use unsigned integers to simplify RNG (#417) 2022-01-06 10:49:09 -08:00
Philippe Tillet
001fb757fe [OPS][BLOCKSPARSE] Added .contiguous() in blocksparse inputs when necessary (#420) 2022-01-06 09:56:22 -08:00
Madeleine Thompson
0ab9d67bad uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
2022-01-05 15:27:17 -08:00
Madeleine Thompson
d8db0308cb [TEST] use numpy for reference results in test_core.py (#409)
Since numpy supports unsigned integers, and pytorch doesn't, this will make it easier to test unsigned integer support.

This adds an explicit requirement for numpy in tests, but we already required scipy, so it was already an implicit dependency.
2022-01-04 13:07:29 -08:00
Philippe Tillet
03f1256f60 [FRONTEND] Added volatile flag for load (#407) 2021-12-30 22:33:24 -08:00
Noah Ziems
3edc2633e9 [TUTORIALS] Fix 01-vector-add.py typo (#406) 2021-12-29 15:09:34 -08:00
Madeleine Thompson
985798f101 add missing bfloat16 repr and improve assertions (#403)
- `BF16TyID` was missing a repr implementation.
- Throw a better exception on impossible casts.
- Add a few assertions. Tested with a debug build.
- Add `pointer_dtype.__str__` to aid kernel debugging.
2021-12-23 17:01:17 -08:00
Philippe Tillet
d8fce83e7a [FRONTEND] Remade exception picklable 2021-12-21 22:14:06 -08:00
Philippe Tillet
a425f24d54 [FRONTEND] Better cache hook (#400)
Added an additional `repr` argument to the cache hook, which represents a human-readable string representation of the signature and argument attributes associated with the compiled binary.
2021-12-21 21:29:47 -08:00
Philippe Tillet
2509124dd0 [DRIVER] Fixed some issue with how ptxas is used (#399)
Now using tmpnam and properly deleting temporaries when an exception is raised
2021-12-21 14:31:51 -08:00
daadaada
39d4bfed83 [OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
2021-12-21 09:56:10 -08:00
Madeleine Thompson
5cdb948c05 [FRONTEND] signed-integer math fixes and testing (#395)
- Promote 16-bit floating-point `/` and `%` to 32-bit; we have to anyway.
- Do not force result of integer binary operations to be the LHS type. There used to be a bug in pytorch that did this, which Triton matched, but that bug is fixed now.
- When testing signed integer operations, use random numbers from the full range of the type.
- Add an optional `seed` argument to `triton.testing.random` so binary operations are not tested with both sides equal when the LHS and RHS have the same type.
- Fix a bad `CompilationError` invocation.
- Fix a warning suppression that causes tests to fail if you run them with `-W error` on python 3.8.
2021-12-21 09:46:05 -08:00
daadaada
4a8953efa3 [FRONTEND] Replace the legacy print call in triton.cc with the SlotTracker-based one. (#396)
The legacy print call will assign names (e.g., %10) to values, which can be undesirable in some cases.
2021-12-18 18:03:22 -08:00
Madeleine Thompson
fa62b4a8f6 [FRONTEND] better stringification (#394)
- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace.
- Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
2021-12-17 20:11:45 -08:00
Philippe Tillet
4e93b41c52 [GENERAL] Some minor fixups (#393)
* [RUNTIME] Now displaying error message when generated PTX is invalid

* [CODEGEN] Now converting `if` condition to bool implicitly
2021-12-17 18:06:21 -08:00
Philippe Tillet
e062812969 [CODEGEN] Disabled peephole for masked load + select -- masked_load
doesn't work as expected when vectorized
2021-12-17 12:44:47 -08:00
Victor
eb077fc993 [RUNTIME] fixed NVidia DLL names on Windows (#392) 2021-12-16 22:09:52 -08:00
Philippe Tillet
e0b92c1380 [FRONTEND] Reverted from .random import *. There are still some
namespace errors in the Triton frontend apparently
2021-12-16 18:37:51 -08:00
Philippe Tillet
558555630f [FRONTEND] Added xor_sum 2021-12-16 17:55:35 -08:00
Madeleine Thompson
e575ae3443 [FRONTEND] Minor accumulated style and warning fixes (#388)
- Fix some whitespace.
- Make an undeclared dependency on `pytest` explicit.
- Fix deprecated `description-file` use.
- `#ifdef` out a deprecated `PyEval_InitThreads` call.
- Use a slightly different numpy invocation in `test_random.py` to quiet down overflow warnings in tests.
- Fix a deprecated cast in `test_core.py`.
- Suppress a warning about `visit_Constant` in Python 3.9+; we can't migrate yet because it'd break Python 3.6 and 3.7.
- Use chained exceptions for `CompilationError` rather than rolling our own; it makes the error messages nicer.
- Add a `__str__` for `tl.dtype` to make debugging kernels easier; it lets you `print` a dtype to see what type was inferred.
- Fix a few bad escapes.
2021-12-10 15:19:20 -08:00
Philippe Tillet
9def2424ab [RUNTIME] Fix typo in IfExp 2021-12-09 15:14:41 -08:00
Philippe Tillet
e31b9b4e66 [RUNTIME] Better support for None (#387)
* regression test fails but it doesn't make sense to me.
2021-12-09 13:21:22 -08:00
Victor
73b04d71b2 Fixes for building on Windows (#382)
* make C++ code compatible with Windows + MSVC

* added dlfcn-win32 for cross-platform dlopen

* fixed building and pip install on Windows

* fixed shared library file name under Windows
2021-12-07 14:10:58 -08:00
Victor
0ff1a26b70 fixed p2p tests failing when there are no supported p2p devices (#386) 2021-12-06 18:14:03 -08:00
Philippe Tillet
f23bf55f15 [RUNTIME] release the gil on launch (#383) 2021-12-03 13:01:01 -08:00
Philippe Tillet
8ec9f037bb [BACKEND/CODE_GEN] Fixed float32 matmul problem (#380) 2021-11-30 22:00:56 -08:00
Philippe Tillet
c86ad9c9ab [FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379) 2021-11-29 19:11:26 -08:00
daadaada
1296eb877b [RUNTIME] Config hook v2.0 (#373)
* Add pre_hook to triton.Config
* Use argument names in triton.heuristics
* Update base perf
* Remove meta from heuristics
2021-11-21 11:20:59 -08:00
Philippe Tillet
5693b582ea [RUNTIME] Now using pybind11 to avoid memory leaks (#377) 2021-11-21 02:30:22 -08:00
Philippe Tillet
edd4b0c8b7 [CODEGEN] Fixed issue with jit function passed as constexpr 2021-11-16 09:53:34 -08:00
Philippe Tillet
5b7ba3eb96 [CODEGEN] Reverted to old launch method (memory leak?) 2021-11-16 01:21:03 -08:00
Philippe Tillet
791b953b21 [CODEGEN] Reverted to old way to query current stream 2021-11-16 00:17:27 -08:00
Philippe Tillet
b908095872 [VERSION] Bumped triton.__version__ to 2.0.0 2021-11-12 15:10:36 -08:00
Philippe Tillet
01cc3d4503 [RUNTIME] Restored do_not_specialize (#374) 2021-11-12 15:06:55 -08:00
Philippe Tillet
e66bf76354 [RUNTIME] Bunch of bugfixes (#372) 2021-11-12 00:55:00 -08:00
Philippe Tillet
f7ab96cfd7 [FRONTEND] Fixed some issues with constexpr 2021-11-09 13:03:09 -08:00
daadaada
9a02dddf29 Fix sdd_lut (#368) 2021-11-08 08:25:05 -08:00
Philippe Tillet
5d54352164 [FRONTEND] Significantly reduce kernel launch time (#367) 2021-11-04 13:25:24 -07:00
Philippe Tillet
2acaa4d0dd [LANG] Added support for constexpr (#361) 2021-10-30 00:32:58 -07:00
Philippe Tillet
b7f0e87dc2 [DRIVER] Removed std::cout log message 2021-10-29 10:42:10 -07:00
Philippe Tillet
770ea96cca [PACKAGING] Bumped dev version to 2.0.0 2021-10-29 01:28:17 -07:00
Philippe Tillet
969d6de8a2 [PACKAGING] Bumped dev version to 1.1.2 2021-10-29 01:26:21 -07:00
309 changed files with 29319 additions and 57384 deletions

1
.clang-format Normal file
View File

@@ -0,0 +1 @@
BasedOnStyle: LLVM

55
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1,55 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# @global-owner1 and @global-owner2 will be requested for
# review when someone opens a pull request.
* @ptillet
# --------
# Analyses
# --------
# Alias analysis
include/triton/Analysis/Alias.h @Jokeren
lib/Analysis/Alias.cpp @Jokeren
# Allocation analysis
include/triton/Analysis/Allocation.h @Jokeren
lib/Analysis/Allocation.cpp @Jokeren
# Membar analysis
include/triton/Analysis/Membar.h @Jokeren
lib/Analysis/Membar.cpp @Jokeren
# AxisInfo analysis
include/triton/Analysis/AxisInfo.h @ptillet
lib/Analysis/AxisInfo.cpp @ptillet
# Utilities
include/triton/Analysis/Utility.h @Jokeren
lib/Analysis/Utility.cpp @Jokeren
# ----------
# Dialects
# ----------
# Pipeline pass
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada
# Coalesce pass
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet
# Layout simplification pass
lib/Dialect/TritonGPU/Transforms/Combine.cpp @ptillet
# -----------
# Conversions
# -----------
# TritonGPUToLLVM
include/triton/Conversion/TritonGPUToLLVM/ @goostavz @Superjomn
lib/Conversions/TritonGPUToLLVM @goostavz @Superjomn
# TritonToTritonGPU
include/triton/Conversion/TritonToTritonGPU/ @daadaada
lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @daadaada
# -------
# Targets
# -------
# LLVMIR
include/triton/Target/LLVMIR/ @goostavz @Superjomn
lib/Target/LLVMIR @goostavz @Superjomn
# PTX
include/triton/Target/PTX/ @goostavz @Superjomn
lib/Target/PTX @goostavz @Superjomn

View File

@@ -1,42 +0,0 @@
name: Documentation
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
Build-Documentation:
runs-on: self-hosted
steps:
- name: Checkout gh-pages
uses: actions/checkout@v1
with:
ref: 'gh-pages'
- name: Checkout branch
uses: actions/checkout@v1
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e .
- name: Build docs
run: |
cd docs
make html
- name: Publish docs
run: |
git checkout gh-pages
sh ./update-website.sh
eval `ssh-agent -s`
DISPLAY=:0 SSH_ASKPASS=~/.ssh/give_pass.sh ssh-add ${{ secrets.SSH_KEY }} <<< ${{ secrets.SSH_PASS }}
git remote set-url origin git@github.com:openai/triton.git
git push

View File

@@ -4,44 +4,88 @@ on:
workflow_dispatch:
pull_request:
branches:
- master
- v2.0
- main
- triton-mlir
jobs:
Runner-Preparation:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A10"], "macos-10.15"]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
fi
Integration-Tests:
runs-on: self-hosted
needs: Runner-Preparation
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}}
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Clear cache
run: |
rm -r /tmp/triton/
continue-on-error: true
rm -rf ~/.triton/cache/
- name: Check imports
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: startsWith(matrix.runner, 'ubuntu')
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton
run: |
alias python='python3'
cd python
pip3 install -e .
TRITON_USE_ASSERT_ENABLED_LLVM=TRUE pip3 install -e '.[tests]'
- name: Unit tests
- name: Run lit tests
run: |
cd python/test/unit
pytest -vs .
cd python
LIT_TEST_DIR="build/$(ls build)/test"
if [ ! -d "$LIT_TEST_DIR" ]; then
echo "Not found `$LIT_TEST_DIR`. Did you change an installation method?" ; exit -1
fi
lit -v "$LIT_TEST_DIR"
- name: Regression tests
- name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted'}}
run: |
cd python/test/regression
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc
cd python/tests
pytest
- name: Run CXX unittests
run: |
cd python/
cd "build/$(ls build)"
ctest

View File

@@ -8,7 +8,7 @@ jobs:
Build-Wheels:
runs-on: self-hosted
runs-on: [self-hosted, V100]
steps:
@@ -18,7 +18,7 @@ jobs:
- name: Patch setup.py
run: |
#sed -i 's/name\=\"triton\"/name="triton-nightly"/g' python/setup.py
export LATEST_DATE=$(git show -s --format=%ci `git rev-parse HEAD` | cut -d ' ' -f 1 | sed 's/-//g')
export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d' --format="%cd")
sed -i -r "s/version\=\"(.*)\"/version=\"\1-dev"$LATEST_DATE"\"/g" python/setup.py
echo "" >> python/setup.cfg
echo "[build_ext]" >> python/setup.cfg

6
.gitignore vendored
View File

@@ -1,6 +1,12 @@
build/
__pycache__
.pytest_cache
python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
.vscode
.vs

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "deps/dlfcn-win32"]
path = deps/dlfcn-win32
url = https://github.com/dlfcn-win32/dlfcn-win32.git

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
known_local_folder=triton
line_length=88
py_version=36

View File

@@ -1,18 +1,19 @@
cmake_minimum_required(VERSION 3.6)
include(ExternalProject)
if(NOT TRITON_LLVM_BUILD_DIR)
set(TRITON_LLVM_BUILD_DIR ${CMAKE_BINARY_DIR})
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_INCLUDE_CURRENT_DIR ON)
project(triton)
include(CTest)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
if(NOT WIN32)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
endif()
# Options
option(BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
# Default build type
if(NOT CMAKE_BUILD_TYPE)
@@ -20,100 +21,207 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
find_library(TERMINFO_LIBRARY tinfo)
if(NOT WIN32)
find_library(TERMINFO_LIBRARY tinfo)
endif()
# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17")
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/deps/dlfcn-win32/src)
add_subdirectory(deps/dlfcn-win32/src ${CMAKE_BINARY_DIR}/dlfcn-win32)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()
##########
# LLVM
##########
if("${LLVM_LIBRARY_DIR}" STREQUAL "")
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
if (NOT MLIR_DIR)
if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
include_directories(${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
add_definitions(${LLVM_DEFINITIONS_LIST})
llvm_map_components_to_libnames(LLVM_LIBRARIES support core
NVPTXInfo nvptxcodegen
AMDGPUInfo AMDGPUcodegen
)
else()
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
endif()
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
endif()
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
else()
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a
libLLVMAMDGPUDisassembler.a
libLLVMMCDisassembler.a
libLLVMAMDGPUCodeGen.a
libLLVMMIRParser.a
libLLVMGlobalISel.a
libLLVMSelectionDAG.a
libLLVMipo.a
libLLVMInstrumentation.a
libLLVMVectorize.a
libLLVMLinker.a
libLLVMIRReader.a
libLLVMAsmParser.a
libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a
libLLVMDebugInfoDWARF.a
libLLVMCodeGen.a
libLLVMTarget.a
libLLVMScalarOpts.a
libLLVMInstCombine.a
libLLVMAggressiveInstCombine.a
libLLVMTransformUtils.a
libLLVMBitWriter.a
libLLVMAnalysis.a
libLLVMProfileData.a
libLLVMObject.a
libLLVMTextAPI.a
libLLVMBitReader.a
libLLVMAMDGPUAsmParser.a
libLLVMMCParser.a
libLLVMAMDGPUDesc.a
libLLVMAMDGPUUtils.a
libLLVMMC.a
libLLVMDebugInfoCodeView.a
libLLVMDebugInfoMSF.a
libLLVMCore.a
libLLVMRemarks.a
libLLVMBitstreamReader.a
libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
)
set(LLVM_LIBRARIES
libLLVMNVPTXCodeGen.a
libLLVMNVPTXDesc.a
libLLVMNVPTXInfo.a
libLLVMAMDGPUDisassembler.a
libLLVMMCDisassembler.a
libLLVMAMDGPUCodeGen.a
libLLVMMIRParser.a
libLLVMGlobalISel.a
libLLVMSelectionDAG.a
libLLVMipo.a
libLLVMInstrumentation.a
libLLVMVectorize.a
libLLVMLinker.a
libLLVMIRReader.a
libLLVMAsmParser.a
libLLVMFrontendOpenMP.a
libLLVMAsmPrinter.a
libLLVMDebugInfoDWARF.a
libLLVMCodeGen.a
libLLVMTarget.a
libLLVMScalarOpts.a
libLLVMInstCombine.a
libLLVMAggressiveInstCombine.a
libLLVMTransformUtils.a
libLLVMBitWriter.a
libLLVMAnalysis.a
libLLVMProfileData.a
libLLVMObject.a
libLLVMTextAPI.a
libLLVMBitReader.a
libLLVMAMDGPUAsmParser.a
libLLVMMCParser.a
libLLVMAMDGPUDesc.a
libLLVMAMDGPUUtils.a
libLLVMMC.a
libLLVMDebugInfoCodeView.a
libLLVMDebugInfoMSF.a
libLLVMCore.a
libLLVMRemarks.a
libLLVMBitstreamReader.a
libLLVMBinaryFormat.a
libLLVMAMDGPUInfo.a
libLLVMSupport.a
libLLVMDemangle.a
libLLVMPasses.a
libLLVMAnalysis.a
libLLVMTransformUtils.a
libLLVMScalarOpts.a
libLLVMTransformUtils.a
libLLVMipo.a
libLLVMObjCARCOpts.a
libLLVMCoroutines.a
libLLVMAnalysis.a
)
endif()
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
endif()
include_directories("${LLVM_INCLUDE_DIRS}")
# Python module
if(BUILD_PYTHON_MODULE)
if(TRITON_BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
# Build CUTLASS python wrapper if requested
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
set(CUTLASS_INCLUDE_DIR "$ENV{CUTLASS_INCLUDE_DIR}")
set(CUTLASS_LIBRARY_DIR "$ENV{CUTLASS_LIBRARY_DIR}")
if(NOT("${CUTLASS_INCLUDE_DIR}" STREQUAL "") AND NOT("${CUTLASS_LIBRARY_DIR}" STREQUAL ""))
set(CUTLASS_SRC ${PYTHON_SRC_PATH}/cutlass.cc)
add_definitions(-DWITH_CUTLASS_BINDINGS)
set(CUTLASS_LIBRARIES "cutlass.a")
include_directories("." ${PYTHON_SRC_PATH})
if (PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development)
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
link_libraries(${Python3_LIBRARIES})
add_link_options(${Python3_LINK_OPTIONS})
endif()
include_directories("." ${PYTHON_SRC_PATH} ${PYTHON_INCLUDE_DIRS} ${CUTLASS_INCLUDE_DIR})
link_directories(${PYTHON_LINK_DIRS} ${CUTLASS_LIBRARY_DIR})
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc ${PYTHON_SRC_PATH}/superblock.cc ${CUTLASS_SRC})
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# find_package(Python3 REQUIRED COMPONENTS Development)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()
# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen) # required by AddMLIR
include(AddLLVM)
include(AddMLIR)
# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
# link_directories(${LLVM_LIBRARY_DIR})
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(bin)
add_library(triton SHARED ${PYTHON_SRC})
# find_package(PythonLibs REQUIRED)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(triton
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonLLVMIR
TritonPTX
${dialect_libs}
${conversion_libs}
# optimizations
MLIRPass
MLIRTransforms
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRIR
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
target_link_libraries(triton ${LLVM_LIBRARIES} z ${TERMINFO_LIBRARY})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z)
endif()
if(BUILD_PYTHON_MODULE)
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
# Check if the platform is MacOS
if(APPLE)
@@ -121,3 +229,7 @@ if(BUILD_PYTHON_MODULE)
endif()
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
endif()
add_subdirectory(test)
add_subdirectory(unittest)

View File

@@ -1,6 +1,6 @@
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2021 OpenAI
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
@@ -20,4 +20,4 @@
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
*/

View File

@@ -18,6 +18,21 @@ The foundations of this project are described in the following MAPL2019 publicat
The [official documentation](https://triton-lang.org) contains installation instructions and tutorials.
# Quick Installation
You can install the latest stable release of Triton from pip:
```bash
pip install triton
```
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
And the latest nightly release:
```bash
pip install -U --pre triton
```
# Changelog
Version 1.1 is out! New features include:

60
bin/CMakeLists.txt Normal file
View File

@@ -0,0 +1,60 @@
add_subdirectory(FileCheck)
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
# TODO: what's this?
llvm_update_compile_flags(triton-opt)
target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIROptLib
MLIRPass
MLIRTransforms
)
mlir_check_all_link_libraries(triton-opt)
# add_llvm_executable(triton-translate triton-translate.cpp PARTIAL_SOURCES_INTENDED)
#llvm_update_compile_flags(triton-translate)
# target_link_libraries(triton-translate PRIVATE
# TritonAnalysis
# TritonTransforms
# TritonGPUTransforms
# TritonLLVMIR
# TritonDriver
# ${dialect_libs}
# ${conversion_libs}
# # tests
# TritonTestAnalysis
# LLVMCore
# LLVMSupport
# LLVMOption
# LLVMCodeGen
# LLVMAsmParser
# # MLIR core
# MLIROptLib
# MLIRIR
# MLIRPass
# MLIRSupport
# MLIRTransforms
# MLIRExecutionEngine
# MLIRMathToLLVM
# MLIRTransformUtils
# MLIRLLVMToLLVMIRTranslation
# MLIRNVVMToLLVMIRTranslation
# )
# mlir_check_all_link_libraries(triton-translate)

View File

@@ -0,0 +1,2 @@
add_llvm_executable(FileCheck FileCheck.cpp)
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)

882
bin/FileCheck/FileCheck.cpp Normal file
View File

@@ -0,0 +1,882 @@
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// FileCheck does a line-by line check of a file that validates whether it
// contains the expected content. This is useful for regression tests etc.
//
// This program exits with an exit status of 2 on error, exit status of 0 if
// the file matched the expected contents, and exit status of 1 if it did not
// contain the expected contents.
//
//===----------------------------------------------------------------------===//
#include "llvm/FileCheck/FileCheck.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <map>
using namespace llvm;
static cl::extrahelp FileCheckOptsEnv(
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
"from the command line.\n");
static cl::opt<std::string>
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
static cl::opt<std::string>
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
cl::init("-"), cl::value_desc("filename"));
static cl::list<std::string> CheckPrefixes(
"check-prefix",
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
static cl::alias CheckPrefixesAlias(
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
cl::NotHidden,
cl::desc(
"Alias for -check-prefix permitting multiple comma separated values"));
static cl::list<std::string> CommentPrefixes(
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
"LLVM's LIT-based test suites, which should be easier to\n"
"maintain if they all follow a consistent comment style. This\n"
"feature is meant for non-LIT test suites using FileCheck."));
static cl::opt<bool> NoCanonicalizeWhiteSpace(
"strict-whitespace",
cl::desc("Do not treat all horizontal whitespace as equivalent"));
static cl::opt<bool> IgnoreCase("ignore-case",
cl::desc("Use case-insensitive matching"));
static cl::list<std::string> ImplicitCheckNot(
"implicit-check-not",
cl::desc("Add an implicit negative check with this pattern to every\n"
"positive check. This can be used to ensure that no instances of\n"
"this pattern occur which are not matched by a positive pattern"),
cl::value_desc("pattern"));
static cl::list<std::string>
GlobalDefines("D", cl::AlwaysPrefix,
cl::desc("Define a variable to be used in capture patterns."),
cl::value_desc("VAR=VALUE"));
static cl::opt<bool> AllowEmptyInput(
"allow-empty", cl::init(false),
cl::desc("Allow the input file to be empty. This is useful when making\n"
"checks that some error message does not occur, for example."));
static cl::opt<bool> AllowUnusedPrefixes(
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
cl::desc("Allow prefixes to be specified but not appear in the test."));
static cl::opt<bool> MatchFullLines(
"match-full-lines", cl::init(false),
cl::desc("Require all positive matches to cover an entire input line.\n"
"Allows leading and trailing whitespace if --strict-whitespace\n"
"is not also passed."));
static cl::opt<bool> EnableVarScope(
"enable-var-scope", cl::init(false),
cl::desc("Enables scope for regex variables. Variables with names that\n"
"do not start with '$' will be reset at the beginning of\n"
"each CHECK-LABEL block."));
static cl::opt<bool> AllowDeprecatedDagOverlap(
"allow-deprecated-dag-overlap", cl::init(false),
cl::desc("Enable overlapping among matches in a group of consecutive\n"
"CHECK-DAG directives. This option is deprecated and is only\n"
"provided for convenience as old tests are migrated to the new\n"
"non-overlapping CHECK-DAG implementation.\n"));
static cl::opt<bool> Verbose(
"v", cl::init(false), cl::ZeroOrMore,
cl::desc("Print directive pattern matches, or add them to the input dump\n"
"if enabled.\n"));
static cl::opt<bool> VerboseVerbose(
"vv", cl::init(false), cl::ZeroOrMore,
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
"issues, or add it to the input dump if enabled. Implies\n"
"-v.\n"));
// The order of DumpInputValue members affects their precedence, as documented
// for -dump-input below.
enum DumpInputValue {
DumpInputNever,
DumpInputFail,
DumpInputAlways,
DumpInputHelp
};
static cl::list<DumpInputValue> DumpInputs(
"dump-input",
cl::desc("Dump input to stderr, adding annotations representing\n"
"currently enabled diagnostics. When there are multiple\n"
"occurrences of this option, the <value> that appears earliest\n"
"in the list below has precedence. The default is 'fail'.\n"),
cl::value_desc("mode"),
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
clEnumValN(DumpInputAlways, "always", "Always dump input"),
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
clEnumValN(DumpInputNever, "never", "Never dump input")));
// The order of DumpInputFilterValue members affects their precedence, as
// documented for -dump-input-filter below.
enum DumpInputFilterValue {
DumpInputFilterError,
DumpInputFilterAnnotation,
DumpInputFilterAnnotationFull,
DumpInputFilterAll
};
static cl::list<DumpInputFilterValue> DumpInputFilters(
"dump-input-filter",
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
"kind <value> plus any context specified by -dump-input-context.\n"
"When there are multiple occurrences of this option, the <value>\n"
"that appears earliest in the list below has precedence. The\n"
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
"-dump-input=always.\n"),
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
"Input lines with annotations"),
clEnumValN(DumpInputFilterAnnotation, "annotation",
"Input lines with starting points of annotations"),
clEnumValN(DumpInputFilterError, "error",
"Input lines with starting points of error "
"annotations")));
static cl::list<unsigned> DumpInputContexts(
"dump-input-context", cl::value_desc("N"),
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
"before and <N> input lines after any lines specified by\n"
"-dump-input-filter. When there are multiple occurrences of\n"
"this option, the largest specified <N> has precedence. The\n"
"default is 5.\n"));
typedef cl::list<std::string>::const_iterator prefix_iterator;
static void DumpCommandLine(int argc, char **argv) {
errs() << "FileCheck command line: ";
for (int I = 0; I < argc; I++)
errs() << " " << argv[I];
errs() << "\n";
}
struct MarkerStyle {
/// The starting char (before tildes) for marking the line.
char Lead;
/// What color to use for this annotation.
raw_ostream::Colors Color;
/// A note to follow the marker, or empty string if none.
std::string Note;
/// Does this marker indicate inclusion by -dump-input-filter=error?
bool FiltersAsError;
MarkerStyle() {}
MarkerStyle(char Lead, raw_ostream::Colors Color,
const std::string &Note = "", bool FiltersAsError = false)
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
assert((!FiltersAsError || !Note.empty()) &&
"expected error diagnostic to have note");
}
};
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
switch (MatchTy) {
case FileCheckDiag::MatchFoundAndExpected:
return MarkerStyle('^', raw_ostream::GREEN);
case FileCheckDiag::MatchFoundButExcluded:
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButWrongLine:
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFoundButDiscarded:
return MarkerStyle('!', raw_ostream::CYAN,
"discard: overlaps earlier match");
case FileCheckDiag::MatchFoundErrorNote:
// Note should always be overridden within the FileCheckDiag.
return MarkerStyle('!', raw_ostream::RED,
"error: unknown error after match",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneAndExcluded:
return MarkerStyle('X', raw_ostream::GREEN);
case FileCheckDiag::MatchNoneButExpected:
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchNoneForInvalidPattern:
return MarkerStyle('X', raw_ostream::RED,
"error: match failed for invalid pattern",
/*FiltersAsError=*/true);
case FileCheckDiag::MatchFuzzy:
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
/*FiltersAsError=*/true);
}
llvm_unreachable_internal("unexpected match type");
}
static void DumpInputAnnotationHelp(raw_ostream &OS) {
OS << "The following description was requested by -dump-input=help to\n"
<< "explain the input dump printed by FileCheck.\n"
<< "\n"
<< "Related command-line options:\n"
<< "\n"
<< " - -dump-input=<value> enables or disables the input dump\n"
<< " - -dump-input-filter=<value> filters the input lines\n"
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
<< " - -v and -vv add more annotations\n"
<< " - -color forces colors to be enabled both in the dump and below\n"
<< " - -help documents the above options in more detail\n"
<< "\n"
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
<< "maximum debugging output on failures:\n"
<< "\n"
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
<< "\n"
<< "Input dump annotation format:\n"
<< "\n";
// Labels for input lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
OS << " labels line number L of the input file\n"
<< " An extra space is added after each input line to represent"
<< " the\n"
<< " newline character\n";
// Labels for annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
OS << " labels the only match result for either (1) a pattern of type T"
<< " from\n"
<< " line L of the check file if L is an integer or (2) the"
<< " I-th implicit\n"
<< " pattern if L is \"imp\" followed by an integer "
<< "I (index origin one)\n";
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
OS << " labels the Nth match result for such a pattern\n";
// Markers on annotation lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
OS << " marks good match (reported if -v)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
OS << " marks bad match, such as:\n"
<< " - CHECK-NEXT on same line as previous match (error)\n"
<< " - CHECK-NOT found (error)\n"
<< " - CHECK-DAG overlapping match (discarded, reported if "
<< "-vv)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
OS << " marks search range when no match is found, such as:\n"
<< " - CHECK-NEXT not found (error)\n"
<< " - CHECK-NOT not found (success, reported if -vv)\n"
<< " - CHECK-DAG not found after discarded matches (error)\n"
<< " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
OS << " marks fuzzy match when no match is found\n";
// Elided lines.
OS << " - ";
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
OS << " indicates elided input lines and annotations, as specified by\n"
<< " -dump-input-filter and -dump-input-context\n";
// Colors.
OS << " - colors ";
WithColor(OS, raw_ostream::GREEN, true) << "success";
OS << ", ";
WithColor(OS, raw_ostream::RED, true) << "error";
OS << ", ";
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
OS << ", ";
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
OS << "\n";
}
/// An annotation for a single input line.
struct InputAnnotation {
/// The index of the match result across all checks
unsigned DiagIndex;
/// The label for this annotation.
std::string Label;
/// Is this the initial fragment of a diagnostic that has been broken across
/// multiple lines?
bool IsFirstLine;
/// What input line (one-origin indexing) this annotation marks. This might
/// be different from the starting line of the original diagnostic if
/// !IsFirstLine.
unsigned InputLine;
/// The column range (one-origin indexing, open end) in which to mark the
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
/// before the newline.
unsigned InputStartCol, InputEndCol;
/// The marker to use.
MarkerStyle Marker;
/// Whether this annotation represents a good match for an expected pattern.
bool FoundAndExpectedMatch;
};
/// Get an abbreviation for the check type.
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
switch (Ty) {
case Check::CheckPlain:
if (Ty.getCount() > 1)
return "count";
return "check";
case Check::CheckNext:
return "next";
case Check::CheckSame:
return "same";
case Check::CheckNot:
return "not";
case Check::CheckDAG:
return "dag";
case Check::CheckLabel:
return "label";
case Check::CheckEmpty:
return "empty";
case Check::CheckComment:
return "com";
case Check::CheckEOF:
return "eof";
case Check::CheckBadNot:
return "bad-not";
case Check::CheckBadCount:
return "bad-count";
case Check::CheckNone:
llvm_unreachable("invalid FileCheckType");
}
llvm_unreachable("unknown FileCheckType");
}
static void
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
const std::vector<FileCheckDiag> &Diags,
std::vector<InputAnnotation> &Annotations,
unsigned &LabelWidth) {
struct CompareSMLoc {
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
return LHS.getPointer() < RHS.getPointer();
}
};
// How many diagnostics does each pattern have?
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
for (auto Diag : Diags)
++DiagCountPerPattern[Diag.CheckLoc];
// How many diagnostics have we seen so far per pattern?
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
// How many total diagnostics have we seen so far?
unsigned DiagIndex = 0;
// What's the widest label?
LabelWidth = 0;
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
++DiagItr) {
InputAnnotation A;
A.DiagIndex = DiagIndex++;
// Build label, which uniquely identifies this check result.
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
auto CheckLineAndCol =
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
llvm::raw_string_ostream Label(A.Label);
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
if (CheckBufferID == CheckFileBufferID)
Label << CheckLineAndCol.first;
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
CheckBufferID < ImpPatBufferIDRange.second)
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
else
llvm_unreachable("expected diagnostic's check location to be either in "
"the check file or for an implicit pattern");
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
A.Marker = GetMarker(DiagItr->MatchTy);
if (!DiagItr->Note.empty()) {
A.Marker.Note = DiagItr->Note;
// It's less confusing if notes that don't actually have ranges don't have
// markers. For example, a marker for 'with "VAR" equal to "5"' would
// seem to indicate where "VAR" matches, but the location we actually have
// for the marker simply points to the start of the match/search range for
// the full pattern of which the substitution is potentially just one
// component.
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
DiagItr->InputStartCol == DiagItr->InputEndCol)
A.Marker.Lead = ' ';
}
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
assert(!DiagItr->Note.empty() &&
"expected custom note for MatchFoundErrorNote");
A.Marker.Note = "error: " + A.Marker.Note;
}
A.FoundAndExpectedMatch =
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
// Compute the mark location, and break annotation into multiple
// annotations if it spans multiple lines.
A.IsFirstLine = true;
A.InputLine = DiagItr->InputStartLine;
A.InputStartCol = DiagItr->InputStartCol;
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
// Sometimes ranges are empty in order to indicate a specific point, but
// that would mean nothing would be marked, so adjust the range to
// include the following character.
A.InputEndCol =
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
Annotations.push_back(A);
} else {
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
"expected input range not to be inverted");
A.InputEndCol = UINT_MAX;
Annotations.push_back(A);
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
L <= E; ++L) {
// If a range ends before the first column on a line, then it has no
// characters on that line, so there's nothing to render.
if (DiagItr->InputEndCol == 1 && L == E)
break;
InputAnnotation B;
B.DiagIndex = A.DiagIndex;
B.Label = A.Label;
B.IsFirstLine = false;
B.InputLine = L;
B.Marker = A.Marker;
B.Marker.Lead = '~';
B.Marker.Note = "";
B.InputStartCol = 1;
if (L != E)
B.InputEndCol = UINT_MAX;
else
B.InputEndCol = DiagItr->InputEndCol;
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
Annotations.push_back(B);
}
}
}
}
static unsigned FindInputLineInFilter(
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
if (DumpInputFilter == DumpInputFilterAll)
return CurInputLine;
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
++AnnotationItr) {
switch (DumpInputFilter) {
case DumpInputFilterAll:
llvm_unreachable("unexpected DumpInputFilterAll");
break;
case DumpInputFilterAnnotationFull:
return AnnotationItr->InputLine;
case DumpInputFilterAnnotation:
if (AnnotationItr->IsFirstLine)
return AnnotationItr->InputLine;
break;
case DumpInputFilterError:
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
return AnnotationItr->InputLine;
break;
}
}
return UINT_MAX;
}
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
unsigned LabelWidth) {
if (ElidedLines.empty())
return;
unsigned EllipsisLines = 3;
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
for (unsigned i = 0; i < EllipsisLines; ++i) {
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
<< right_justify(".", LabelWidth);
OS << '\n';
}
} else
OS << ElidedLines;
ElidedLines.clear();
}
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
DumpInputFilterValue DumpInputFilter,
unsigned DumpInputContext,
StringRef InputFileText,
std::vector<InputAnnotation> &Annotations,
unsigned LabelWidth) {
OS << "Input was:\n<<<<<<\n";
// Sort annotations.
llvm::sort(Annotations,
[](const InputAnnotation &A, const InputAnnotation &B) {
// 1. Sort annotations in the order of the input lines.
//
// This makes it easier to find relevant annotations while
// iterating input lines in the implementation below. FileCheck
// does not always produce diagnostics in the order of input
// lines due to, for example, CHECK-DAG and CHECK-NOT.
if (A.InputLine != B.InputLine)
return A.InputLine < B.InputLine;
// 2. Sort annotations in the temporal order FileCheck produced
// their associated diagnostics.
//
// This sort offers several benefits:
//
// A. On a single input line, the order of annotations reflects
// the FileCheck logic for processing directives/patterns.
// This can be helpful in understanding cases in which the
// order of the associated directives/patterns in the check
// file or on the command line either (i) does not match the
// temporal order in which FileCheck looks for matches for the
// directives/patterns (due to, for example, CHECK-LABEL,
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
// that order but does not match the order of those
// diagnostics along an input line (due to, for example,
// CHECK-DAG).
//
// On the other hand, because our presentation format presents
// input lines in order, there's no clear way to offer the
// same benefit across input lines. For consistency, it might
// then seem worthwhile to have annotations on a single line
// also sorted in input order (that is, by input column).
// However, in practice, this appears to be more confusing
// than helpful. Perhaps it's intuitive to expect annotations
// to be listed in the temporal order in which they were
// produced except in cases the presentation format obviously
// and inherently cannot support it (that is, across input
// lines).
//
// B. When diagnostics' annotations are split among multiple
// input lines, the user must track them from one input line
// to the next. One property of the sort chosen here is that
// it facilitates the user in this regard by ensuring the
// following: when comparing any two input lines, a
// diagnostic's annotations are sorted in the same position
// relative to all other diagnostics' annotations.
return A.DiagIndex < B.DiagIndex;
});
// Compute the width of the label column.
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
*InputFileEnd = InputFileText.bytes_end();
unsigned LineCount = InputFileText.count('\n');
if (InputFileEnd[-1] != '\n')
++LineCount;
unsigned LineNoWidth = std::log10(LineCount) + 1;
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
// on input lines and (2) to the right of the (left-aligned) labels on
// annotation lines so that input lines and annotation lines are more
// visually distinct. For example, the spaces on the annotation lines ensure
// that input line numbers and check directive line numbers never align
// horizontally. Those line numbers might not even be for the same file.
// One space would be enough to achieve that, but more makes it even easier
// to see.
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
// Print annotated input lines.
unsigned PrevLineInFilter = 0; // 0 means none so far
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
std::string ElidedLines;
raw_string_ostream ElidedLinesOS(ElidedLines);
ColorMode TheColorMode =
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
if (TheColorMode == ColorMode::Enable)
ElidedLinesOS.enable_colors(true);
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
for (unsigned Line = 1;
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
const unsigned char *InputFileLine = InputFilePtr;
// Compute the previous and next line included by the filter.
if (NextLineInFilter < Line)
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
AnnotationItr, AnnotationEnd);
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
if (NextLineInFilter == Line)
PrevLineInFilter = Line;
// Elide this input line and its annotations if it's not within the
// context specified by -dump-input-context of an input line included by
// -dump-input-filter. However, in case the resulting ellipsis would occupy
// more lines than the input lines and annotations it elides, buffer the
// elided lines and annotations so we can print them instead.
raw_ostream *LineOS = &OS;
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
(NextLineInFilter == UINT_MAX ||
Line + DumpInputContext < NextLineInFilter))
LineOS = &ElidedLinesOS;
else {
LineOS = &OS;
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
}
// Print right-aligned line number.
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
TheColorMode)
<< format_decimal(Line, LabelWidth) << ": ";
// For the case where -v and colors are enabled, find the annotations for
// good matches for expected patterns in order to highlight everything
// else in the line. There are no such annotations if -v is disabled.
std::vector<InputAnnotation> FoundAndExpectedMatches;
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
++I) {
if (I->FoundAndExpectedMatch)
FoundAndExpectedMatches.push_back(*I);
}
}
// Print numbered line with highlighting where there are no matches for
// expected patterns.
bool Newline = false;
{
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
/*BG=*/false, TheColorMode);
bool InMatch = false;
if (Req.Verbose)
COS.changeColor(raw_ostream::CYAN, true, true);
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
bool WasInMatch = InMatch;
InMatch = false;
for (auto M : FoundAndExpectedMatches) {
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
InMatch = true;
break;
}
}
if (!WasInMatch && InMatch)
COS.resetColor();
else if (WasInMatch && !InMatch)
COS.changeColor(raw_ostream::CYAN, true, true);
if (*InputFilePtr == '\n') {
Newline = true;
COS << ' ';
} else
COS << *InputFilePtr;
++InputFilePtr;
}
}
*LineOS << '\n';
unsigned InputLineWidth = InputFilePtr - InputFileLine;
// Print any annotations.
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
/*BG=*/false, TheColorMode);
// The two spaces below are where the ": " appears on input lines.
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
unsigned Col;
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
COS << ' ';
COS << AnnotationItr->Marker.Lead;
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
++Col)
COS << '~';
const std::string &Note = AnnotationItr->Marker.Note;
if (!Note.empty()) {
// Put the note at the end of the input line. If we were to instead
// put the note right after the marker, subsequent annotations for the
// same input line might appear to mark this note instead of the input
// line.
for (; Col <= InputLineWidth; ++Col)
COS << ' ';
COS << ' ' << Note;
}
COS << '\n';
++AnnotationItr;
}
}
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
OS << ">>>>>>\n";
}
int main(int argc, char **argv) {
// Enable use of ANSI color codes because FileCheck is using them to
// highlight text.
llvm::sys::Process::UseANSIEscapeCodes(true);
InitLLVM X(argc, argv);
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
"FILECHECK_OPTS");
// Select -dump-input* values. The -help documentation specifies the default
// value and which value to choose if an option is specified multiple times.
// In the latter case, the general rule of thumb is to choose the value that
// provides the most information.
DumpInputValue DumpInput =
DumpInputs.empty()
? DumpInputFail
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
DumpInputFilterValue DumpInputFilter;
if (DumpInputFilters.empty())
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
: DumpInputFilterError;
else
DumpInputFilter =
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
unsigned DumpInputContext = DumpInputContexts.empty()
? 5
: *std::max_element(DumpInputContexts.begin(),
DumpInputContexts.end());
if (DumpInput == DumpInputHelp) {
DumpInputAnnotationHelp(outs());
return 0;
}
if (CheckFilename.empty()) {
errs() << "<check-file> not specified\n";
return 2;
}
FileCheckRequest Req;
append_range(Req.CheckPrefixes, CheckPrefixes);
append_range(Req.CommentPrefixes, CommentPrefixes);
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
bool GlobalDefineError = false;
for (StringRef G : GlobalDefines) {
size_t EqIdx = G.find('=');
if (EqIdx == std::string::npos) {
errs() << "Missing equal sign in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
if (EqIdx == 0) {
errs() << "Missing variable name in command-line definition '-D" << G
<< "'\n";
GlobalDefineError = true;
continue;
}
Req.GlobalDefines.push_back(G);
}
if (GlobalDefineError)
return 2;
Req.AllowEmptyInput = AllowEmptyInput;
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
Req.EnableVarScope = EnableVarScope;
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
Req.Verbose = Verbose;
Req.VerboseVerbose = VerboseVerbose;
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
Req.MatchFullLines = MatchFullLines;
Req.IgnoreCase = IgnoreCase;
if (VerboseVerbose)
Req.Verbose = true;
FileCheck FC(Req);
if (!FC.ValidateCheckPrefixes())
return 2;
Regex PrefixRE = FC.buildCheckPrefixRegex();
std::string REError;
if (!PrefixRE.isValid(REError)) {
errs() << "Unable to combine check-prefix strings into a prefix regular "
"expression! This is likely a bug in FileCheck's verification of "
"the check-prefix strings. Regular expression parsing failed "
"with the following error: "
<< REError << "\n";
return 2;
}
SourceMgr SM;
// Read the expected strings from the check file.
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
if (std::error_code EC = CheckFileOrErr.getError()) {
errs() << "Could not open check file '" << CheckFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
SmallString<4096> CheckFileBuffer;
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
unsigned CheckFileBufferID =
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
CheckFileText, CheckFile.getBufferIdentifier()),
SMLoc());
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
return 2;
// Open the file to check and add it to SourceMgr.
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
if (InputFilename == "-")
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
if (std::error_code EC = InputFileOrErr.getError()) {
errs() << "Could not open input file '" << InputFilename
<< "': " << EC.message() << '\n';
return 2;
}
MemoryBuffer &InputFile = *InputFileOrErr.get();
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
DumpCommandLine(argc, argv);
return 2;
}
SmallString<4096> InputFileBuffer;
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
InputFileText, InputFile.getBufferIdentifier()),
SMLoc());
std::vector<FileCheckDiag> Diags;
int ExitCode = FC.checkInput(SM, InputFileText,
DumpInput == DumpInputNever ? nullptr : &Diags)
? EXIT_SUCCESS
: 1;
if (DumpInput == DumpInputAlways ||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
errs() << "\n"
<< "Input file: " << InputFilename << "\n"
<< "Check file: " << CheckFilename << "\n"
<< "\n"
<< "-dump-input=help explains the following input dump.\n"
<< "\n";
std::vector<InputAnnotation> Annotations;
unsigned LabelWidth;
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
Annotations, LabelWidth);
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
InputFileText, Annotations, LabelWidth);
}
return ExitCode;
}

42
bin/triton-opt.cpp Normal file
View File

@@ -0,0 +1,42 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/Passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
namespace mlir {
namespace test {
void registerTestAliasPass();
void registerTestAlignmentPass();
void registerTestAllocationPass();
void registerTestMembarPass();
} // namespace test
} // namespace mlir
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
}

130
bin/triton-translate.cpp Normal file
View File

@@ -0,0 +1,130 @@
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/driver/llvm.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <iostream>
namespace mlir {
namespace triton {
OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
MLIRContext &context) {
std::string errorMessage;
auto input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
mlir::DialectRegistry registry;
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, arith::ArithmeticDialect,
StandardOpsDialect, scf::SCFDialect>();
context.appendDialectRegistry(registry);
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer)
-> OwningOpRef<ModuleOp> {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
if (!module) {
llvm::errs() << "Parse MLIR file failed.";
return nullptr;
}
return module;
};
auto module = processBuffer(std::move(input));
if (!module) {
return nullptr;
}
return module;
}
LogicalResult tritonTranslateMain(int argc, char **argv,
llvm::StringRef toolName) {
static llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string> targetKind(
"target", llvm::cl::desc("<translation target, options: llvmir/ptx>"),
llvm::cl::value_desc("target"), llvm::cl::init("llvmir"));
static llvm::cl::opt<int> SMArch("sm", llvm::cl::desc("sm arch"),
llvm::cl::init(80));
static llvm::cl::opt<int> ptxVersion(
"ptx-version", llvm::cl::desc("PTX version"), llvm::cl::init(10000));
llvm::InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
mlir::MLIRContext context;
auto module = loadMLIRModule(inputFilename, context);
if (!module) {
return failure();
}
std::string errorMessage;
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return failure();
}
llvm::LLVMContext llvmContext;
auto llvmir = translateTritonGPUToLLVMIR(&llvmContext, *module);
if (!llvmir) {
llvm::errs() << "Translate to LLVM IR failed";
}
if (targetKind == "llvmir")
llvm::outs() << *llvmir << '\n';
else if (targetKind == "ptx")
llvm::outs() << ::triton::driver::llir_to_ptx(
llvmir.get(), SMArch.getValue(), ptxVersion.getValue());
return success();
}
} // namespace triton
} // namespace mlir
int main(int argc, char **argv) {
return failed(mlir::triton::tritonTranslateMain(
argc, argv, "Triton Translate Testing Tool."));
}

View File

@@ -25,7 +25,7 @@
# LLVM_VERSION_STRING - Full LLVM version string (e.g. 6.0.0svn).
# LLVM_VERSION_BASE_STRING - Base LLVM version string without git/svn suffix (e.g. 6.0.0).
#
# Note: The variable names were chosen in conformance with the offical CMake
# Note: The variable names were chosen in conformance with the official CMake
# guidelines, see ${CMAKE_ROOT}/Modules/readme.txt.
# Try suffixed versions to pick up the newest LLVM install available on Debian
@@ -196,4 +196,4 @@ include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(LLVM
REQUIRED_VARS LLVM_ROOT_DIR
VERSION_VAR LLVM_VERSION_STRING)
VERSION_VAR LLVM_VERSION_STRING)

27
docs/_templates/versions.html vendored Normal file
View File

@@ -0,0 +1,27 @@
{%- if current_version %}
<div class="rst-versions" data-toggle="rst-versions" role="note" aria-label="versions">
<span class="rst-current-version" data-toggle="rst-current-version">
<span class="fa fa-book"> Other Versions</span>
v: {{ current_version.name }}
<span class="fa fa-caret-down"></span>
</span>
<div class="rst-other-versions">
{%- if versions.tags %}
<dl>
<dt>Tags</dt>
{%- for item in versions.tags %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
{%- if versions.branches %}
<dl>
<dt>Branches</dt>
{%- for item in versions.branches %}
<dd><a href="{{ item.url }}">{{ item.name }}</a></dd>
{%- endfor %}
</dl>
{%- endif %}
</div>
</div>
{%- endif %}

View File

@@ -34,14 +34,17 @@ def process_sig(app, what, name, obj, options, signature, return_annotation):
def setup(app):
"""Customize function args retrieving to get args under decorator."""
import sphinx
import triton
import os
app.connect("autodoc-process-signature", process_sig)
os.system("pip install -e ../python")
def forward_jit_fn(func):
old = func
def wrapped(obj, **kwargs):
import triton
if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn
return old(obj)
@@ -52,6 +55,7 @@ def setup(app):
old_documenter = sphinx.ext.autosummary.get_documenter
def documenter(app, obj, parent):
import triton
if isinstance(obj, triton.code_gen.JITFunction):
obj = obj.fn
return old_documenter(app, obj, parent)
@@ -66,9 +70,17 @@ def setup(app):
import sys
import os
sys.path.insert(0, os.path.abspath('../python/'))
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon']
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion']
autosummary_generate = True
# versioning config
smv_tag_whitelist = r'^(v1.1.2)$'
smv_branch_whitelist = r'^master$'
smv_remote_whitelist = None
smv_released_pattern = r'^tags/.*$'
smv_outputdir_format = '{ref.name}'
smv_prefer_remote_refs = False
# Sphinx gallery
extensions += ['sphinx_gallery.gen_gallery']
from sphinx_gallery.sorting import FileNameSortKey
@@ -85,6 +97,11 @@ sphinx_gallery_conf = {
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
html_sidebars = {
'**': [
'_templates/versions.html',
],
}
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:

View File

@@ -44,7 +44,7 @@ You can then test your installation by running the unit tests:
.. code-block:: bash
pip install -r requirements-test.txt
pip install -e '.[tests]'
pytest -vs test/unit/
and the benchmarks

View File

@@ -14,7 +14,7 @@ Traditional compilers typically rely on intermediate representations, such as LL
Program Representation
+++++++++++++++++++++++
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample litterature on linear and integer programming.
Polyhedral compilation is a vast area of research. In this section we only outline the most basic aspects of this topic, but readers interested in the solid mathematical foundations underneath may refer to the ample literature on linear and integer programming.
.. table::
:widths: 50 50

1
include/CMakeLists.txt Normal file
View File

@@ -0,0 +1 @@
add_subdirectory(triton)

View File

@@ -0,0 +1,80 @@
#ifndef TRITON_ANALYSIS_ALIAS_H
#define TRITON_ANALYSIS_ALIAS_H
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
class AliasInfo {
public:
AliasInfo() = default;
AliasInfo(Value value) { insert(value); }
void insert(Value value) { allocs.insert(value); }
const DenseSet<Value> &getAllocs() const { return allocs; }
bool operator==(const AliasInfo &other) const {
return allocs == other.allocs;
}
/// The pessimistic value state of a value without alias
static AliasInfo getPessimisticValueState(MLIRContext *context) {
return AliasInfo();
}
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
/// The union of both arguments
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
private:
/// The set of allocated values that are aliased by this lattice.
/// For now, we only consider aliased value produced by the following
/// situations:
/// 1. values returned by scf.yield
/// 2. block arguments in scf.for
/// Example:
/// alloc v1 alloc v2
/// | |
/// |--------------| |------------|
/// scf.for v3 scf.for v4 scf.for v5
/// |
/// scf.yield v6
///
/// v1's alloc [v1]
/// v2's alloc [v2]
/// v3's alloc [v1]
/// v4's alloc [v1, v2]
/// v5's alloc [v2]
/// v6's alloc [v1]
///
/// Therefore, v1's liveness range is the union of v3, v4, and v6
/// v2's liveness range is the union of v4 and v5.
DenseSet<Value> allocs;
};
//===----------------------------------------------------------------------===//
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
public:
using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
/// Given two values, returns their aliasing behavior.
AliasResult alias(Value lhs, Value rhs);
/// Returns the modify-reference behavior of `op` on `location`.
ModRefResult getModRef(Operation *op, Value location);
/// Computes if the alloc set of the results are changed.
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AliasInfo> *> operands) override;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_ALIAS_H

View File

@@ -0,0 +1,194 @@
#ifndef TRITON_ANALYSIS_ALLOCATION_H
#define TRITON_ANALYSIS_ALLOCATION_H
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/raw_ostream.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <atomic>
#include <limits>
namespace mlir {
namespace triton {
class AllocationAnalysis;
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op);
} // namespace triton
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
/// A class that represents an interval, specified using a start and an end
/// values: [Start, End).
template <typename T> class Interval {
public:
Interval() {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
T size() const { return End - Start; }
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
bool intersects(const Interval &R) const {
return Start < R.End && R.Start < End;
}
bool operator==(const Interval &R) const {
return Start == R.Start && End == R.End;
}
bool operator!=(const Interval &R) const { return !(*this == R); }
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
private:
T Start = std::numeric_limits<T>::min();
T End = std::numeric_limits<T>::max();
};
class Allocation {
public:
/// A unique identifier for shared memory buffers
using BufferId = size_t;
using BufferIdSetT = DenseSet<BufferId>;
static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();
/// Creates a new Allocation analysis that computes the shared memory
/// information for all associated shared memory values.
Allocation(Operation *operation) : operation(operation) { run(); }
/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
/// Returns the offset of the given buffer in the shared memory.
size_t getOffset(BufferId bufferId) const {
return bufferSet.lookup(bufferId).offset;
}
/// Returns the size of the given buffer in the shared memory.
size_t getAllocatedSize(BufferId bufferId) const {
return bufferSet.lookup(bufferId).size;
}
/// Returns the buffer id of the given value.
/// This interface only returns the allocated buffer id.
/// If you want to get all the buffer ids that are associated with the given
/// value, including alias buffers, use getBufferIds.
BufferId getBufferId(Value value) const {
if (valueBuffer.count(value)) {
return valueBuffer.lookup(value)->id;
} else {
return InvalidBufferId;
}
}
/// Returns all the buffer ids of the given value, including alias buffers.
BufferIdSetT getBufferIds(Value value) const {
BufferIdSetT bufferIds;
auto allocBufferId = getBufferId(value);
if (allocBufferId != InvalidBufferId)
bufferIds.insert(allocBufferId);
for (auto *buffer : aliasBuffer.lookup(value)) {
if (buffer->id != InvalidBufferId)
bufferIds.insert(buffer->id);
}
return bufferIds;
}
/// Returns the scratch buffer id of the given value.
BufferId getBufferId(Operation *operation) const {
if (opScratch.count(operation)) {
return opScratch.lookup(operation)->id;
} else {
return InvalidBufferId;
}
}
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
return false;
auto lhsBuffer = bufferSet.lookup(lhsId);
auto rhsBuffer = bufferSet.lookup(rhsId);
return lhsBuffer.intersects(rhsBuffer);
}
private:
/// A class that represents a shared memory buffer
struct BufferT {
enum class BufferKind { Explicit, Scratch };
/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;
BufferKind kind;
BufferId id;
size_t size;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT() : BufferT(BufferKind::Explicit) {}
BufferT(BufferKind kind) : BufferT(kind, 0, 0) {}
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
bool intersects(const BufferT &other) const {
return Interval<size_t>(offset, offset + size)
.intersects(
Interval<size_t>(other.offset, other.offset + other.size));
}
};
/// Op -> Scratch Buffer
using OpScratchMapT = DenseMap<Operation *, BufferT *>;
/// Value -> Explicit Buffer
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
/// Value -> Alias Buffer
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = DenseMap<BufferId, BufferT>;
/// Runs allocation analysis on the given top-level operation.
void run();
private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&...args) {
auto buffer = BufferT(Kind, std::forward<Args>(args)...);
bufferSet[buffer.id] = std::move(buffer);
if constexpr (Kind == BufferT::BufferKind::Explicit) {
valueBuffer[key] = &bufferSet[buffer.id];
} else {
opScratch[key] = &bufferSet[buffer.id];
}
}
void addAlias(Value value, Value alloc) {
aliasBuffer[value].insert(valueBuffer[alloc]);
}
private:
Operation *operation;
OpScratchMapT opScratch;
ValueBufferMapT valueBuffer;
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;
friend class triton::AllocationAnalysis;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_ALLOCATION_H

View File

@@ -0,0 +1,138 @@
#ifndef TRITON_ANALYSIS_AXISINFO_H
#define TRITON_ANALYSIS_AXISINFO_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
/// This lattice value represents known information on the axes of a lattice.
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef SmallVector<int, 4> DimVectorT;
public:
// Default constructor
AxisInfo() : AxisInfo({}, {}, {}) {}
// Construct contiguity info with known contiguity
AxisInfo(DimVectorT knownContiguity, DimVectorT knownDivisibility,
DimVectorT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == (size_t)rank);
assert(knownConstancy.size() == (size_t)rank);
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const DimVectorT &getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DimVectorT &getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const DimVectorT &getConstancy() const { return constancy; }
int getRank() const { return rank; }
// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&
(divisibility == other.divisibility) &&
(constancy == other.constancy);
}
/// The pessimistic value state of the contiguity is unknown.
static AxisInfo getPessimisticValueState(MLIRContext *context) {
return AxisInfo();
}
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of contiguous integers along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
/// Would have contiguity [1, 4].
/// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
/// [18, 22, 26, 30]
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
DimVectorT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
// would have divisibility [1, 2]
// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of constant integer along it. This is
/// particularly useful to infer the contiguity
/// of operations (e.g., add) involving a constant
/// For example
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
DimVectorT constancy;
// number of dimensions of the lattice
int rank;
};
class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n) {
if (n == 0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(
Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo, AxisInfo, int)> &getContiguity,
const std::function<int(AxisInfo, AxisInfo, int)> &getDivisibility,
const std::function<int(AxisInfo, AxisInfo, int)> &getConstancy);
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
};
} // namespace mlir
#endif

View File

@@ -0,0 +1,115 @@
#ifndef TRITON_ANALYSIS_MEMBAR_H
#define TRITON_ANALYSIS_MEMBAR_H
#include "Allocation.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
class OpBuilder;
//===----------------------------------------------------------------------===//
// Shared Memory Barrier Analysis
//===----------------------------------------------------------------------===//
class MembarAnalysis {
public:
/// Creates a new Membar analysis that generates the shared memory barrier
/// in the following circumstances:
/// - RAW: If a shared memory write is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// - WAR: If a shared memory read is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// The following circumstances do not require a barrier:
/// - WAW: not possible because overlapped memory allocation is not allowed.
/// - RAR: no write is performed.
/// Temporary storage of operations such as Reduce are considered as both
/// a shared memory read. If the temporary storage is written but not read,
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) { run(); }
private:
struct RegionInfo {
using BufferIdSetT = Allocation::BufferIdSetT;
BufferIdSetT syncReadBuffers;
BufferIdSetT syncWriteBuffers;
RegionInfo() = default;
RegionInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
}
/// Unions two RegionInfo objects.
void join(const RegionInfo &other) {
syncReadBuffers.insert(other.syncReadBuffers.begin(),
other.syncReadBuffers.end());
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
other.syncWriteBuffers.end());
}
/// Returns true if buffers in two RegionInfo objects are intersected.
bool isIntersected(const RegionInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/ isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation);
}
/// Clears the buffers because a barrier is inserted.
void sync() {
syncReadBuffers.clear();
syncWriteBuffers.clear();
}
private:
/// Returns true if buffers in two sets are intersected.
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
Allocation *allocation) const {
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
return allocation->isIntersected(lhsId, rhsId);
});
});
}
};
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:
/// region1
/// op1
/// op2 (scf.if)
/// region2
/// op3
/// op4
/// region3
/// op5
/// op6
/// op7
/// region2 and region3 started with the information of region1.
/// Each region is analyzed separately and keeps their own copy of the
/// information. At op7, we union the information of the region2 and region3
/// and update the information of region1.
void dfsOperation(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
/// Updates the RegionInfo operation based on the operation.
void transfer(Operation *operation, RegionInfo *blockInfo,
OpBuilder *builder);
private:
Allocation *allocation;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_MEMBAR_H

View File

@@ -0,0 +1,37 @@
#ifndef TRITON_ANALYSIS_UTILITY_H
#define TRITON_ANALYSIS_UTILITY_H
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
#include <numeric>
#include <string>
namespace mlir {
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);
std::string getValueOperandName(Value value, AsmState &state);
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
// output[i] = input[order[i]]
template <typename T>
SmallVector<T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
assert(input.size() == rank);
SmallVector<T> result(rank);
for (auto it : llvm::enumerate(order)) {
result[it.index()] = input[it.value()];
}
return result;
}
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -0,0 +1,2 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TritonConversionPassIncGen)

View File

@@ -0,0 +1,38 @@
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
#define TRITON_CONVERSION_MLIR_TYPES_H_
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
// This file redefines some common MLIR types for easy usage.
namespace mlir {
namespace triton {
namespace type {
// Integer types
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
}
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
} // namespace type
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_MLIR_TYPES_H_

View File

@@ -0,0 +1,17 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
namespace mlir {
namespace triton {
#define GEN_PASS_REGISTRATION
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,47 @@
#ifndef TRITON_CONVERSION_PASSES
#define TRITON_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
let summary = "Convert Triton to TritonGPU";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::StandardOpsDialect",
// TODO: Does this pass depend on SCF?
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
let options = [
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">
];
}
def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::gpu::GPUDialect",
"mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect",
"mlir::StandardOpsDialect"];
}
#endif

View File

@@ -0,0 +1,338 @@
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
namespace mlir {
class ConversionPatternRewriter;
class Location;
namespace triton {
using llvm::StringRef;
struct PTXInstr;
struct PTXInstrCommon;
struct PTXInstrExecution;
// PTXBuilder helps to manage a PTX asm program consists of one or multiple
// instructions.
//
// A helper for building a ASM program, the objective of PTXBuilder is to give a
// thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
// Currently, several factors are introduced to reduce the need for mixing
// string and C++ if-else code.
//
// Usage:
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
// "b"(p));
//
// PTXBuilder builder;
// auto& add = builder.create<>();
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
// // predicate here binds %0 to pVal, pVal is a mlir::Value
//
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
//
// To get the asm code:
// builder.dump()
//
// To get all the mlir::Value used in the PTX code,
//
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
//
// To get the string containing all the constraints with "," separated,
// builder.getConstraints() // get "=r,r,k"
//
// PTXBuilder can build a PTX asm with multiple instructions, sample code:
//
// PTXBuilder builder;
// auto& mov = builder.create("mov");
// auto& cp = builder.create("cp");
// mov(...);
// cp(...);
// This will get a PTX code with two instructions.
//
// Similar to a C function, a declared PTXInstr instance can be launched
// multiple times with different operands, e.g.
//
// auto& mov = builder.create("mov");
// mov(... some operands ...);
// mov(... some different operands ...);
//
// Finally, we will get a PTX code with two mov instructions.
//
// There are several derived instruction type for typical instructions, for
// example, the PtxIOInstr for ld and st instructions.
struct PTXBuilder {
struct Operand {
std::string constraint;
Value value;
int idx{-1};
llvm::SmallVector<Operand *> list;
std::function<std::string(int idx)> repr;
// for list
Operand() = default;
Operand(const Operation &) = delete;
Operand(Value value, StringRef constraint)
: constraint(constraint), value(value) {}
bool isList() const { return !value && constraint.empty(); }
Operand *listAppend(Operand *arg) {
list.push_back(arg);
return this;
}
Operand *listGet(size_t nth) const {
assert(nth < list.size());
return list[nth];
}
std::string dump() const;
};
template <typename INSTR = PTXInstr, typename... Args>
INSTR *create(Args &&...args) {
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
return static_cast<INSTR *>(instrs.back().get());
}
// Create a list of operands.
Operand *newListOperand() { return newOperand(); }
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
auto *list = newOperand();
for (auto &item : items) {
list->listAppend(newOperand(item.first, item.second));
}
return list;
}
Operand *newListOperand(unsigned count, mlir::Value val,
const std::string &constraint) {
auto *list = newOperand();
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(val, constraint));
}
return list;
}
Operand *newListOperand(unsigned count, const std::string &constraint) {
auto *list = newOperand();
for (unsigned i = 0; i < count; ++i) {
list->listAppend(newOperand(constraint));
}
return list;
}
// Create a new operand. It will not add to operand list.
// @value: the MLIR value bind to this operand.
// @constraint: ASM operand constraint, .e.g. "=r"
// @formatter: extra format to represent this operand in ASM code, default is
// "%{0}".format(operand.idx).
Operand *newOperand(mlir::Value value, StringRef constraint,
std::function<std::string(int idx)> formatter = nullptr);
// Create a new operand which is written to, that is, the constraint starts
// with "=", e.g. "=r".
Operand *newOperand(StringRef constraint);
// Create a constant integer operand.
Operand *newConstantOperand(int v);
// Create a constant operand with explicit code specified.
Operand *newConstantOperand(const std::string &v);
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
llvm::SmallVector<Operand *, 4> getAllArgs() const;
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
std::string getConstraints() const;
std::string dump() const;
mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc,
Type resTy, bool hasSideEffect = true,
bool isAlignStack = false,
ArrayRef<Attribute> attrs = {}) const;
private:
Operand *newOperand() {
argArchive.emplace_back(std::make_unique<Operand>());
return argArchive.back().get();
}
friend struct PTXInstr;
friend struct PTXInstrCommon;
protected:
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
llvm::SmallVector<std::unique_ptr<PTXInstrCommon>, 2> instrs;
llvm::SmallVector<std::unique_ptr<PTXInstrExecution>, 4> executions;
int oprCounter{};
};
// PTX instruction common interface.
// Put the generic logic for all the instructions here.
struct PTXInstrCommon {
explicit PTXInstrCommon(PTXBuilder *builder) : builder(builder) {}
using Operand = PTXBuilder::Operand;
// clang-format off
PTXInstrExecution& operator()() { return call({}); }
PTXInstrExecution& operator()(Operand* a) { return call({a}); }
PTXInstrExecution& operator()(Operand* a, Operand* b) { return call({a, b}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c) { return call({a, b, c}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d) { return call({a, b, c, d}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e) { return call({a, b, c, d, e}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f) { return call({a, b, c, d, e, f}); }
PTXInstrExecution& operator()(Operand* a, Operand* b, Operand* c, Operand* d, Operand * e, Operand* f, Operand* g) { return call({a, b, c, d, e, f, g}); }
// clang-format on
// Set operands of this instruction.
PTXInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs);
protected:
PTXInstrExecution &call(llvm::ArrayRef<Operand *> oprs);
PTXBuilder *builder{};
llvm::SmallVector<std::string, 4> instrParts;
friend struct PTXInstrExecution;
};
template <class ConcreteT> struct PTXInstrBase : public PTXInstrCommon {
using Operand = PTXBuilder::Operand;
explicit PTXInstrBase(PTXBuilder *builder, const std::string &name)
: PTXInstrCommon(builder) {
o(name);
}
// Append a suffix to the instruction.
// e.g. PTXInstr("add").o("s32") get a add.s32.
// A predicate is used to tell whether to apply the suffix, so that no if-else
// code needed. e.g. `PTXInstr("add").o("s32", isS32).o("u32", !isS32);` will
// get a `add.s32` if isS32 is true.
ConcreteT &o(const std::string &suffix, bool predicate = true) {
if (predicate)
instrParts.push_back(suffix);
return *static_cast<ConcreteT *>(this);
}
};
struct PTXInstr : public PTXInstrBase<PTXInstr> {
using PTXInstrBase<PTXInstr>::PTXInstrBase;
};
// A helper for PTX ld/st instruction.
// Usage:
// PtxIOInstr store("st");
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
// store.addAddr(addrValue, "l", off);
struct PTXIOInstr : public PTXInstrBase<PTXIOInstr> {
using PTXInstrBase<PTXIOInstr>::PTXInstrBase;
// Add ".global" suffix to instruction
PTXIOInstr &global(bool predicate = true) {
o("global", predicate);
return *this;
}
// Add ".shared" suffix to instruction
PTXIOInstr &shared(bool predicate = true) {
o("shared", predicate);
return *this;
}
// Add ".v" suffix to instruction
PTXIOInstr &v(int vecWidth, bool predicate = true) {
if (vecWidth > 1) {
o("v" + std::to_string(vecWidth), predicate);
}
return *this;
}
// Add ".b" suffix to instruction
PTXIOInstr &b(int width) {
o("b" + std::to_string(width));
return *this;
}
};
struct PTXCpAsyncInstrBase : public PTXInstrBase<PTXCpAsyncInstrBase> {
explicit PTXCpAsyncInstrBase(PTXBuilder *builder)
: PTXInstrBase(builder, "cp.async") {}
};
struct PTXCpAsyncCommitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncCommitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("commit_group");
}
};
struct PTXCpAsyncWaitGroupInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncWaitGroupInstr(PTXBuilder *builder)
: PTXCpAsyncInstrBase(builder) {
o("wait_group");
}
};
struct PTXCpAsyncLoadInstr : public PTXCpAsyncInstrBase {
explicit PTXCpAsyncLoadInstr(PTXBuilder *builder,
triton::CacheModifier modifier)
: PTXCpAsyncInstrBase(builder) {
o(triton::stringifyCacheModifier(modifier).str());
o("shared");
o("global");
}
};
// Record the operands and context for "launching" a PtxInstr.
struct PTXInstrExecution {
using Operand = PTXBuilder::Operand;
llvm::SmallVector<Operand *> argsInOrder;
PTXInstrExecution() = default;
explicit PTXInstrExecution(PTXInstrCommon *instr,
llvm::ArrayRef<Operand *> oprs)
: argsInOrder(oprs.begin(), oprs.end()), instr(instr) {}
// Prefix a predicate to the instruction.
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
pred = instr->builder->newOperand(value, constraint);
return *this;
}
// Prefix a !predicate to the instruction.
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
pred = instr->builder->newOperand(value, constraint);
pred->repr = [](int idx) { return "@!%" + std::to_string(idx); };
return *this;
}
std::string dump() const;
SmallVector<Operand *> getArgList() const;
PTXInstrCommon *instr{};
Operand *pred{};
};
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_

View File

@@ -0,0 +1,42 @@
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter);
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
};
namespace triton {
// Names for identifying different NVVM annotations. It is used as attribute
// names in MLIR modules. Refer to
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for
// the full list.
struct NVVMMetadataField {
static constexpr char MaxNTid[] = "nvvm.maxntid";
static constexpr char Kernel[] = "nvvm.kernel";
};
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonGPUToLLVMPass();
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,25 @@
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
// Create the pass with numWarps passed from cl::opt.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(int numWarps);
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,2 @@
add_subdirectory(Triton)
add_subdirectory(TritonGPU)

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,19 @@
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS TritonDialect.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
set(LLVM_TARGET_DEFINITIONS TritonOps.td)
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td)
mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(TritonTableGen)

View File

@@ -0,0 +1,40 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"
#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.h.inc"
namespace mlir {
namespace triton {
class DialectInferLayoutInterface
: public DialectInterface::Base<DialectInferLayoutInterface> {
public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}
virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;
};
} // namespace triton
} // namespace mlir
#endif // TRITON_IR_DIALECT_H_

View File

@@ -0,0 +1,9 @@
#ifndef TRITON_IR_INTERFACES_H_
#define TRITON_IR_INTERFACES_H_
#include "mlir/IR/OpDefinition.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,60 @@
#ifndef TRITON_IR_TRAITS_H_
#define TRITON_IR_TRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include <iostream>
namespace mlir {
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifySameOperandsAndResultEncoding(Operation *op);
LogicalResult verifySameOperandsEncoding(Operation *op);
// The rationale for this trait is to prevent users from creating programs
// that would have catastrophic register pressure and cause the compiler to
// hang.
// Since H100 has 256KB registers, we should allow users to create tensors
// of size up to 256K elements. It will spill for datatypes wider than 1B,
// but we probably should limit number of elements (rather than bytes) to
// keep specs simple
int constexpr maxTensorNumElements = 1048576;
LogicalResult verifyTensorSize(Operation *op);
} // namespace impl
template <class ConcreteType>
class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyTensorSize(op);
}
};
template <typename ConcreteType>
class SameOperandsAndResultEncoding
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultEncoding(op);
}
};
template <typename ConcreteType>
class SameOperandsEncoding
: public TraitBase<ConcreteType, SameOperandsEncoding> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsEncoding(op);
}
};
} // namespace OpTrait
} // namespace mlir
#endif

View File

@@ -0,0 +1,67 @@
#ifndef TRITON_ATTR_DEFS
#define TRITON_ATTR_DEFS
include "mlir/IR/EnumAttr.td"
// Attrs for LoadOp
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
]> {
let cppNamespace = "::mlir::triton";
}
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {
let cppNamespace = "::mlir::triton";
}
// reduction
def TT_RedOpAttr : I32EnumAttr<
/*name*/"RedOp", /*summary*/"",
/*case*/
[
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
I32EnumAttrCase<"FADD", 2, "fadd">,
I32EnumAttrCase<"MIN", 3, "min">,
I32EnumAttrCase<"MAX", 4, "max">,
I32EnumAttrCase<"UMIN", 5, "umin">,
I32EnumAttrCase<"UMAX", 6, "umax">,
I32EnumAttrCase<"ARGMIN", 7, "argmin">,
I32EnumAttrCase<"ARGMAX", 8, "argmax">,
I32EnumAttrCase<"ARGUMIN", 9, "argumin">,
I32EnumAttrCase<"ARGUMAX", 10, "argumax">,
I32EnumAttrCase<"FMIN", 11, "fmin">,
I32EnumAttrCase<"FMAX", 12, "fmax">,
I32EnumAttrCase<"ARGFMIN", 13, "argfmin">,
I32EnumAttrCase<"ARGFMAX", 14, "argfmax">,
I32EnumAttrCase<"XOR", 15, "xor">
]> {
let cppNamespace = "::mlir::triton";
}
// atomic
def TT_AtomicRMWAttr : I32EnumAttr<
"RMWOp", "",
[
I32EnumAttrCase<"AND", 1, "and">,
I32EnumAttrCase<"OR", 2, "or">,
I32EnumAttrCase<"XOR", 3, "xor">,
I32EnumAttrCase<"ADD", 4, "add">,
I32EnumAttrCase<"FADD", 5, "fadd">,
I32EnumAttrCase<"MAX", 6, "max">,
I32EnumAttrCase<"MIN", 7, "min">,
I32EnumAttrCase<"UMAX", 8, "umax">,
I32EnumAttrCase<"UMIN", 9, "umin">
]> {
let cppNamespace = "::mlir::triton";
}
#endif

View File

@@ -0,0 +1,47 @@
#ifndef TRITON_DIALECT
#define TRITON_DIALECT
include "mlir/IR/OpBase.td"
def Triton_Dialect : Dialect {
let name = "tt";
let cppNamespace = "::mlir::triton";
let summary = "The Triton IR in MLIR";
let description = [{
Triton Dialect.
Dependent Dialects:
* Arithmetic:
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
* Math:
* exp, sin, cos, log, ...
* StructuredControlFlow:
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
}];
let dependentDialects = [
"arith::ArithmeticDialect",
"math::MathDialect",
"StandardOpsDialect",
"scf::SCFDialect",
"gpu::GPUDialect",
// Since LLVM 15
// "cf::ControlFlowDialect",
// "func::FuncDialect"
];
let extraClassDeclaration = [{
void registerTypes();
}];
let hasConstantMaterializer = 1;
}
include "triton/Dialect/Triton/IR/TritonTypes.td"
#endif // TRITON_DIALECT

View File

@@ -0,0 +1,6 @@
#ifndef TRITON_INTERFACES
#define TRITON_INTERFACES
include "mlir/IR/OpBase.td"
#endif // TRITON_INTERFACES

View File

@@ -0,0 +1,403 @@
#ifndef TRITON_OPS
#define TRITON_OPS
include "triton/Dialect/Triton/IR/TritonDialect.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
//
// Op Base
//
class TT_Op<string mnemonic, list<Trait> traits = []> :
Op<Triton_Dialect, mnemonic, !listconcat(traits, [TensorSizeTrait])> {
}
//
// CastOps
//
// Use cast ops in arith:
// bitcast
// fptoui, fptosi, uitofp, sitofp,
// extf, tructf,
// extui, extsi, tructi
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast int64 to pointer";
let arguments = (ins TT_I64Like:$from);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast pointer to int64";
let arguments = (ins TT_PtrLike:$from);
let results = (outs TT_I64Like:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
}
// arith.bitcast doesn't support pointers
def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Cast between types of the same bitwidth";
let arguments = (ins TT_Type:$from);
let results = (outs TT_Type:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: Add verifier
}
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
NoSideEffect,
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
let summary = "Floating point casting for custom types";
let description = [{
Floating point casting for custom types (F8, BF8).
F8 <-> BF8, FP16, FP32
BF8 <-> F8, FP16, FP32
}];
let arguments = (ins TT_FloatLike:$from);
let results = (outs TT_FloatLike:$result);
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
// TODO: We need a verifier here.
}
//
// Pointer Arith Ops
//
def TT_AddPtrOp : TT_Op<"addptr",
[NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
TypesMatchWith<"result type matches ptr type",
"result", "ptr", "$_self">,
TypesMatchWith<"result shape matches offset shape",
"result", "offset",
"getI32SameShape($_self)">]> {
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
let results = (outs TT_PtrLike:$result);
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
}
//
// Load/Store Ops
//
def TT_LoadOp : TT_Op<"load",
[SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
AttrSizedOperandSegments,
MemoryEffects<[MemRead]>,
TypesMatchWith<"infer ptr type from result type",
"result", "ptr", "getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from result type or none",
"result", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"infer other type from result type or none",
"result", "other", "$_self",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "load";
let arguments = (ins TT_PtrLike:$ptr, Optional<TT_BoolLike>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
// let assemblyFormat = "operands attr-dict `:` type($result)";
let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
let hasCanonicalizer = 1;
}
def TT_StoreOp : TT_Op<"store",
[SameOperandsShape,
SameOperandsEncoding,
MemoryEffects<[MemWrite]>,
TypesMatchWith<"infer ptr type from value type",
"value", "ptr",
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"infer mask type from value type",
"value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
];
// let assemblyFormat = "operands attr-dict `:` type($value)";
let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
let hasCanonicalizer = 1;
}
//
// Atomic Op
//
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic rmw";
let description = [{
load data at $ptr, do $rmw_op with $val, and store result to $ptr.
return old value at $ptr
}];
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrTensor:$ptr,
TT_Type:$val, I1Tensor:$mask);
let results = (outs TT_Type:$result);
}
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "atomic cas";
let description = [{
compare $cmp with data $old at location $ptr,
if $old == $cmp, store $val to $ptr,
else store $old to $ptr,
return $old
}];
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
let results = (outs TT_Type:$result);
}
//
// Shape Manipulation Ops
//
def TT_SplatOp : TT_Op<"splat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_ExpandDimsOp : TT_Op<"expand_dims", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SameOperandsAndResultElementType]> {
let summary = "expand_dims";
let arguments = (ins TT_Tensor:$src, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_ViewOp : TT_Op<"view", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "view";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Type:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_CatOp : TT_Op<"cat", [NoSideEffect,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
//
// SPMD Ops
//
def TT_GetProgramIdOp : TT_Op<"get_program_id", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [NoSideEffect]> {
let arguments = (ins I32Attr:$axis);
let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Dot Op
//
def TT_DotOp : TT_Op<"dot", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result's type matches accumulator's type",
"d", "c", "$_self">]> {
let summary = "dot";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32, BoolAttr:$transA, BoolAttr:$transB);
let results = (outs TT_FpIntTensor:$d);
let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
}
//
// Reduce Op
//
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "reduce";
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
let results = (outs TT_Type:$result);
let builders = [
OpBuilder<(ins "triton::RedOp":$redOp, "Value":$operand, "int":$axis)>,
];
let assemblyFormat = "$operand attr-dict `:` type($operand) `->` type($result)";
}
//
// External elementwise op
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Type>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let results = (outs TT_Type:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
//
// Make Range Op
//
// TODO: should have ConstantLike as Trait
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let summary = "make range";
let description = [{
Returns an 1D int32 tensor.
Values span from $start to $end (exclusive), with step = 1
}];
let arguments = (ins I32Attr:$start, I32Attr:$end);
let results = (outs TT_IntTensor:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
//
// Make PrintfOp
//
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix,
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging";
let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
format are generated automatically from the arguments.
}];
let assemblyFormat = [{
$prefix attr-dict ($args^ `:` type($args))?
}];
}
#endif // Triton_OPS

View File

@@ -0,0 +1,72 @@
#ifndef TRITON_TYPES
#define TRITON_TYPES
include "triton/Dialect/Triton/IR/TritonDialect.td"
//
// Types
//
class TritonTypeDef<string name, string _mnemonic>
: TypeDef<Triton_Dialect, name> {
// Used by printer/parser
let mnemonic = _mnemonic;
}
// Floating-point Type
def F8 : TritonTypeDef<"Float8", "f8">;
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
// Boolean Type
// TT_Bool -> I1
def TT_BoolTensor : TensorOf<[I1]>;
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
// Integer Type
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
def TT_IntTensor : TensorOf<[TT_Int]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
// I32 Type
// TT_I32 -> I32
// TT_I32Tensor -> I32Tensor
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
// I64 Type
// TT_I64 -> I64
// TT_I64Tensor -> I64Tensor
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
// Pointer Type
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
let summary = "pointer type";
let description = [{
Triton PointerType
}];
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$pointeeType,
"int":$addressSpace
), [{
return $_get(pointeeType.getContext(), pointeeType, addressSpace);
}]>
];
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
#endif

View File

@@ -0,0 +1,10 @@
#ifndef TRITON_IR_TYPES_H_
#define TRITON_IR_TYPES_H_
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#define GET_TYPEDEF_CLASSES
#include "triton/Dialect/Triton/IR/Types.h.inc"
#endif // TRITON_IR_TYPES_H_

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton)
add_public_tablegen_target(TritonTransformsIncGen)

View File

@@ -0,0 +1,18 @@
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
std::unique_ptr<Pass> createCombineOpsPass();
} // namespace triton
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
} // namespace mlir
#endif

View File

@@ -0,0 +1,23 @@
#ifndef TRITON_PASSES
#define TRITON_PASSES
include "mlir/Pass/PassBase.td"
def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp"> {
let summary = "combine ops";
let description = [{
dot(a, b, 0) + c => dot(a, b, c)
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
select(cond, load(ptrs, broadcast(cond), ???), other) =>
load(ptrs, broadcast(cond), other)
}];
let constructor = "mlir::triton::createCombineOpsPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
/*SelectOp*/"mlir::StandardOpsDialect"];
}
#endif

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,12 @@
set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
add_public_tablegen_target(TritonGPUTableGen)
set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td)
mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonGPUAttrDefsIncGen)

View File

@@ -0,0 +1,38 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace gpu {
unsigned getElemsPerThread(Type type);
SmallVector<unsigned> getSizePerThread(Attribute layout);
SmallVector<unsigned> getThreadsPerCTA(const Attribute &layout);
SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout);
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -0,0 +1,352 @@
#ifndef TRITONGPU_ATTRDEFS
#define TRITONGPU_ATTRDEFS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
//===----------------------------------------------------------------------===//
class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass> {
let description = [{
TritonGPU Tensors differ from usual tensors in that they contain a _layout_ attribute which determines
how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function
\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding
to the indices of the CUDA threads allowed to access some data at index $i$.
For example, let us consider the layout function:
\mathcal{L}(0, 0) = {0, 4}
\mathcal{L}(0, 1) = {1, 5}
\mathcal{L}(1, 0) = {2, 6}
\mathcal{L}(1, 1) = {3, 7}
Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
- T[0,0] is owned by both cuda thread 0 and 4
- T[0,1] is owned by both cuda thread 1 and 5
- T[1,0] is owned by both cuda thread 2 and 6
- T[1,1] is owned by both cuda thread 3 and 7
Right now, Triton implements two classes of layouts: shared, and distributed.
}];
code extraBaseClassDeclaration = [{
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}
//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> {
let mnemonic = "shared";
let description = [{
An encoding for tensors whose elements may be simultaneously accessed by
different cuda threads in the programs, via shared memory. In other words,
for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}.
In order to avoid shared memory bank conflicts, elements may be swizzled
in memory. For example, a swizzled row-major layout could store its data
as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
A_{1, 0} A_{1, 1} A_{1, 2} A_{1, 3} ... [phase 0] /
groups of vec=2 elements
are stored contiguously
_ _ _ _ /\_ _ _ _
A_{2, 2} A_{2, 3} A_{2, 0} A_{2, 1} ... [phase 1] \ per phase = 2
A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
}];
let parameters = (
ins
// swizzle info
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
// Distributed Layout Encoding
//===----------------------------------------------------------------------===//
class DistributedEncoding<string name> : TritonGPU_Attr<name> {
let description = [{
Distributed encodings have a layout function that is entirely characterized
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
(or even the same rank) as the tensor it is encoding.
The layout function \mathcal{L} of this layout is then defined, for an
index `i` \in R^D, as follows:
\mathcal{L}(A)[i_d] = L[(i_d + k_d*A.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*A.shape[d] < L.shape[d]
For example, for a tensor/layout pair
A = [x x x x x x x x]
[x x x x x x x x]
L = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
}];
let extraClassDeclaration = extraBaseClassDeclaration;
}
//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
let mnemonic = "blocked";
let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
used to promote memory coalescing in LoadInst and StoreInst.
It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which
specify the amount of elements owned by each CUDA thread, warp and CTA respectively.
For example, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows.
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
for
#triton_gpu.blocked_layout<{
sizePerThread = {2, 2}
threadsPerWarp = {8, 4}
warpsPerCTA = {1, 2}
}>
}];
let builders = [
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// TODO: compiles on MacOS but not linux?
// AttrBuilder<(ins "ArrayRef<unsigned>":$sizePerThread,
// "ArrayRef<unsigned>":$threadsPerWarp,
// "ArrayRef<unsigned>":$warpsPerCTA,
// "ArrayRef<unsigned>":$order), [{
// int rank = threadsPerWarp.size();
// SmallVector<unsigned, 4> sizePerWarp(rank);
// SmallVector<unsigned, 4> sizePerCTA(rank);
// for (unsigned i = 0; i < rank; i++) {
// sizePerWarp.push_back(sizePerThread[i] * threadsPerWarp[i]);
// sizePerCTA.push_back(sizePerWarp[i] * warpsPerCTA[i]);
// }
// return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, sizePerWarp, sizePerCTA);
// }]>,
// Custom builder initializes sizePerWarp and sizePerCTA automatically
// Default builder takes sizePerThread, order and numWarps, and tries to
// pack numWarps*32 threads in the provided order for use in a type
// of the given shape.
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$sizePerThread,
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
int remainingWarps = numWarps;
int remainingLanes = 32;
SmallVector<unsigned, 4> threadsPerWarp(rank);
SmallVector<unsigned, 4> warpsPerCTA(rank);
for (int _dim = 0; _dim < rank; ++_dim) {
int dim = order[_dim];
int maxNumThreads = int(shape[dim]) / sizePerThread[dim];
warpsPerCTA[dim] = std::clamp(remainingWarps, 1, maxNumThreads);
maxNumThreads = maxNumThreads / warpsPerCTA[dim];
threadsPerWarp[dim] = std::clamp(remainingLanes, 1, maxNumThreads);
remainingWarps /= warpsPerCTA[dim];
remainingLanes /= threadsPerWarp[dim];
}
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
SliceEncodingAttr squeeze(int axis);
}];
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
// fastest-changing axis first
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
// These attributes can be inferred from the rest
// ArrayRefParameter<"unsigned">:$sizePerWarp,
// ArrayRefParameter<"unsigned">:$sizePerCTA
);
}
//===----------------------------------------------------------------------===//
// MMA Layout Encoding
//===----------------------------------------------------------------------===//
// TODO: MMAv1 and MMAv2 should be two instances of the same class
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
let mnemonic = "mma";
let description = [{
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'version' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A `blockTileSize` to indicate how data should be
partitioned between warps.
// -------------------------------- version = 1 --------------------------- //
For first-gen tensor cores, the implicit warpTileSize is [16, 16].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ]
[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22]
[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23]
[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14]
[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15]
[ ..............................................................
[ ..............................................................
[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30]
[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31]
warp 1 = warp0 + 32
--------------------------------/\-------------------------------
[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38]
[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39]
[ ..............................................................
[ ..............................................................
[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62]
[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63]
// -------------------------------- version = 2 --------------------------- //
For second-gen tensor cores, the implicit warpTileSize is [16, 8].
Information about this layout can be found in the official PTX documentation
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.16816 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
warp 0 warp 1
-----------------/\------------- ----------------/\-------------
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35
[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39
[ .............................. ..............................
[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63
warp 3 warp 4
----------------/\------------- ----------------/\-------------
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99
[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103
[ .............................. ...............................
[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127
}];
let parameters = (
ins
"unsigned":$version,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
let mnemonic = "slice";
let description = [{
TODO: improve docs
A = [x x x x x x x x]
parent = [0 1 2 3 ]
[4 5 6 7 ]
[8 9 10 11]
[12 13 14 15]
dim = 0
Then the data of A would be distributed as follow between the 16 CUDA threads:
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ]
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
}];
let parameters = (
ins
"unsigned":$dim,
// TODO: constraint here to only take distributed encodings
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration # [{
SmallVector<int64_t> paddedShape(ArrayRef<int64_t> shape) const;
}];
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
let mnemonic = "dot_op";
let description = [{
In TritonGPU dialect, considering `d = tt.dot a, b, c`
tt.dot's operands a and b must be of DotOperandEncodingAttr layout.
a's opIdx is 0, b's opIdx is 1.
The parend field in DotOperandEncodingAttr is the layout of d.
}];
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
);
let extraClassDeclaration = extraBaseClassDeclaration;
}
#endif

View File

@@ -0,0 +1,35 @@
#ifndef TRITONGPU_DIALECT
#define TRITONGPU_DIALECT
include "mlir/IR/OpBase.td"
def TritonGPU_Dialect : Dialect {
let name = "triton_gpu";
let cppNamespace = "::mlir::triton::gpu";
let hasOperationAttrVerify = 1;
let description = [{
Triton GPU Dialect.
}];
let dependentDialects = [
"triton::TritonDialect",
"mlir::gpu::GPUDialect"
];
let extraClassDeclaration = [{
static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; }
static int getNumWarps(ModuleOp mod) {
if(!mod->hasAttr("triton_gpu.num-warps"))
llvm::report_fatal_error(
"TritonGPU module should contain a triton_gpu.num-warps attribute");
return mod->getAttr("triton_gpu.num-warps").cast<IntegerAttr>().getInt();
}
}];
}
#endif

View File

@@ -0,0 +1,211 @@
#ifndef TRITONGPU_OPS
#define TRITONGPU_OPS
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic, traits>;
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
[SameOperandsAndResultShape, NoSideEffect]> {
let summary = "convert layout";
let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
let summary = "async wait";
let arguments = (ins I32Attr:$num);
let assemblyFormat = "attr-dict";
}
// Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU.
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let summary = "integer comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
TT_IntLike:$lhs,
TT_IntLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
let summary = "floating-point comparison operation";
let description = [{}];
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
TT_FloatLike:$lhs,
TT_FloatLike:$rhs);
let results = (outs TT_BoolLike:$result);
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
let summary = "select operation";
let description = [{}];
let arguments = (ins TT_BoolLike:$condition,
TT_Tensor:$true_value,
TT_Tensor:$false_value);
let results = (outs TT_Tensor:$result);
}
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
[SameVariadicOperandSize,
// MemoryEffects<[MemRead]>, doesn't work with CSE but seems like it should?
NoSideEffect,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">]> {
let summary = "insert slice async";
let description = [{
This operation inserts a tensor `$src` into another tensor `$dst` as specified by the operations
`$index` argument and `$axis` attribute.
It returns a copy of `$dst` with the proper slice updated asynchronously with the value of `$src`.
This operation is non-blocking, and `$results` will have the updated value after the corresponding async_wait.
The insert_slice_async operation supports the following arguments:
* src: the tensor that is inserted.
* dst: the tensor into which the `$src` tensor is inserted.
* index: the index of the `$src` tensor at the given `$axis` from which the `$dst` tensor is inserted into
* mask: optional tensor-rank number of boolean masks which specify which
elements of the `$src` tensor are inserted into the `$dst` tensor.
* other: optional tensor-rank number of other tensors which specify what
values are inserted into the `$dst` tensor if the corresponding
element of the `$mask` tensor is false.
In the future, we may decompose this operation into a sequence of:
* `async` operation to specify a sequence of asynchronous operations
* `load` operation to load a tensor from global memory
* `insert_slice` operations to insert the `$src` tensor into the `$dst` tensor
Example:
```
%1 = triton_gpu.alloc_tensor : tensor<2x32xf32>
%2 = triton_gpu.insert_slice_async %0, %1, %index { axis = 0 } : tensor<32x!tt.ptr<f32>, #AL> -> tensor<2x32xf32, #A>
triiton_gpu.async_wait { num = 0 : i32 }
```
}];
let arguments = (ins TT_PtrTensor:$src, TT_Tensor:$dst, I32:$index,
Optional<I1Tensor>:$mask, Optional<TT_Type>:$other,
TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict,
BoolAttr:$isVolatile, I32Attr:$axis);
let builders = [
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index, "Value":$mask,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
OpBuilder<(ins "Value":$src, "Value":$dst, "Value":$index,
"Value":$mask, "Value":$other,
"triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile, "int":$axis)>,
];
let results = (outs TT_Tensor:$result);
//let assemblyFormat = [{
// $src `,` $dst ``
// $index, $mask, $other
// attr-dict `:` type($src) `->` type($dst)
//}];
// The custom parser could be replaced with oilist in LLVM-16
let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [NoSideEffect, InferTypeOpInterface]> {
let summary = "extract slice";
let description = [{
The "extract_slice" operation extracts a `$result` tensor from a `$src` tensor as
specified by the operation's `$index` and `$axis` arguments.
The extract_slice operation supports the following arguments:
* src: the tensor that is extracted from.
* index: the index at the given `$axis` from which the `$src` tensor is extracted
Example:
```
// Rank-reducing extract_slice.
%1 = tensor.extract_slice %0, %index {axis = 0} : tensor<8x16x4xf32> -> tensor<1x16x4xf32>
```
}];
let arguments = (ins TT_Tensor:$src, I32:$index, I32Attr:$axis);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{$src `,` $index attr-dict `:` type($src) `->` type($result)}];
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
}];
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [NoSideEffect]> {
let summary = "allocate tensor";
let description = [{
This operation defines a tensor of a particular shape.
The contents of the tensor are supposed to be in shared memory.
Note: This op can be repalced to a `bufferization.alloc_tensor` in LLVM 16.
}];
let assemblyFormat = [{attr-dict `:` type($result)}];
let results = (outs TT_Tensor:$result);
// result needs to be of shared layout
let verifier = [{ return ::verify(*this); }];
}
#endif

View File

@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU)
add_public_tablegen_target(TritonGPUTransformsIncGen)

View File

@@ -0,0 +1,24 @@
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
std::unique_ptr<Pass> createTritonGPUSwizzlePass();
std::unique_ptr<Pass> createTritonGPUCoalescePass();
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
std::unique_ptr<Pass> createTritonGPUVerifier();
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
} // namespace mlir
#endif

View File

@@ -0,0 +1,79 @@
#ifndef TRITONGPU_PASSES
#define TRITONGPU_PASSES
include "mlir/Pass/PassBase.td"
def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
TODO
}];
let constructor = "mlir::createTritonGPUPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithmeticDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"2",
"number of pipeline stages">
];
}
def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
let summary = "coalesce";
let description = [{
TODO
}];
let constructor = "mlir::createTritonGPUCoalescePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
let summary = "combine triton gpu ops";
let description = [{
convert_layout(convert_layout(%src, #LAYOUT_0), #LAYOUT_1) =>
convert_layout(%src, #LAYOUT_1)
convert_layout(%src, #LAYOUT) => %src if %src.layout() == #LAYOUT
}];
let constructor = "mlir::createTritonGPUCombineOpsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}
def TritonGPUSwizzle : Pass<"tritongpu-swizzle", "mlir::ModuleOp"> {
let summary = "swizzle";
let description = [{
Inserts conversions to swizzled layout so as to avoid shared memory bank conflicts.
}];
let constructor = "mlir::createTritonGPUSwizzlePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
let summary = "canonicalize scf.ForOp ops";
let description = [{
This implements some optimizations that are missing in the standard scf.ForOp
canonicalizer.
}];
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif

View File

@@ -0,0 +1,32 @@
//===----------------------------------------------------------------------===//
//
// Defines utilities to use while converting to the TritonGPU dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
private:
MLIRContext *context;
int numWarps;
};
class TritonGPUConversionTarget : public ConversionTarget {
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx,
TritonGPUTypeConverter &typeConverter);
};
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_

View File

@@ -0,0 +1,35 @@
#ifndef TRITON_TARGET_LLVMIRTRANSLATION_H
#define TRITON_TARGET_LLVMIRTRANSLATION_H
#include <memory>
#include <vector>
namespace llvm {
class Module;
class LLVMContext;
} // namespace llvm
namespace mlir {
class ModuleOp;
} // namespace mlir
namespace mlir {
namespace triton {
// add external dependent libs
void addExternalLibs(mlir::ModuleOp &module,
const std::vector<std::string> &names,
const std::vector<std::string> &paths);
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
} // namespace triton
} // namespace mlir
#endif // TRITON_TARGET_LLVMIRTRANSLATION_H

View File

@@ -0,0 +1,18 @@
#ifndef TRITON_TARGET_PTXTRANSLATION_H
#define TRITON_TARGET_PTXTRANSLATION_H
#include <memory>
#include <string>
namespace llvm {
class Module;
} // namespace llvm
namespace triton {
// Translate TritonGPU IR to PTX code.
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version);
} // namespace triton
#endif

View File

@@ -1,80 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map>
#include <vector>
namespace triton {
namespace ir {
class value;
class module;
class phi_node;
class splat_inst;
class cast_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
}
namespace codegen{
namespace analysis{
class align {
private:
struct cst_info {
unsigned num_cst;
unsigned value;
};
// helpers
std::vector<unsigned> get_shapes(ir::value *v);
// populate is_constant
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_cast(ir::cast_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_cast(ir::cast_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps
void populate(ir::value *v);
public:
void run(ir::module &mod);
unsigned get(ir::value* v, unsigned ax) const;
std::vector<unsigned> contiguous(ir::value* v) const;
private:
std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
};
}
}
}
#endif

View File

@@ -1,47 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#define TDL_INCLUDE_IR_CODEGEN_STORAGE_ALLOC_H
#include <map>
#include <set>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
namespace triton{
namespace ir{
class value;
class function;
class module;
}
namespace codegen{
namespace analysis{
class tiles;
class liveness;
class cts;
class allocation {
public:
allocation(liveness *live)
: liveness_(live) { }
// accessors
bool has_offset(const data_layout *x) const { return offsets_.find(x) != offsets_.end(); }
unsigned offset(const data_layout *x) const { return offsets_.at(x); }
unsigned allocated_size() const { return allocated_size_; }
// run
void run(ir::module& mod);
private:
std::map<const data_layout*, unsigned> offsets_;
size_t allocated_size_;
// dependences
liveness *liveness_;
};
}
}
}
#endif

View File

@@ -1,52 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_AXES_H_
#define _TRITON_CODEGEN_ANALYSIS_AXES_H_
#include "triton/tools/graph.h"
#include <map>
#include <vector>
namespace triton{
namespace ir{
class value;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
class axes {
typedef std::pair<ir::value*, unsigned> node_t;
private:
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i,
bool is_masked_load_async=false);
void update_graph_no_edge(ir::instruction *i);
void update_graph(ir::instruction *i);
public:
axes();
void run(ir::module &mod);
// accessors
int get(ir::value *value, unsigned dim);
std::vector<int> get(ir::value *value);
private:
tools::graph<node_t> graph_;
std::map<node_t, size_t> axes_;
};
}
}
}
#endif

View File

@@ -1,265 +0,0 @@
#ifndef _TRITON_CODEGEN_ANALYSIS_GRID_H_
#define _TRITON_CODEGEN_ANALYSIS_GRID_H_
#include <map>
#include <set>
#include <vector>
#include <memory>
#include "triton/tools/graph.h"
#include "triton/codegen/target.h"
namespace triton{
namespace ir{
class value;
class type;
class module;
class instruction;
class phi_node;
}
namespace codegen{
namespace analysis{
class axes;
class align;
class layout_visitor;
class data_layout;
class mma_layout;
class scanline_layout;
class shared_layout;
class layout_visitor {
public:
virtual void visit_layout(data_layout *);
virtual void visit_layout_mma(mma_layout*) = 0;
virtual void visit_layout_scanline(scanline_layout*) = 0;
virtual void visit_layout_shared(shared_layout*) = 0;
};
class data_layout {
protected:
enum id_t {
MMA,
SCANLINE,
SHARED
};
typedef std::vector<int> axes_t;
typedef std::vector<unsigned> shape_t;
typedef std::vector<int> order_t;
typedef std::vector<ir::value*> values_t;
private:
template<typename T>
T* downcast(id_t id) {
if(id_ == id)
return static_cast<T*>(this);
return nullptr;
}
public:
data_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned> &shape,
const std::vector<ir::value *> &values,
analysis::align* align);
// visitor
virtual void accept(layout_visitor* vst) = 0;
// downcast
mma_layout* to_mma() { return downcast<mma_layout>(MMA); }
scanline_layout* to_scanline() { return downcast<scanline_layout>(SCANLINE); }
shared_layout* to_shared() { return downcast<shared_layout>(SHARED); }
// accessors
size_t get_rank() { return shape_.size(); }
const shape_t& get_shape() const { return shape_; }
const order_t& get_order() const { return order_; }
const values_t& get_values() const { return values_;}
int get_axis(size_t k) const { return axes_.at(k); }
std::vector<int> get_axes() const { return axes_; }
const int get_order(size_t k) const { return order_.at(k); }
// find the position of given axis
int find_axis(int to_find) const;
private:
id_t id_;
axes_t axes_;
values_t values_;
protected:
order_t order_;
shape_t shape_;
};
class distributed_layout: public data_layout{
public:
distributed_layout(id_t id,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value*>& values,
analysis::align* align);
int shape_per_cta(size_t k) { return shape_per_cta_.at(k); }
int rep_per_cta(size_t k) { return shape_[k] / shape_per_cta_[k]; }
protected:
std::vector<int> shape_per_cta_;
};
class mma_layout: public distributed_layout {
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
int wpt(size_t k) { return wpt_.at(k); }
int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); }
private:
// fragment per warp
std::vector<int> fpw_;
// shape per warp
std::vector<int> spw_;
// warp per tile
std::vector<int> wpt_;
// shape per tile
std::vector<int> spt_;
// repetitions
std::vector<int> rep_;
};
struct scanline_layout: public distributed_layout {
scanline_layout(size_t num_warps,
const std::vector<int>& axes,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align,
target* tgt);
void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); }
// accessor
int mts(size_t k) { return mts_.at(k); }
int nts(size_t k) { return nts_.at(k); }
public:
// micro tile size. The size of a tile held by a thread block.
std::vector<int> mts_;
// nano tile size. The size of a tile held by a thread.
std::vector<int> nts_;
};
struct double_buffer_info_t {
ir::value* first;
ir::value* latch;
ir::phi_node* phi;
};
struct N_buffer_info_t {
std::vector<ir::value*> firsts; // not necessarily ordered as input order
ir::value* latch;
ir::phi_node* phi;
std::map<ir::value*, int> firsts_idx;
};
// abstract for dot and coresponding smem values
class shared_layout: public data_layout {
private:
static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator);
static void extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res);
static void extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t>& res, int &prev_stages);
public:
shared_layout(data_layout *arg,
const std::vector<int>& axes,
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
ir::type* get_type() { return ty_; }
double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); }
N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); }
int get_num_stages() const;
size_t get_per_stage_size() const { return size_ / get_num_stages(); }
size_t get_per_stage_elements() const;
size_t get_num_per_phase() { return num_per_phase_; }
ir::value* hmma_dot_a() { return hmma_dot_a_; }
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
data_layout* get_arg_layout() { return arg_layout_; }
private:
size_t size_;
ir::type *ty_;
std::shared_ptr<double_buffer_info_t> double_buffer_;
std::shared_ptr<N_buffer_info_t> N_buffer_;
size_t num_per_phase_;
ir::value* hmma_dot_a_;
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
};
class layouts {
typedef ir::value* node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
void init_hmma_tile(data_layout& layouts);
void init_scanline_tile(data_layout &layouts);
void create(size_t id, const std::vector<ir::value*>& values);
public:
// constructor
layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt);
// accessors
unsigned layout_of(ir::value *value) const { return groups_.at(value); }
bool has(ir::value* value) const { return groups_.find(value) != groups_.end(); }
const std::vector<ir::value*>& values_of(unsigned id) const { return values_.at(id); }
size_t num_layouts() const { return values_.size();}
data_layout* get(size_t id) { return layouts_.at(id); }
data_layout* get(ir::value *v) { return get(layout_of(v));}
std::map<size_t, data_layout*> &get_all() { return layouts_; }
bool has_tmp(ir::value* i) { return tmp_.find(i) != tmp_.end(); }
int tmp(ir::value* i) { return tmp_.at(i);}
void copy(ir::value* dst, ir::value* src) { groups_[dst] = groups_[src]; }
// execution
void run(ir::module &mod);
private:
analysis::axes* axes_;
analysis::align* align_;
size_t num_warps_;
target* tgt_;
tools::graph<ir::value*> graph_;
std::map<ir::value*, size_t> groups_;
std::map<size_t, std::vector<ir::value*>> values_;
std::map<size_t, data_layout*> layouts_;
std::map<ir::value*, size_t> tmp_;
};
}
}
}
#endif

View File

@@ -1,67 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#define TDL_INCLUDE_IR_CODEGEN_LIVENESS_H
#include <map>
#include <set>
#include <vector>
#include "triton/codegen/analysis/layout.h"
#include "triton/tools/graph.h"
namespace triton{
namespace ir{
class value;
class phi_node;
class function;
class module;
class instruction;
}
namespace codegen{
namespace analysis{
typedef unsigned slot_index;
class tiles;
class layouts;
class data_layout;
struct segment {
slot_index start;
slot_index end;
bool contains(slot_index idx) const {
return start <= idx && idx < end;
}
bool intersect(const segment &Other){
return contains(Other.start) || Other.contains(start);
}
};
class liveness {
private:
typedef std::map<shared_layout*, segment> intervals_map_t;
public:
// constructor
liveness(layouts *l): layouts_(l){ }
// accessors
const intervals_map_t& get() const { return intervals_; }
segment get(shared_layout* v) const { return intervals_.at(v); }
// run
void run(ir::module &mod);
private:
// analysis
layouts *layouts_;
intervals_map_t intervals_;
};
}
}
}
#endif

View File

@@ -1,43 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H
#include <map>
namespace triton{
namespace ir{
class module;
}
namespace codegen{
class target;
namespace analysis{
class layouts;
class data_layout;
class swizzle {
public:
// constructor
swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ }
// accessors
int get_per_phase(data_layout* layout) { return per_phase_.at(layout); }
int get_max_phase(data_layout* layout) { return max_phase_.at(layout); }
int get_vec (data_layout* layout) { return vec_.at(layout); }
// run
void run(ir::module &mod);
private:
layouts* layouts_;
target* tgt_;
std::map<data_layout*, int> per_phase_;
std::map<data_layout*, int> max_phase_;
std::map<data_layout*, int> vec_;
};
}
}
}
#endif

View File

@@ -1,42 +0,0 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <memory>
namespace llvm{
class Module;
class LLVMContext;
}
namespace triton{
namespace codegen {
class target;
}
namespace ir{
class module;
}
namespace driver{
class device;
class module;
class kernel;
}
}
namespace triton{
namespace codegen{
// TODO:
// There should be a proper pass manager there!
std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx,
codegen::target* target,
int sm, int num_warps,
int num_stages, int &shared_static);
}
}
#endif

View File

@@ -1,258 +0,0 @@
#pragma once
#ifndef _TRITON_SELECTION_GENERATOR_H_
#define _TRITON_SELECTION_GENERATOR_H_
#include "triton/ir/visitor.h"
#include "triton/codegen/analysis/layout.h"
#include <functional>
// forward
namespace llvm{
class Type;
class Value;
class PHINode;
class BasicBlock;
class Attribute;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
namespace triton{
namespace ir{
class attribute;
class load_inst;
class store_inst;
}
namespace codegen{
// forward
namespace analysis{
class liveness;
class tiles;
class align;
class allocation;
class cts;
class axes;
class layouts;
class swizzle;
}
// typedef
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Attribute Attribute;
typedef llvm::BasicBlock BasicBlock;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
typedef std::vector<Value*> indices_t;
class target;
}
}
namespace triton{
namespace codegen{
struct distributed_axis {
int contiguous;
std::vector<Value*> values;
Value* thread_id;
};
class adder{
public:
adder(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class multiplier{
public:
multiplier(Builder** builder): builder_(builder) { }
Value* operator()(Value* x, Value* y, const std::string& name = "");
private:
Builder** builder_;
};
class geper{
public:
geper(Builder** builder): builder_(builder) { }
Value* operator()(Value *ptr, Value* off, const std::string& name = "");
Value* operator()(Type* ty, Value*ptr, std::vector<Value*> vals, const std::string& name = "");
private:
Builder** builder_;
};
class generator: public ir::visitor, public analysis::layout_visitor {
private:
void init_idx(ir::value *x);
Instruction* add_barrier();
Value* shared_off(const std::vector<unsigned>& shapes, const std::vector<int>& order, indices_t idx);
void finalize_shared_layout(analysis::shared_layout*);
void finalize_function(ir::function*);
void finalize_phi_node(ir::phi_node*);
private:
Type *cvt(ir::type *ty);
llvm::Attribute cvt(ir::attribute attr);
public:
generator(analysis::axes *a_axes,
analysis::layouts *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
analysis::swizzle *swizzle,
target *tgt,
unsigned num_warps);
void visit_value(ir::value* v);
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
void visit_icmp_inst(ir::icmp_inst*);
void visit_fcmp_inst(ir::fcmp_inst*);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3);
std::tuple<Value*, Value*, Value*, Value*> fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3);
Value* bf16_to_fp32(Value *in0);
Value* fp32_to_bf16(Value *in0);
void visit_cast_inst(ir::cast_inst*);
void visit_return_inst(ir::return_inst*);
void visit_cond_branch_inst(ir::cond_branch_inst*);
void visit_uncond_branch_inst(ir::uncond_branch_inst*);
void visit_load_inst(ir::load_inst*);
void visit_unmasked_load_inst(ir::unmasked_load_inst*);
void visit_masked_load_inst(ir::masked_load_inst*);
void visit_store_inst(ir::store_inst*);
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_cat_inst(ir::cat_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
void visit_downcast_inst(ir::downcast_inst*);
void visit_exp_inst(ir::exp_inst*);
void visit_cos_inst(ir::cos_inst*);
void visit_umulhi_inst(ir::umulhi_inst* x);
void visit_sin_inst(ir::sin_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_rmw_inst(ir::atomic_rmw_inst*);
void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK);
void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_dot_inst(ir::dot_inst*);
void visit_trans_inst(ir::trans_inst*);
void visit_sqrt_inst(ir::sqrt_inst*);
Value* shfl_sync(Value* acc, int32_t i);
void visit_reduce1d_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reducend_inst(ir::reduce_inst*, std::function<Value*(Value*,Value*)>, Value*);
void visit_reduce_inst(ir::reduce_inst*);
void visit_select_inst(ir::select_inst*);
void visit_layout_convert(ir::value *out, ir::value *in);
void visit_cvt_layout_inst(ir::cvt_layout_inst*);
void visit_masked_load_async_inst(ir::masked_load_async_inst*);
void visit_copy_to_shared_inst(ir::copy_to_shared_inst*);
void visit_copy_from_shared_inst(ir::copy_from_shared_inst*);
void visit_barrier_inst(ir::barrier_inst*);
void visit_prefetch_s_inst(ir::prefetch_s_inst*);
void visit_async_wait_inst(ir::async_wait_inst*);
// void visit_make_range_dyn(ir::make_range_dyn*);
void visit_make_range(ir::make_range*);
// void visit_make_range_sta(ir::make_range_sta*);
void visit_undef_value(ir::undef_value*);
void visit_constant_int(ir::constant_int*);
void visit_constant_fp(ir::constant_fp*);
void visit_alloc_const(ir::alloc_const*);
void visit_function(ir::function*);
void visit_basic_block(ir::basic_block*);
void visit_argument(ir::argument*);
void visit(ir::module &, llvm::Module &);
// layouts
void visit_layout_mma(analysis::mma_layout*);
void visit_layout_scanline(analysis::scanline_layout*);
void visit_layout_shared(analysis::shared_layout*);
private:
LLVMContext *ctx_;
Builder* builder_;
Module *mod_;
analysis::axes *a_axes_;
analysis::swizzle *swizzle_;
std::map<unsigned, distributed_axis> axes_;
target *tgt_;
analysis::layouts *layouts_;
analysis::align *alignment_;
analysis::allocation *alloc_;
Value *shmem_;
std::set<ir::value*> seen_;
unsigned num_warps_;
std::map<analysis::data_layout*, Value*> offset_a_m_;
std::map<analysis::data_layout*, Value*> offset_a_k_;
std::map<analysis::data_layout*, Value*> offset_b_k_;
std::map<analysis::data_layout*, Value*> offset_b_n_;
/// layout -> base ptr
std::map<analysis::data_layout*, Value*> shared_ptr_;
std::map<analysis::data_layout*, Value*> shared_pre_ptr_;
std::map<analysis::data_layout*, Value*> shared_next_ptr_;
/// offset for double-buffered layout
std::map<analysis::data_layout*, Value*> shared_off_;
/// Base shmem pointer of ir value
std::map<ir::value*, Value*> shmems_;
std::map<ir::value*, Value*> shoffs_;
std::map<ir::value*, std::vector<indices_t>> idxs_;
std::map<ir::value*, std::map<indices_t, Value*>> vals_;
/// idx for multi-stage pipeline
std::map<analysis::data_layout*, Value*> read_smem_idx_;
std::map<analysis::data_layout*, Value*> write_smem_idx_;
/// triton bb -> llvm bb
std::map<ir::value*, BasicBlock *> bbs_;
std::map<ir::value*, std::vector<int>> ords_;
// helper for creating llvm values
adder add;
multiplier mul;
geper gep;
/// PHI nodes
std::vector<std::tuple<llvm::PHINode*, Value*, ir::basic_block*>> lazy_phi_incs_;
/// Record prefetch instrs that needs to be moved
std::map<ir::value*, std::vector<Value*>> prefetch_latch_to_bb_;
};
}
}
#endif

View File

@@ -1,105 +0,0 @@
#ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H
#define TDL_INCLUDE_IR_CODEGEN_TARGET_H
namespace llvm{
class Type;
class Value;
class Instruction;
class Constant;
class LLVMContext;
class Module;
class ConstantFolder;
class IRBuilderDefaultInserter;
template <typename T, typename Inserter>
class IRBuilder;
class ArrayType;
class Function;
}
// typedefs
namespace triton{
namespace codegen{
typedef llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter> Builder;
typedef llvm::LLVMContext LLVMContext;
typedef llvm::Type Type;
typedef llvm::Value Value;
typedef llvm::Module Module;
typedef llvm::Instruction Instruction;
typedef llvm::Constant Constant;
typedef llvm::ArrayType ArrayType;
typedef llvm::Function Function;
}
}
namespace triton{
namespace codegen{
class nvidia_cu_target;
class target {
public:
target(bool is_gpu): is_gpu_(is_gpu){}
virtual ~target() {}
virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0;
virtual Instruction* add_barrier(Module *module, Builder& builder) = 0;
virtual Instruction* add_memfence(Module *module, Builder& builder) = 0;
virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0;
virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0;
virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0;
virtual unsigned guaranteed_alignment() = 0;
nvidia_cu_target* as_nvidia();
bool is_gpu() const;
private:
bool is_gpu_;
};
class amd_cl_target: public target {
public:
amd_cl_target(): target(true){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 16; }
};
class nvidia_cu_target: public target {
public:
nvidia_cu_target(int sm): target(true), sm_(sm){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
int sm() { return sm_; }
unsigned guaranteed_alignment() { return 16; }
private:
int sm_;
};
class cpu_target: public target {
public:
cpu_target(): target(false){}
void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn);
Instruction* add_barrier(Module *module, Builder& builder);
Instruction* add_memfence(Module *module, Builder& builder);
Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax);
Value* get_local_id(Module *module, Builder& builder, unsigned ax);
Value* get_block_id(Module *module, Builder& builder, unsigned ax);
Value* get_num_blocks(Module *module, Builder& builder, unsigned ax);
unsigned guaranteed_alignment() { return 1; }
};
}
}
#endif

View File

@@ -1,48 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H
#include <map>
#include <set>
#include <vector>
namespace triton {
namespace ir {
class module;
class value;
class io_inst;
class instruction;
class builder;
}
namespace codegen{
namespace analysis{
class align;
class layouts;
class cts;
}
namespace transform{
class coalesce {
private:
void extract_io_use(ir::value *v, std::set<ir::io_inst*>& result);
void extract_ld(ir::io_inst *i, std::map<int, std::vector<triton::ir::io_inst *> > &result);
ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map<ir::value*, ir::value*>& seen);
public:
coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts);
triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder);
void run(ir::module &mod);
private:
analysis::align* align_;
analysis::layouts* layout_;
};
}
}
}
#endif

View File

@@ -1,36 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#define TDL_INCLUDE_CODEGEN_BUFFER_INFO_PASS_H
#include <set>
#include <map>
namespace triton {
namespace ir {
class module;
class value;
class phi_node;
class instruction;
class builder;
}
namespace codegen{
namespace transform{
class cts {
private:
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared);
public:
cts(bool use_async = false): use_async_(use_async) {}
void run(ir::module &mod);
private:
bool use_async_;
};
}
}
}
#endif

View File

@@ -1,24 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class dce {
public:
dce() {}
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -1,22 +0,0 @@
#ifndef _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
#define _TRITON_SELECTION_TRANSFORM_DISASSOCIATE_H_
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class disassociate {
public:
void run(ir::module &mod);
};
}
}
}
#endif

View File

@@ -1,72 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H
#define TDL_INCLUDE_CODEGEN_BARRIERS_H
#include <vector>
#include <map>
#include <list>
#include <set>
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class basic_block;
class instruction;
class masked_load_async_inst;
class value;
class builder;
}
namespace codegen{
namespace analysis{
class allocation;
class liveness;
class layouts;
class cts;
class shared_layout;
}
namespace transform{
class prefetch;
class membar {
private:
typedef std::pair<unsigned, unsigned> interval_t;
typedef std::set<ir::value*> val_set_t;
typedef std::vector<ir::value*> val_vec_t;
private:
bool intersect(const val_set_t &X, const val_set_t &Y);
bool check_safe_war(ir::instruction* i);
int group_of(triton::ir::value *i, std::vector<triton::ir::value *> &async_write);
bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout);
val_set_t intersect_with(const val_set_t& as, const val_set_t& bs);
void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read,
std::set<triton::ir::value *> &safe_war, bool &inserted, ir::builder &builder);
public:
membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc,
transform::prefetch *prefetch, target* tgt):
liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {}
void run(ir::module &mod);
private:
analysis::liveness *liveness_;
analysis::layouts *layouts_;
analysis::allocation *alloc_;
transform::prefetch *prefetch_;
target* tgt_;
};
}
}
}
#endif

View File

@@ -1,53 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H
#include "triton/codegen/target.h"
namespace triton {
namespace ir {
class module;
class value;
class instruction;
class trans_inst;
class builder;
class constant_int;
class dot_inst;
}
namespace codegen{
namespace analysis{
class layouts;
}
namespace transform{
class peephole {
private:
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
bool rewrite_cvt_layout(ir::instruction *value, ir::builder& builder);
public:
peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {}
void run(ir::module &mod);
private:
target* tgt_;
analysis::layouts* layouts_;
};
}
}
}
#endif

View File

@@ -1,30 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H
// forward declaration
namespace triton {
namespace ir {
class module;
}
} // namespace triton
namespace triton {
namespace codegen {
namespace transform {
class pipeline {
public:
pipeline(bool has_copy_async, int num_stages)
: has_copy_async_(has_copy_async), num_stages_(num_stages) {}
void run(ir::module &module);
private:
bool has_copy_async_;
int num_stages_;
};
} // namespace transform
} // namespace codegen
} // namespace triton
#endif

View File

@@ -1,27 +0,0 @@
#ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H
#include <set>
// forward dclaration
namespace triton::ir{
class module;
class value;
}
namespace triton::codegen {
class target;
}
namespace triton::codegen::transform {
class prefetch {
target* tgt_;
std::set<ir::value*> prefetched_vals_;
public:
prefetch(target *tgt) : tgt_(tgt) {}
void run(ir::module &module);
bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); }
};
}
#endif

View File

@@ -1,26 +0,0 @@
#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H
#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H
namespace triton {
// forward declaration
namespace ir {
class module;
}
namespace codegen{
namespace transform{
class reorder {
public:
void run(ir::module& module);
};
}
}
}
#endif

View File

@@ -1,316 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_DISPATCH_H_
#define _TRITON_DRIVER_DISPATCH_H_
#include <type_traits>
#include <dlfcn.h>
//CUDA Backend
#include "triton/external/CUDA/cuda.h"
#include "triton/external/CUDA/nvml.h"
//// HIP backend
//#define __HIP_PLATFORM_AMD__
#include "triton/external/hip.h"
//Exceptions
#include <iostream>
#include <stdexcept>
namespace llvm {
class PassRegistry;
class Module;
}
namespace triton
{
namespace driver
{
class cu_context;
template<class T> void check(T){}
void check(CUresult err);
void check(hipError_t err);
class dispatch
{
protected:
template <class F>
struct return_type;
template <class R, class... A>
struct return_type<R (*)(A...)>
{ typedef R type; };
typedef bool (*f_init_t)();
template<f_init_t initializer, typename FunPtrT, typename... Args>
static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
{
initializer();
if(cache == nullptr){
cache = dlsym(lib_h, name);
if(cache == 0)
throw std::runtime_error("dlsym unable to load function");
}
FunPtrT fptr;
*reinterpret_cast<void **>(&fptr) = cache;
typename return_type<FunPtrT>::type res = (*fptr)(args...);
check(res);
return res;
}
public:
static void release();
// Nvidia
static bool nvmlinit();
static bool cuinit();
// AMD
static bool hipinit();
/* ------------------- *
* CUDA
* ------------------- */
// context management
static CUresult cuInit(unsigned int Flags);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
static CUresult cuDriverGetVersion(int *driverVersion);
// device management
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
static CUresult cuDeviceGetCount(int *count);
// link management
static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
static CUresult cuLinkDestroy(CUlinkState state);
// module management
static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
static CUresult cuModuleLoad(CUmodule *module, const char *fname);
static CUresult cuModuleLoadData(CUmodule* module, const void* image);
static CUresult cuModuleUnload(CUmodule hmod);
static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
// stream management
static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
static CUresult cuStreamSynchronize(CUstream hStream);
static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
static CUresult cuStreamDestroy_v2(CUstream hStream);
static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
// function management
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
// memory management
static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
static CUresult cuMemFree_v2(CUdeviceptr dptr);
static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
// event management
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
static CUresult cuEventDestroy_v2(CUevent hEvent);
/* ------------------- *
* NVML
* ------------------- */
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
/* ------------------- *
* HIP
* ------------------- */
// context management
static hipError_t hipInit(unsigned int Flags);
static hipError_t hipCtxDestroy(hipCtx_t ctx);
static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
static hipError_t hipCtxGetDevice(hipDevice_t* result);
static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
static hipError_t hipDriverGetVersion(int *driverVersion);
// device management
static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
static hipError_t hipGetDeviceCount(int *count);
// module management
static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
static hipError_t hipModuleUnload(hipModule_t hmod);
static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
// stream management
static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
static hipError_t hipStreamSynchronize(hipStream_t hStream);
static hipError_t hipStreamDestroy(hipStream_t hStream);
static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
// function management
static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
// memory management
static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
static hipError_t hipFree(hipDeviceptr_t dptr);
static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
// event management
static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
static hipError_t hipEventDestroy(hipEvent_t hEvent);
private:
// Libraries
static void* cuda_;
static void* nvml_;
static void* hip_;
/* ------------------- *
* CUDA
* ------------------- */
// context management
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuCtxCreate_v2_;
static void* cuCtxGetDevice_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuCtxEnablePeerAccess_;
static void* cuDriverGetVersion_;
static void* cuInit_;
// device management
static void* cuDeviceGet_;
static void* cuDeviceGetName_;
static void* cuDeviceGetPCIBusId_;
static void* cuDeviceGetAttribute_;
static void* cuDeviceGetCount_;
// link management
static void* cuLinkAddData_v2_;
static void* cuLinkCreate_v2_;
static void* cuLinkDestroy_;
static void* cuLinkComplete_;
// module management
static void* cuModuleGetGlobal_v2_;
static void* cuModuleLoad_;
static void* cuModuleUnload_;
static void* cuModuleLoadDataEx_;
static void* cuModuleLoadData_;
static void* cuModuleGetFunction_;
// stream management
static void* cuStreamCreate_;
static void* cuStreamSynchronize_;
static void* cuStreamDestroy_v2_;
static void* cuStreamGetCtx_;
static void* cuLaunchKernel_;
// function management
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
// memory management
static void* cuMemcpyDtoH_v2_;
static void* cuMemFree_v2_;
static void* cuMemcpyDtoHAsync_v2_;
static void* cuMemcpyHtoDAsync_v2_;
static void* cuMemcpyHtoD_v2_;
static void* cuMemAlloc_v2_;
static void* cuMemsetD8Async_;
static void* cuPointerGetAttribute_;
// event management
static void* cuEventCreate_;
static void* cuEventElapsedTime_;
static void* cuEventRecord_;
static void* cuEventDestroy_v2_;
/* ------------------- *
* NVML
* ------------------- */
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;
static void* nvmlDeviceGetClockInfo_;
static void* nvmlDeviceGetMaxClockInfo_;
static void* nvmlDeviceSetApplicationsClocks_;
/* ------------------- *
* HIP
* ------------------- */
// context management
static void* hipInit_;
static void* hipCtxDestroy_;
static void* hipCtxCreate_;
static void* hipCtxPushCurrent_;
static void* hipCtxPopCurrent_;
static void* hipCtxGetDevice_;
static void* hipCtxEnablePeerAccess_;
static void* hipDriverGetVersion_;
// device management
static void* hipGetDevice_;
static void* hipDeviceGetName_;
static void* hipDeviceGetPCIBusId_;
static void* hipDeviceGetAttribute_;
static void* hipGetDeviceCount_;
// module management
static void* hipModuleGetGlobal_;
static void* hipModuleLoad_;
static void* hipModuleLoadData_;
static void* hipModuleUnload_;
static void* hipModuleLoadDataEx_;
static void* hipModuleGetFunction_;
// stream management
static void* hipStreamCreate_;
static void* hipStreamSynchronize_;
static void* hipStreamDestroy_;
static void* hipModuleLaunchKernel_;;
// function management
static void* hipFuncGetAttributes_;
static void* hipFuncSetAttribute_;
static void* hipFuncSetCacheConfig_;
// memory management
static void* hipMalloc_;
static void* hipPointerGetAttribute_;
static void* hipMemsetD8Async_;
static void* hipMemcpyDtoH_;
static void* hipFree_;
static void* hipMemcpyDtoHAsync_;
static void* hipMemcpyHtoDAsync_;
static void* hipMemcpyHtoD_;
// event management
static void* hipEventCreate_;
static void* hipEventElapsedTime_;
static void* hipEventRecord_;
static void* hipEventDestroy_;
};
}
}
#endif

View File

@@ -1,220 +0,0 @@
#pragma once
#ifndef _TRITON_DRIVER_ERROR_H_
#define _TRITON_DRIVER_ERROR_H_
#include <exception>
#include "triton/driver/dispatch.h"
namespace triton
{
namespace driver
{
namespace exception
{
namespace nvrtc
{
#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } }
TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program");
TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option");
TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation");
TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure");
TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error");
#undef TRITON_CREATE_NVRTC_EXCEPTION
}
namespace cuda
{
class base: public std::exception{};
#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } }
TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found");
TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
namespace cublas
{
class base: public std::exception{};
#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } }
TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed");
TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUBLAS_EXCEPTION
}
namespace cudnn
{
#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } }
TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed");
TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param");
TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error");
TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch");
TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error");
TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed");
TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress");
TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow");
}
namespace hip
{
class base: public std::exception{};
#define TRITON_CREATE_HIP_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "HIP: Error- " msg; } }
TRITON_CREATE_HIP_EXCEPTION(invalid_value ,"invalid value");
TRITON_CREATE_HIP_EXCEPTION(out_of_memory ,"out of memory");
TRITON_CREATE_HIP_EXCEPTION(not_initialized ,"not initialized");
TRITON_CREATE_HIP_EXCEPTION(deinitialized ,"deinitialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_disabled ,"profiler disabled");
TRITON_CREATE_HIP_EXCEPTION(profiler_not_initialized ,"profiler not initialized");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_started ,"profiler already started");
TRITON_CREATE_HIP_EXCEPTION(profiler_already_stopped ,"profiler already stopped");
TRITON_CREATE_HIP_EXCEPTION(no_device ,"no device");
TRITON_CREATE_HIP_EXCEPTION(invalid_device ,"invalid device");
TRITON_CREATE_HIP_EXCEPTION(invalid_image ,"invalid image");
TRITON_CREATE_HIP_EXCEPTION(invalid_context ,"invalid context");
TRITON_CREATE_HIP_EXCEPTION(context_already_current ,"context already current");
TRITON_CREATE_HIP_EXCEPTION(map_failed ,"map failed");
TRITON_CREATE_HIP_EXCEPTION(unmap_failed ,"unmap failed");
TRITON_CREATE_HIP_EXCEPTION(array_is_mapped ,"array is mapped");
TRITON_CREATE_HIP_EXCEPTION(already_mapped ,"already mapped");
TRITON_CREATE_HIP_EXCEPTION(no_binary_for_gpu ,"no binary for gpu");
TRITON_CREATE_HIP_EXCEPTION(already_acquired ,"already acquired");
TRITON_CREATE_HIP_EXCEPTION(not_mapped ,"not mapped");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_array ,"not mapped as array");
TRITON_CREATE_HIP_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer");
TRITON_CREATE_HIP_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable");
TRITON_CREATE_HIP_EXCEPTION(unsupported_limit ,"unsupported limit");
TRITON_CREATE_HIP_EXCEPTION(context_already_in_use ,"context already in use");
TRITON_CREATE_HIP_EXCEPTION(peer_access_unsupported ,"peer access unsupported");
TRITON_CREATE_HIP_EXCEPTION(invalid_ptx ,"invalid ptx");
TRITON_CREATE_HIP_EXCEPTION(invalid_graphics_context ,"invalid graphics context");
TRITON_CREATE_HIP_EXCEPTION(invalid_source ,"invalid source");
TRITON_CREATE_HIP_EXCEPTION(file_not_found ,"file not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found");
TRITON_CREATE_HIP_EXCEPTION(shared_object_init_failed ,"shared object init failed");
TRITON_CREATE_HIP_EXCEPTION(operating_system ,"operating system");
TRITON_CREATE_HIP_EXCEPTION(invalid_handle ,"invalid handle");
TRITON_CREATE_HIP_EXCEPTION(not_found ,"not found");
TRITON_CREATE_HIP_EXCEPTION(not_ready ,"not ready");
TRITON_CREATE_HIP_EXCEPTION(illegal_address ,"illegal address");
TRITON_CREATE_HIP_EXCEPTION(launch_out_of_resources ,"launch out of resources");
TRITON_CREATE_HIP_EXCEPTION(launch_timeout ,"launch timeout");
TRITON_CREATE_HIP_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing");
TRITON_CREATE_HIP_EXCEPTION(peer_access_already_enabled ,"peer access already enabled");
TRITON_CREATE_HIP_EXCEPTION(peer_access_not_enabled ,"peer access not enabled");
TRITON_CREATE_HIP_EXCEPTION(primary_context_active ,"primary context active");
TRITON_CREATE_HIP_EXCEPTION(context_is_destroyed ,"context is destroyed");
TRITON_CREATE_HIP_EXCEPTION(assert_error ,"assert");
TRITON_CREATE_HIP_EXCEPTION(too_many_peers ,"too many peers");
TRITON_CREATE_HIP_EXCEPTION(host_memory_already_registered ,"host memory already registered");
TRITON_CREATE_HIP_EXCEPTION(host_memory_not_registered ,"hot memory not registered");
TRITON_CREATE_HIP_EXCEPTION(hardware_stack_error ,"hardware stack error");
TRITON_CREATE_HIP_EXCEPTION(illegal_instruction ,"illegal instruction");
TRITON_CREATE_HIP_EXCEPTION(misaligned_address ,"misaligned address");
TRITON_CREATE_HIP_EXCEPTION(invalid_address_space ,"invalid address space");
TRITON_CREATE_HIP_EXCEPTION(invalid_pc ,"invalid pc");
TRITON_CREATE_HIP_EXCEPTION(launch_failed ,"launch failed");
TRITON_CREATE_HIP_EXCEPTION(not_permitted ,"not permitted");
TRITON_CREATE_HIP_EXCEPTION(not_supported ,"not supported");
TRITON_CREATE_HIP_EXCEPTION(invalid_symbol ,"invalid symbol");
TRITON_CREATE_HIP_EXCEPTION(unknown ,"unknown");
#undef TRITON_CREATE_CUDA_EXCEPTION
}
}
}
}
#endif

View File

@@ -1,19 +0,0 @@
#include <string>
#include "triton/driver/dispatch.h"
namespace llvm{
class Module;
}
namespace triton{
namespace driver{
void init_llvm();
std::string llir_to_ptx(llvm::Module* module, int cc, int version);
std::string ptx_to_cubin(const std::string& ptx, int cc);
CUmodule ptx_to_cumodule(const std::string& ptx, int cc);
std::string llir_to_amdgpu(llvm::Module* module, const std::string& proc);
hipModule_t amdgpu_to_hipmodule(const std::string& path);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,288 +0,0 @@
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
// Ignoring error-code return values from hip APIs is discouraged. On C++17,
// we can make that yield a warning
/*
* @brief hipError_t
* @enum
* @ingroup Enumerations
*/
// Developer note - when updating these, update the hipErrorName and hipErrorString functions in
// NVCC and HCC paths Also update the hipCUDAErrorTohipError function in NVCC path.
#include <cstddef>
typedef enum hipError_t {
hipSuccess = 0, ///< Successful completion.
hipErrorInvalidValue = 1, ///< One or more of the parameters passed to the API call is NULL
///< or not in an acceptable range.
hipErrorOutOfMemory = 2,
// Deprecated
hipErrorMemoryAllocation = 2, ///< Memory allocation error.
hipErrorNotInitialized = 3,
// Deprecated
hipErrorInitializationError = 3,
hipErrorDeinitialized = 4,
hipErrorProfilerDisabled = 5,
hipErrorProfilerNotInitialized = 6,
hipErrorProfilerAlreadyStarted = 7,
hipErrorProfilerAlreadyStopped = 8,
hipErrorInvalidConfiguration = 9,
hipErrorInvalidPitchValue = 12,
hipErrorInvalidSymbol = 13,
hipErrorInvalidDevicePointer = 17, ///< Invalid Device Pointer
hipErrorInvalidMemcpyDirection = 21, ///< Invalid memory copy direction
hipErrorInsufficientDriver = 35,
hipErrorMissingConfiguration = 52,
hipErrorPriorLaunchFailure = 53,
hipErrorInvalidDeviceFunction = 98,
hipErrorNoDevice = 100, ///< Call to hipGetDeviceCount returned 0 devices
hipErrorInvalidDevice = 101, ///< DeviceID must be in range 0...#compute-devices.
hipErrorInvalidImage = 200,
hipErrorInvalidContext = 201, ///< Produced when input context is invalid.
hipErrorContextAlreadyCurrent = 202,
hipErrorMapFailed = 205,
// Deprecated
hipErrorMapBufferObjectFailed = 205, ///< Produced when the IPC memory attach failed from ROCr.
hipErrorUnmapFailed = 206,
hipErrorArrayIsMapped = 207,
hipErrorAlreadyMapped = 208,
hipErrorNoBinaryForGpu = 209,
hipErrorAlreadyAcquired = 210,
hipErrorNotMapped = 211,
hipErrorNotMappedAsArray = 212,
hipErrorNotMappedAsPointer = 213,
hipErrorECCNotCorrectable = 214,
hipErrorUnsupportedLimit = 215,
hipErrorContextAlreadyInUse = 216,
hipErrorPeerAccessUnsupported = 217,
hipErrorInvalidKernelFile = 218, ///< In CUDA DRV, it is CUDA_ERROR_INVALID_PTX
hipErrorInvalidGraphicsContext = 219,
hipErrorInvalidSource = 300,
hipErrorFileNotFound = 301,
hipErrorSharedObjectSymbolNotFound = 302,
hipErrorSharedObjectInitFailed = 303,
hipErrorOperatingSystem = 304,
hipErrorInvalidHandle = 400,
// Deprecated
hipErrorInvalidResourceHandle = 400, ///< Resource handle (hipEvent_t or hipStream_t) invalid.
hipErrorNotFound = 500,
hipErrorNotReady = 600, ///< Indicates that asynchronous operations enqueued earlier are not
///< ready. This is not actually an error, but is used to distinguish
///< from hipSuccess (which indicates completion). APIs that return
///< this error include hipEventQuery and hipStreamQuery.
hipErrorIllegalAddress = 700,
hipErrorLaunchOutOfResources = 701, ///< Out of resources error.
hipErrorLaunchTimeOut = 702,
hipErrorPeerAccessAlreadyEnabled =
704, ///< Peer access was already enabled from the current device.
hipErrorPeerAccessNotEnabled =
705, ///< Peer access was never enabled from the current device.
hipErrorSetOnActiveProcess = 708,
hipErrorAssert = 710, ///< Produced when the kernel calls assert.
hipErrorHostMemoryAlreadyRegistered =
712, ///< Produced when trying to lock a page-locked memory.
hipErrorHostMemoryNotRegistered =
713, ///< Produced when trying to unlock a non-page-locked memory.
hipErrorLaunchFailure =
719, ///< An exception occurred on the device while executing a kernel.
hipErrorCooperativeLaunchTooLarge =
720, ///< This error indicates that the number of blocks launched per grid for a kernel
///< that was launched via cooperative launch APIs exceeds the maximum number of
///< allowed blocks for the current device
hipErrorNotSupported = 801, ///< Produced when the hip API is not supported/implemented
hipErrorUnknown = 999, //< Unknown error.
// HSA Runtime Error Codes start here.
hipErrorRuntimeMemory = 1052, ///< HSA runtime memory call returned error. Typically not seen
///< in production systems.
hipErrorRuntimeOther = 1053, ///< HSA runtime call other than memory returned error. Typically
///< not seen in production systems.
hipErrorTbd ///< Marker that more error codes are needed.
} hipError_t;
typedef struct ihipCtx_t* hipCtx_t;
// Note many APIs also use integer deviceIds as an alternative to the device pointer:
typedef int hipDevice_t;
typedef enum hipDeviceP2PAttr {
hipDevP2PAttrPerformanceRank = 0,
hipDevP2PAttrAccessSupported,
hipDevP2PAttrNativeAtomicSupported,
hipDevP2PAttrHipArrayAccessSupported
} hipDeviceP2PAttr;
typedef struct ihipStream_t* hipStream_t;
#define hipIpcMemLazyEnablePeerAccess 0
#define HIP_IPC_HANDLE_SIZE 64
typedef struct hipIpcMemHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcMemHandle_t;
typedef struct hipIpcEventHandle_st {
char reserved[HIP_IPC_HANDLE_SIZE];
} hipIpcEventHandle_t;
typedef struct ihipModule_t* hipModule_t;
typedef struct ihipModuleSymbol_t* hipFunction_t;
typedef struct hipFuncAttributes {
int binaryVersion;
int cacheModeCA;
size_t constSizeBytes;
size_t localSizeBytes;
int maxDynamicSharedSizeBytes;
int maxThreadsPerBlock;
int numRegs;
int preferredShmemCarveout;
int ptxVersion;
size_t sharedSizeBytes;
} hipFuncAttributes;
typedef struct ihipEvent_t* hipEvent_t;
/*
* @brief hipDeviceAttribute_t
* @enum
* @ingroup Enumerations
*/
typedef enum hipDeviceAttribute_t {
hipDeviceAttributeMaxThreadsPerBlock, ///< Maximum number of threads per block.
hipDeviceAttributeMaxBlockDimX, ///< Maximum x-dimension of a block.
hipDeviceAttributeMaxBlockDimY, ///< Maximum y-dimension of a block.
hipDeviceAttributeMaxBlockDimZ, ///< Maximum z-dimension of a block.
hipDeviceAttributeMaxGridDimX, ///< Maximum x-dimension of a grid.
hipDeviceAttributeMaxGridDimY, ///< Maximum y-dimension of a grid.
hipDeviceAttributeMaxGridDimZ, ///< Maximum z-dimension of a grid.
hipDeviceAttributeMaxSharedMemoryPerBlock, ///< Maximum shared memory available per block in
///< bytes.
hipDeviceAttributeTotalConstantMemory, ///< Constant memory size in bytes.
hipDeviceAttributeWarpSize, ///< Warp size in threads.
hipDeviceAttributeMaxRegistersPerBlock, ///< Maximum number of 32-bit registers available to a
///< thread block. This number is shared by all thread
///< blocks simultaneously resident on a
///< multiprocessor.
hipDeviceAttributeClockRate, ///< Peak clock frequency in kilohertz.
hipDeviceAttributeMemoryClockRate, ///< Peak memory clock frequency in kilohertz.
hipDeviceAttributeMemoryBusWidth, ///< Global memory bus width in bits.
hipDeviceAttributeMultiprocessorCount, ///< Number of multiprocessors on the device.
hipDeviceAttributeComputeMode, ///< Compute mode that device is currently in.
hipDeviceAttributeL2CacheSize, ///< Size of L2 cache in bytes. 0 if the device doesn't have L2
///< cache.
hipDeviceAttributeMaxThreadsPerMultiProcessor, ///< Maximum resident threads per
///< multiprocessor.
hipDeviceAttributeComputeCapabilityMajor, ///< Major compute capability version number.
hipDeviceAttributeComputeCapabilityMinor, ///< Minor compute capability version number.
hipDeviceAttributeConcurrentKernels, ///< Device can possibly execute multiple kernels
///< concurrently.
hipDeviceAttributePciBusId, ///< PCI Bus ID.
hipDeviceAttributePciDeviceId, ///< PCI Device ID.
hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, ///< Maximum Shared Memory Per
///< Multiprocessor.
hipDeviceAttributeIsMultiGpuBoard, ///< Multiple GPU devices.
hipDeviceAttributeIntegrated, ///< iGPU
hipDeviceAttributeCooperativeLaunch, ///< Support cooperative launch
hipDeviceAttributeCooperativeMultiDeviceLaunch, ///< Support cooperative launch on multiple devices
hipDeviceAttributeMaxTexture1DWidth, ///< Maximum number of elements in 1D images
hipDeviceAttributeMaxTexture2DWidth, ///< Maximum dimension width of 2D images in image elements
hipDeviceAttributeMaxTexture2DHeight, ///< Maximum dimension height of 2D images in image elements
hipDeviceAttributeMaxTexture3DWidth, ///< Maximum dimension width of 3D images in image elements
hipDeviceAttributeMaxTexture3DHeight, ///< Maximum dimensions height of 3D images in image elements
hipDeviceAttributeMaxTexture3DDepth, ///< Maximum dimensions depth of 3D images in image elements
hipDeviceAttributeHdpMemFlushCntl, ///< Address of the HDP_MEM_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeHdpRegFlushCntl, ///< Address of the HDP_REG_COHERENCY_FLUSH_CNTL register
hipDeviceAttributeMaxPitch, ///< Maximum pitch in bytes allowed by memory copies
hipDeviceAttributeTextureAlignment, ///<Alignment requirement for textures
hipDeviceAttributeTexturePitchAlignment, ///<Pitch alignment requirement for 2D texture references bound to pitched memory;
hipDeviceAttributeKernelExecTimeout, ///<Run time limit for kernels executed on the device
hipDeviceAttributeCanMapHostMemory, ///<Device can map host memory into device address space
hipDeviceAttributeEccEnabled, ///<Device has ECC support enabled
hipDeviceAttributeCooperativeMultiDeviceUnmatchedFunc, ///< Supports cooperative launch on multiple
///devices with unmatched functions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedGridDim, ///< Supports cooperative launch on multiple
///devices with unmatched grid dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedBlockDim, ///< Supports cooperative launch on multiple
///devices with unmatched block dimensions
hipDeviceAttributeCooperativeMultiDeviceUnmatchedSharedMem, ///< Supports cooperative launch on multiple
///devices with unmatched shared memories
hipDeviceAttributeAsicRevision, ///< Revision of the GPU in this device
hipDeviceAttributeManagedMemory, ///< Device supports allocating managed memory on this system
hipDeviceAttributeDirectManagedMemAccessFromHost, ///< Host can directly access managed memory on
/// the device without migration
hipDeviceAttributeConcurrentManagedAccess, ///< Device can coherently access managed memory
/// concurrently with the CPU
hipDeviceAttributePageableMemoryAccess, ///< Device supports coherently accessing pageable memory
/// without calling hipHostRegister on it
hipDeviceAttributePageableMemoryAccessUsesHostPageTables, ///< Device accesses pageable memory via
/// the host's page tables
hipDeviceAttributeCanUseStreamWaitValue ///< '1' if Device supports hipStreamWaitValue32() and
///< hipStreamWaitValue64() , '0' otherwise.
} hipDeviceAttribute_t;
typedef void* hipDeviceptr_t;
/*
* @brief hipJitOption
* @enum
* @ingroup Enumerations
*/
typedef enum hipJitOption {
hipJitOptionMaxRegisters = 0,
hipJitOptionThreadsPerBlock,
hipJitOptionWallTime,
hipJitOptionInfoLogBuffer,
hipJitOptionInfoLogBufferSizeBytes,
hipJitOptionErrorLogBuffer,
hipJitOptionErrorLogBufferSizeBytes,
hipJitOptionOptimizationLevel,
hipJitOptionTargetFromContext,
hipJitOptionTarget,
hipJitOptionFallbackStrategy,
hipJitOptionGenerateDebugInfo,
hipJitOptionLogVerbose,
hipJitOptionGenerateLineInfo,
hipJitOptionCacheMode,
hipJitOptionSm3xOpt,
hipJitOptionFastCompile,
hipJitOptionNumOptions
} hipJitOption;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncAttribute {
hipFuncAttributeMaxDynamicSharedMemorySize = 8,
hipFuncAttributePreferredSharedMemoryCarveout = 9,
hipFuncAttributeMax
} hipFuncAttribute;
/**
* @warning On AMD devices and some Nvidia devices, these hints and controls are ignored.
*/
typedef enum hipFuncCache_t {
hipFuncCachePreferNone, ///< no preference for shared memory or L1 (default)
hipFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 cache
hipFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory
hipFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory
} hipFuncCache_t;
#define HIP_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01)
#define HIP_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02)
#define HIP_LAUNCH_PARAM_END ((void*)0x03)

View File

@@ -1,88 +0,0 @@
#pragma once
#ifndef _TRITON_IR_BASIC_BLOCK_H_
#define _TRITON_IR_BASIC_BLOCK_H_
#include <string>
#include <list>
#include "value.h"
#include "visitor.h"
namespace triton{
namespace ir{
class context;
class function;
class instruction;
/* Basic Block */
class basic_block: public value{
public:
// instruction iterator types
typedef std::list<instruction*> inst_list_t;
typedef inst_list_t::iterator iterator;
typedef inst_list_t::const_iterator const_iterator;
typedef inst_list_t::reverse_iterator reverse_iterator;
typedef inst_list_t::const_reverse_iterator const_reverse_iterator;
private:
// constructors
basic_block(context &ctx, const std::string &name, function *parent);
public:
// accessors
function* get_parent() { return parent_; }
context& get_context() { return ctx_; }
// get iterator to first instruction that is not a phi
iterator get_first_non_phi();
// get instruction list
inst_list_t &get_inst_list() { return inst_list_; }
const inst_list_t &get_inst_list() const { return inst_list_; }
void erase(instruction *i) { inst_list_.remove(i); }
// instruction iterator functions
inline iterator begin() { return inst_list_.begin(); }
inline const_iterator begin() const { return inst_list_.begin(); }
inline iterator end () { return inst_list_.end(); }
inline const_iterator end () const { return inst_list_.end(); }
inline reverse_iterator rbegin() { return inst_list_.rbegin(); }
inline const_reverse_iterator rbegin() const { return inst_list_.rbegin(); }
inline reverse_iterator rend () { return inst_list_.rend(); }
inline const_reverse_iterator rend () const { return inst_list_.rend(); }
inline size_t size() const { return inst_list_.size(); }
inline bool empty() const { return inst_list_.empty(); }
inline const instruction &front() const { return *inst_list_.front(); }
inline instruction &front() { return *inst_list_.front(); }
inline const instruction &back() const { return *inst_list_.back(); }
inline instruction &back() { return *inst_list_.back(); }
// predecessors
const std::vector<basic_block*>& get_predecessors() const { return preds_; }
const std::vector<basic_block*>& get_successors() const { return succs_; }
void add_predecessor(basic_block* pred);
// factory functions
static basic_block* create(context &ctx, const std::string &name, function *parent);
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_basic_block(this); }
private:
context &ctx_;
std::string name_;
function *parent_;
std::vector<basic_block*> preds_;
std::vector<basic_block*> succs_;
inst_list_t inst_list_;
};
}
}
#endif

View File

@@ -1,176 +0,0 @@
#pragma once
#ifndef _TRITON_IR_BUILDER_H_
#define _TRITON_IR_BUILDER_H_
#include <vector>
#include <string>
#include "instructions.h"
#include "basic_block.h"
#include "type.h"
namespace triton{
namespace ir{
class basic_block;
class value;
class type;
class constant_int;
class instruction;
class context;
class phi_node;
/* Builder */
class builder{
typedef basic_block::iterator iterator;
public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
void set_insert_point_after(instruction* i);
void set_insert_point(basic_block* block);
basic_block* get_insert_block() { return block_; }
iterator get_insert_point() { return insert_point_;}
// Constants
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();
type *get_int1_ty();
type *get_int8_ty();
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_half_ty();
type *get_float_ty();
type *get_double_ty();
// Insert
template<typename InstTy>
InstTy* insert(InstTy *inst){
assert(block_);
block_->get_inst_list().insert(insert_point_, inst);
inst->set_parent(block_);
// for(ir::value* op: inst->ops())
// op->add_use(inst);
return inst;
}
// terminator instructions
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Cast instructions
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
value* create_fp_to_si(value *src, type *dst_ty);
value* create_fp_to_ui(value *src, type *dst_ty);
value* create_fp_ext(value *src, type *dst_ty);
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs);
value *create_fdiv(value *lhs, value *rhs);
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
// Comparison (int)
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_icmpSLE(value *lhs, value *rhs);
value *create_icmpSLT(value *lhs, value *rhs);
value *create_icmpSGE(value *lhs, value *rhs);
value *create_icmpSGT(value *lhs, value *rhs);
value *create_icmpULE(value *lhs, value *rhs);
value *create_icmpULT(value *lhs, value *rhs);
value *create_icmpUGE(value *lhs, value *rhs);
value *create_icmpUGT(value *lhs, value *rhs);
value *create_icmpEQ(value *lhs, value *rhs);
value *create_icmpNE(value *lhs, value *rhs);
// Comparison (float)
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_fcmpOLT(value *lhs, value *rhs);
value *create_fcmpOGT(value *lhs, value *rhs);
value *create_fcmpOLE(value *lhs, value *rhs);
value *create_fcmpOGE(value *lhs, value *rhs);
value *create_fcmpOEQ(value *lhs, value *rhs);
value *create_fcmpONE(value *lhs, value *rhs);
value *create_fcmpULT(value *lhs, value *rhs);
value *create_fcmpUGT(value *lhs, value *rhs);
value *create_fcmpULE(value *lhs, value *rhs);
value *create_fcmpUGE(value *lhs, value *rhs);
value *create_fcmpUEQ(value *lhs, value *rhs);
value *create_fcmpUNE(value *lhs, value *rhs);
// Logical
value *create_and(value *lhs, value *rhs);
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, load_inst::CACHE_MODIFIER cache);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_cat(value *lhs, value *rhs);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Built-in instruction
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
// These have no place in the IR, and hopefully they can be removed at some point
value *create_umulhi(value* lhs, value* rhs);
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
value *create_prefetch_s(value *arg, int inc);
private:
context &ctx_;
basic_block *block_;
iterator insert_point_;
};
}
}
#endif

View File

@@ -1,113 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONSTANT_H_
#define _TRITON_IR_CONSTANT_H_
#include "enums.h"
#include "value.h"
#include <cassert>
#include "visitor.h"
namespace triton{
namespace ir{
class type;
class context;
/* Constant */
class constant: public user{
protected:
using user::user;
public:
static constant* get_all_ones_value(type *ty);
static constant* get_null_value(type *ty);
virtual std::string repr() const = 0;
};
/* Undef value */
class undef_value: public constant{
private:
undef_value(type *ty);
public:
static undef_value* get(type* ty);
std::string repr() const { return "undef"; }
void accept(visitor* vst) { vst->visit_undef_value(this); }
};
/* Constant int */
class constant_int: public constant{
protected:
constant_int(type *ty, uint64_t value);
public:
virtual uint64_t get_value() const { return value_; }
static constant_int *get(type *ty, uint64_t value);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_int(this); }
protected:
uint64_t value_;
};
/* Constant fp */
class constant_fp: public constant{
constant_fp(type *ty, double value);
public:
double get_value() { return value_; }
static constant* get_negative_zero(type *ty);
static constant* get_zero_value_for_negation(type *ty);
static constant* get(context &ctx, double v);
static constant* get(type *ty, double v);
std::string repr() const { return std::to_string(value_); }
void accept(visitor* vst) { vst->visit_constant_fp(this); }
private:
double value_;
};
/* Global Value */
class global_value: public constant {
public:
enum linkage_types_t {
external
};
public:
global_value(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space);
std::string repr() const { return get_name(); }
private:
linkage_types_t linkage_;
};
/* global object */
class global_object: public global_value {
public:
global_object(type *ty, unsigned num_ops,
linkage_types_t linkage, const std::string &name,
unsigned addr_space = 0);
std::string repr() const { return get_name(); }
};
/* global variable */
class alloc_const: public global_object {
public:
alloc_const(type *ty, constant_int *size,
const std::string &name = "");
std::string repr() const { return get_name(); }
void accept(visitor* vst) { vst->visit_alloc_const(this); }
};
}
}
#endif

View File

@@ -1,31 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_H_
#define _TRITON_IR_CONTEXT_H_
#include <memory>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class builder;
class type;
class context_impl;
/* Context */
class context {
public:
context();
context(const context&) = delete;
context& operator=(const context&) = delete;
public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl;
};
}
}
#endif

View File

@@ -1,48 +0,0 @@
#pragma once
#ifndef _TRITON_IR_CONTEXT_IMPL_H_
#define _TRITON_IR_CONTEXT_IMPL_H_
#include <map>
#include "triton/ir/type.h"
namespace triton{
namespace ir{
class context;
class constant;
class constant_int;
class constant_fp;
class undef_value;
/* Context impl */
class context_impl {
public:
// constructors
context_impl(context &ctx);
public:
// non-numeric types
type void_ty, label_ty;
// floating point types
type fp8_ty, fp16_ty, bf16_ty, fp32_ty, fp64_ty;
// integer types
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
// Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
// undef values
std::map<type*, undef_value*> uv_constants_;
};
}
}
#endif

View File

@@ -1,110 +0,0 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
#include <stdexcept>
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_add(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_max(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_min(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_and(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_or(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xor(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
// math
static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder);
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *cos(ir::value *x, ir::builder *builder);
static ir::value *sin(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *max_contiguous(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -1,175 +0,0 @@
#pragma once
#ifndef _TRITON_IR_ENUMS_H_
#define _TRITON_IR_ENUMS_H_
namespace triton{
namespace ir{
enum binary_op_t: unsigned int{
Add,
FAdd,
Sub,
FSub,
Mul,
FMul,
UDiv,
SDiv,
FDiv,
URem,
SRem,
FRem,
Shl,
LShr,
AShr,
And,
Or,
Xor
};
enum class atomic_rmw_op_t: unsigned int{
And,
Or,
Xor,
Add,
Max,
Min,
UMax,
UMin,
FAdd,
Xchg,
};
enum cast_op_t: unsigned int {
Trunc,
ZExt,
SExt,
FPTrunc,
FPExt,
UIToFP,
SIToFP,
FPToUI,
FPToSI,
PtrToInt,
IntToPtr,
BitCast,
AddrSpaceCast
};
enum cmp_pred_t: unsigned int {
FIRST_FCMP_PREDICATE,
FCMP_FALSE,
FCMP_OEQ,
FCMP_OGT,
FCMP_OGE,
FCMP_OLT,
FCMP_OLE,
FCMP_ONE,
FCMP_ORD,
FCMP_UNO,
FCMP_UEQ,
FCMP_UGT,
FCMP_UGE,
FCMP_ULT,
FCMP_ULE,
FCMP_UNE,
FCMP_TRUE,
LAST_FCMP_PREDICATE,
FIRST_ICMP_PREDICATE,
ICMP_EQ,
ICMP_NE,
ICMP_UGT,
ICMP_UGE,
ICMP_ULT,
ICMP_ULE,
ICMP_SGT,
ICMP_SGE,
ICMP_SLT,
ICMP_SLE,
LAST_ICMP_PREDICATE
};
enum value_id_t: unsigned {
/* ------------ *
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// phi
INST_PHI,
// arithmetic
INST_BINOP,
INST_GETELEMENTPTR,
INST_SELECT,
INST_SQRT,
// cmp
INST_ICMP,
INST_FCMP,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
INST_CAST_SEXT,
INST_CAST_FP_TRUNC,
INST_CAST_FP_EXT,
INST_CAST_UI_TO_FP,
INST_CAST_SI_TO_FP,
INST_CAST_FP_TO_UI,
INST_CAST_FP_TO_SI,
INST_CAST_PTR_TO_INT,
INST_CAST_INT_TO_PTR,
INST_CAST_BIT_CAST,
INST_CAST_ADDR_SPACE_CAST,
// terminators
INST_RETURN,
INST_COND_BRANCH,
INST_UNCOND_BRANCH,
// io
INST_UNMASKED_LOAD,
INST_MASKED_LOAD,
INST_MASKED_LOAD_ASYNC,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// retile
INST_RESHAPE,
INST_SPLAT,
INST_CAT,
INST_BROADCAST,
INST_DOWNCAST,
// builtin
INST_GET_PROGRAM_ID,
INST_GET_NUM_PROGRAMS,
// atomics
INST_ATOMIC_CAS,
INST_ATOMIC_EXCH,
INST_ATOMIC_RMW,
// math
INST_UMULHI,
INST_EXP,
INST_COS,
INST_SIN,
INST_LOG,
// array arithmetic
INST_TRANS,
INST_REDUCE,
INST_DOT,
// intrinsics
INST_COPY_TO_SHARED,
INST_COPY_FROM_SHARED,
INST_CVT_LAYOUT,
INST_CVT_SCANLINE,
INST_DECOALESCE,
INST_RECOALESCE,
INST_BARRIER,
INST_ASYNC_WAIT,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE,
INST_PREFETCH_S,
};
}
}
#endif

View File

@@ -1,142 +0,0 @@
#pragma once
#ifndef _TRITON_IR_FUNCTION_H_
#define _TRITON_IR_FUNCTION_H_
#include <string>
#include <map>
#include "value.h"
#include "constant.h"
namespace triton{
namespace ir{
class function;
class function_type;
class module;
class basic_block;
/* Argument */
class argument: public value{
argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
public:
static argument* create(type *ty, const std::string &name,
function *parent = nullptr, unsigned arg_no = 0);
function* get_parent() const;
unsigned get_arg_no() const;
void accept(visitor *v);
private:
function *parent_;
unsigned arg_no_;
};
/* Attribute */
enum attribute_kind_t {
readonly = 0,
writeonly,
noalias,
aligned,
multiple_of,
retune,
not_implemented
};
class attribute {
public:
attribute(attribute_kind_t kind, unsigned value = 0):
kind_(kind), value_(value){}
bool operator<(const attribute& other) const {
return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
}
attribute_kind_t get_kind() const {
return kind_;
}
unsigned get_value() const {
return value_;
}
bool is_llvm_attr() const {
return kind_ != multiple_of;
}
std::string repr() const {
switch(kind_){
case readonly: return ".readonly";
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
case retune: return ".retunr";
default: break;
}
assert(false);
return "";
}
private:
attribute_kind_t kind_;
unsigned value_;
};
/* Function */
class function: public global_object{
typedef std::vector<argument*> args_t;
typedef args_t::iterator arg_iterator;
typedef args_t::const_iterator const_arg_iterator;
typedef std::vector<basic_block*> blocks_t;
typedef blocks_t::iterator block_iterator;
typedef blocks_t::const_iterator const_block_iterator;
typedef std::map<unsigned, std::set<attribute>> attr_map_t;
private:
function(function_type *ty, linkage_types_t linkage,
const std::string &name = "", module *parent = nullptr);
public:
// accessors
const args_t &args() const { return args_; }
function_type* get_fn_type() { return fn_ty_; }
const function_type* get_fn_type() const { return fn_ty_; }
module *get_parent() { return parent_; }
const module *get_parent() const { return parent_; }
// factory methods
static function *create(function_type *ty, linkage_types_t linkage,
const std::string &name, module *mod);
// blocks
const blocks_t &blocks() { return blocks_; }
const blocks_t &blocks() const { return blocks_; }
void insert_block(basic_block* block, basic_block *next = nullptr);
// attributes
void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
const attr_map_t &attrs() { return attrs_; }
bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
void print(std::ostream &os);
// visitor
void accept(visitor *v) { v->visit_function(this); }
private:
module *parent_;
bool init_;
function_type *fn_ty_;
args_t args_;
blocks_t blocks_;
attr_map_t attrs_;
};
}
}
#endif

View File

@@ -1,935 +0,0 @@
#pragma once
#ifndef _TRITON_IR_INSTRUCTIONS_H_
#define _TRITON_IR_INSTRUCTIONS_H_
#include <vector>
#include <map>
#include "triton/ir/enums.h"
#include "triton/ir/constant.h"
#include "triton/ir/value.h"
#include "triton/ir/type.h"
#include "triton/ir/metadata.h"
#include "triton/ir/visitor.h"
#define _TRITON_DEFINE_CLONE(name) \
ir::instruction* clone_impl() const { return new name(*this); }
#define _TRITON_DEFINE_ACCEPT(name) \
void accept(visitor* v) { v->visit_ ## name (this); }
namespace triton{
namespace ir{
class constant_int;
class constant;
class make_range;
class basic_block;
class context;
class visitor;
//===----------------------------------------------------------------------===//
// instruction classes
//===----------------------------------------------------------------------===//
class result_reference;
class instruction: public user{
public:
virtual std::string repr_impl() const = 0;
private:
virtual ir::instruction* clone_impl() const = 0;
protected:
// constructors
instruction(type *ty, value_id_t ity, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// parent
void set_parent(basic_block *block) { parent_ = block; }
const basic_block *get_parent() const { return parent_; }
basic_block *get_parent() { return parent_; }
void erase_from_parent();
// helpers
bool has_tile_result_or_op();
// repr
std::string repr() const { return repr_impl(); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
// for(auto it = op_begin(); it != op_end(); it++)
// (*it)->add_use(res);
res->parent_ = nullptr;
res->users_.clear();
return res;
}
// instruction id
value_id_t get_id() const { return id_; }
void print(std::ostream &os);
private:
basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
value_id_t id_;
};
//===----------------------------------------------------------------------===//
// phi_node classes
//===----------------------------------------------------------------------===//
class phi_node: public instruction {
private:
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
std::string repr_impl() const { return "phi"; }
public:
void set_incoming_value(unsigned i, value *v);
void set_incoming_block(unsigned i, basic_block *block);
value *get_value_for_block(basic_block *block);
value *get_incoming_value(unsigned i) { return get_operand(i); }
basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
unsigned get_num_incoming() { return get_num_operands(); }
void add_incoming(value *v, basic_block *block);
// Type
void set_type(type *ty) { ty_ = ty; }
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(phi_node)
_TRITON_DEFINE_ACCEPT(phi_node)
private:
unsigned num_reserved_;
std::vector<basic_block*> blocks_;
};
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction {
public:
typedef binary_op_t op_t;
private:
std::string repr_impl() const;
protected:
// Constructors
binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
public:
// Get operand
binary_op_t get_op() const { return op_; }
// Bool
bool is_terminator() const;
bool is_binary_op() const;
bool is_int_div_rem() const;
bool is_shift() const;
bool is_cast() const;
bool is_int_mult() const;
bool is_int_add_sub() const;
bool is_int_div() const;
bool is_int_rem() const;
bool is_shl() const;
bool is_shr() const;
// Wraps
void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
// Factory methods
static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(binary_operator)
_TRITON_DEFINE_ACCEPT(binary_operator)
public:
binary_op_t op_;
bool has_no_unsigned_wrap_;
bool has_no_signed_wrap_;
};
//===----------------------------------------------------------------------===//
// cmp_inst classes
//===----------------------------------------------------------------------===//
class cmp_inst: public instruction{
public:
typedef cmp_pred_t pred_t;
private:
std::string repr_impl() const;
protected:
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
static bool is_fp_predicate(cmp_pred_t pred);
static bool is_int_predicate(cmp_pred_t pred);
static type* make_cmp_result_type(type *ty);
public:
cmp_pred_t get_pred() const { return pred_; }
private:
cmp_pred_t pred_;
};
class icmp_inst: public cmp_inst {
icmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(icmp_inst)
_TRITON_DEFINE_ACCEPT(icmp_inst)
};
class fcmp_inst: public cmp_inst {
fcmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(fcmp_inst)
_TRITON_DEFINE_ACCEPT(fcmp_inst)
};
//===----------------------------------------------------------------------===//
// unary_inst classes
//===----------------------------------------------------------------------===//
class unary_inst: public instruction {
protected:
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
//===----------------------------------------------------------------------===//
// cast_inst classes
//===----------------------------------------------------------------------===//
class cast_inst: public unary_inst{
private:
std::string repr_impl() const;
protected:
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
: unary_inst(ty, id, v, name, next), op_(op) { }
private:
static bool is_valid(cast_op_t op, value *arg, type *ty);
public:
// accessors
cast_op_t get_op() const { return op_; }
// factory methods
static cast_inst *create(cast_op_t op, value *arg, type *ty,
const std::string &name = "", instruction *next = nullptr);
static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_ACCEPT(cast_inst)
private:
cast_op_t op_;
};
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
class name : public cast_inst { \
_TRITON_DEFINE_CLONE(name) \
friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, id, v, name, next, op){ } \
};
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
//===----------------------------------------------------------------------===//
// terminator_inst classes
//===----------------------------------------------------------------------===//
class terminator_inst: public instruction{
using instruction::instruction;
};
// return instruction
class return_inst: public terminator_inst {
private:
std::string repr_impl() const { return "ret"; }
return_inst(context &ctx, value *ret_val, instruction *next);
public:
// accessors
value *get_return_value()
{ return get_num_operands() ? get_operand(0) : nullptr; }
unsigned get_num_successors() const { return 0; }
// factory methods
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
_TRITON_DEFINE_CLONE(return_inst)
_TRITON_DEFINE_ACCEPT(return_inst)
};
// base branch instruction
class branch_inst: public terminator_inst{
private:
std::string repr_impl() const { return "br"; }
protected:
using terminator_inst::terminator_inst;
public:
static branch_inst* create(basic_block *dest,
instruction *next = nullptr);
static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
instruction *next = nullptr);
};
// conditional branch
class cond_branch_inst: public branch_inst {
private:
friend class branch_inst;
cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
public:
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
value *get_cond() { return get_operand(2); }
_TRITON_DEFINE_CLONE(cond_branch_inst)
_TRITON_DEFINE_ACCEPT(cond_branch_inst)
};
// unconditional branch
class uncond_branch_inst: public branch_inst {
private:
friend class branch_inst;
uncond_branch_inst(basic_block *dst, instruction *next);
public:
basic_block *get_dest() { return (basic_block*)get_operand(0); }
_TRITON_DEFINE_CLONE(uncond_branch_inst)
_TRITON_DEFINE_ACCEPT(uncond_branch_inst)
};
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//
class getelementptr_inst: public instruction {
private:
std::string repr_impl() const { return "getelementptr"; }
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
private:
static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
public:
// accessors
type *get_source_elt_ty() { return source_elt_ty; }
op_iterator idx_begin() { return op_begin() + 1; }
op_iterator idx_end() { return op_end(); }
value *get_pointer_operand() { return *op_begin(); }
// factory methods
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(getelementptr_inst)
_TRITON_DEFINE_ACCEPT(getelementptr_inst)
private:
type *source_elt_ty;
type *res_elt_ty;
};
//===----------------------------------------------------------------------===//
// load_inst/store_inst classes
//===----------------------------------------------------------------------===//
class io_inst: public instruction {
protected:
io_inst(type *ty, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
};
// load
class load_inst: public io_inst {
public:
enum CACHE_MODIFIER : uint32_t {
NONE=0,
CA,
CG,
};
CACHE_MODIFIER get_cache_modifier() const { return cache_; }
protected:
load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache,
const std::string &name = "", instruction *next = nullptr);
std::string get_cache_modifier_repr() const {
if (cache_ == CA) return ".ca";
if (cache_ == CG) return ".cg";
return "";
}
CACHE_MODIFIER cache_;
private:
static type *get_pointee_type(type *ty);
};
// unmasked load
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
CACHE_MODIFIER cache,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
_TRITON_DEFINE_ACCEPT(unmasked_load_inst)
};
// masked load
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
CACHE_MODIFIER cache,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
_TRITON_DEFINE_ACCEPT(masked_load_inst)
};
// masked load async
class masked_load_async_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load_async_async" + get_cache_modifier_repr(); }
masked_load_async_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(1); }
value *get_false_value_operand() { return get_operand(2); }
// factory method
static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
load_inst::CACHE_MODIFIER cache,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_async_inst)
_TRITON_DEFINE_ACCEPT(masked_load_async_inst)
};
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
value *get_value_operand() { return get_operand(1); }
};
// unmasked_store
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
public:
// factory method
static unmasked_store_inst* create(value* ptr, value *v,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
_TRITON_DEFINE_ACCEPT(unmasked_store_inst)
};
class masked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "masked_store"; }
masked_store_inst(value *ptr, value *v, value *mask,
const std::string &name, instruction *next);
public:
// accessors
value *get_mask_operand() { return get_operand(2); }
// factory method
static masked_store_inst* create(value *ptr, value *v, value *mask,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
_TRITON_DEFINE_ACCEPT(masked_store_inst)
};
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
// cat
class cat_inst: public instruction {
private:
std::string repr_impl() const { return "cat"; }
cat_inst(value *x, value *y, const std::string &name, instruction *next);
public:
static instruction* create(value *lhs, value *rhs,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(cat_inst)
_TRITON_DEFINE_ACCEPT(cat_inst)
};
// retile
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
};
// reshape
class reshape_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "reshape"; }
public:
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
_TRITON_DEFINE_ACCEPT(reshape_inst)
};
// splat
class splat_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "splat"; }
public:
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
_TRITON_DEFINE_ACCEPT(splat_inst)
};
// broadcast
class broadcast_inst: public retile_inst {
private:
using retile_inst::retile_inst;
std::string repr_impl() const { return "broadcast"; }
public:
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
_TRITON_DEFINE_ACCEPT(broadcast_inst)
};
// downcast
class downcast_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "downcast"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(downcast_inst)
_TRITON_DEFINE_ACCEPT(downcast_inst)
};
//===----------------------------------------------------------------------===//
// builtin_inst classes
//===----------------------------------------------------------------------===//
class builtin_inst: public instruction{
protected:
using instruction::instruction;
};
class get_program_id_inst: public builtin_inst {
private:
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_program_id_inst)
_TRITON_DEFINE_ACCEPT(get_program_id_inst)
private:
unsigned axis_;
};
class get_num_programs_inst: public builtin_inst {
private:
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_programs_inst)
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
private:
unsigned axis_;
};
class atomic_inst: public io_inst {
public:
using io_inst::io_inst;
};
class atomic_rmw_inst: public atomic_inst {
private:
atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_rmw"; }
_TRITON_DEFINE_CLONE(atomic_rmw_inst)
_TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
public:
static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
atomic_rmw_op_t get_op() { return op_; }
private:
atomic_rmw_op_t op_;
};
class atomic_cas_inst: public atomic_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; }
_TRITON_DEFINE_CLONE(atomic_cas_inst)
_TRITON_DEFINE_ACCEPT(atomic_cas_inst)
public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class umulhi_inst: public builtin_inst {
private:
umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "umulhi"; }
_TRITON_DEFINE_CLONE(umulhi_inst)
_TRITON_DEFINE_ACCEPT(umulhi_inst)
public:
static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
};
class exp_inst: public builtin_inst {
private:
exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "exp"; }
_TRITON_DEFINE_CLONE(exp_inst)
_TRITON_DEFINE_ACCEPT(exp_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class cos_inst: public builtin_inst {
private:
cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "cos"; }
_TRITON_DEFINE_CLONE(cos_inst)
_TRITON_DEFINE_ACCEPT(cos_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class sin_inst: public builtin_inst {
private:
sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "sin"; }
_TRITON_DEFINE_CLONE(sin_inst)
_TRITON_DEFINE_ACCEPT(sin_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class log_inst: public builtin_inst {
private:
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "log"; }
_TRITON_DEFINE_CLONE(log_inst)
_TRITON_DEFINE_ACCEPT(log_inst)
public:
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
};
class dot_inst: public builtin_inst {
public:
enum TransT { NoTrans, Trans };
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst)
};
//class outer_inst: public builtin_inst {
//private:
// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
//public:
// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
//};
class trans_inst: public builtin_inst {
public:
ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
private:
trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; }
public:
static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<int> get_perm() const;
_TRITON_DEFINE_CLONE(trans_inst)
_TRITON_DEFINE_ACCEPT(trans_inst)
private:
std::vector<int> perm_;
};
class sqrt_inst: public builtin_inst {
private:
sqrt_inst(value *arg, const std::string& name, instruction* next);
std::string repr_impl() const { return "sqrt"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(sqrt_inst)
_TRITON_DEFINE_ACCEPT(sqrt_inst)
};
class reduce_inst: public builtin_inst {
public:
enum op_t{
ADD, SUB, MAX, MIN,
FADD, FSUB, FMAX, FMIN
};
private:
static type* get_res_type(value *arg, unsigned axis);
static std::string to_str(op_t op);
private:
reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
_TRITON_DEFINE_CLONE(reduce_inst)
_TRITON_DEFINE_ACCEPT(reduce_inst)
public:
static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
op_t get_op() const { return op_; }
private:
unsigned axis_;
op_t op_;
};
class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
std::string repr_impl() const { return "select"; }
_TRITON_DEFINE_CLONE(select_inst)
_TRITON_DEFINE_ACCEPT(select_inst)
public:
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
value* get_pred_op() { return get_operand(0); }
value* get_if_value_op() { return get_operand(1); }
value* get_else_value_op() { return get_operand(2); }
};
//===----------------------------------------------------------------------===//
// intrinsics classes
//===----------------------------------------------------------------------===//
class copy_to_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_to_shared"; }
public:
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
_TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
};
class copy_from_shared_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "copy_from_shared"; }
public:
static copy_from_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_from_shared_inst)
_TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
};
class cvt_layout_inst: public unary_inst {
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "cvt_layout_inst"; }
public:
static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(cvt_layout_inst)
_TRITON_DEFINE_ACCEPT(cvt_layout_inst)
};
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "barrier"; }
_TRITON_DEFINE_CLONE(barrier_inst)
_TRITON_DEFINE_ACCEPT(barrier_inst)
public:
static barrier_inst* create(context &ctx, const std::string &name = "",
instruction *next = nullptr);
};
class async_wait_inst: public instruction{
private:
async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
_TRITON_DEFINE_CLONE(async_wait_inst)
_TRITON_DEFINE_ACCEPT(async_wait_inst)
public:
static async_wait_inst* create(context &ctx, int N,
const std::string &name = "", instruction *next = nullptr);
int get_N() { return N_; }
void set_N(int n) { N_ = n; }
private:
int N_;
};
class prefetch_s_inst : public instruction {
std::string repr_impl() const { return "prefetch_s"; }
_TRITON_DEFINE_CLONE(prefetch_s_inst)
_TRITON_DEFINE_ACCEPT(prefetch_s_inst)
/// inc_: 0->first, 1->latch
int inc_ = 0;
public:
prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
set_operand(0, arg);
}
int get_inc() const { return inc_; }
static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
instruction *next=nullptr);
};
/* constant range */
class make_range: public instruction{
make_range(type *ty, constant_int* first, constant_int* last);
std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
_TRITON_DEFINE_CLONE(make_range)
_TRITON_DEFINE_ACCEPT(make_range)
public:
static make_range *create(constant_int *first, constant_int *last);
const constant_int* get_first() const;
const constant_int* get_last() const;
private:
constant_int* first_;
constant_int* last_;
};
}
}
#endif

View File

@@ -1,32 +0,0 @@
#pragma once
#ifndef _TRITON_IR_METADATA_H_
#define _TRITON_IR_METADATA_H_
namespace triton{
namespace ir{
/* Metadata */
class metadata{
public:
enum kind_t{
multiple_of,
max_contiguous
};
private:
metadata(kind_t kind, unsigned value);
public:
static metadata* get(kind_t kind, unsigned value);
private:
kind_t kind_;
unsigned value_;
};
}
}
#endif

View File

@@ -1,112 +0,0 @@
#pragma once
#ifndef _TRITON_IR_MODULE_H_
#define _TRITON_IR_MODULE_H_
#include <map>
#include <set>
#include <stack>
#include <string>
#include <functional>
#include "triton/ir/builder.h"
#include "triton/ir/metadata.h"
#include "triton/ir/context.h"
namespace triton{
namespace lang{
class iteration_statement;
class compound_statement;
}
namespace ir{
class basic_block;
class phi_node;
class value;
class context;
class function;
class attribute;
class function_type;
class constant;
class global_value;
class alloc_const;
/* Module */
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
typedef std::pair<ir::metadata::kind_t, unsigned> md_pair_t;
public:
typedef std::map<std::string, global_value*> symbols_map_t;
typedef std::vector<function*> functions_list_t;
struct current_iteration_info_t{
lang::iteration_statement *statement;
basic_block *block;
};
private:
phi_node *make_phi(type *ty, unsigned num_values, basic_block *block);
value *try_remove_trivial_phis(ir::phi_node *&phi);
value *add_phi_operands(const std::string& name, phi_node *&phi);
value *get_value_recursive(const std::string& name, basic_block *block);
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name, builder& builder);
builder& get_builder();
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
const std::map<val_key_t, value*>& get_values() { return values_; }
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
const std::string& get_name();
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
// Functions
const functions_list_t &get_function_list() const { return functions_; }
functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty);
// Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
// Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
// Metadata
void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; }
void print(std::ostream &os);
private:
std::string name_;
builder& builder_;
std::map<val_key_t, value*> values_;
std::map<std::string, type*> types_;
std::set<std::string> const_;
std::set<basic_block*> sealed_blocks_;
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
functions_list_t functions_;
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
std::map<std::string, md_pair_t> metadatas_;
};
}
}
#endif

View File

@@ -1,22 +0,0 @@
#ifndef _TRITON_IR_PRINT_H_
#define _TRITON_IR_PRINT_H_
#include "builder.h"
namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
void print(module &mod, std::ostream& os);
void print(function &func, std::ostream& os);
void print(basic_block &bb, std::ostream& os);
void print(instruction &instr, std::ostream& os);
}
}
#endif

View File

@@ -1,240 +0,0 @@
#pragma once
#ifndef _TRITON_IR_TYPE_H_
#define _TRITON_IR_TYPE_H_
#include <cassert>
#include <vector>
#include <string>
namespace triton{
namespace ir{
class context;
class value;
class integer_type;
class constant_int;
/* Type */
class type {
public:
typedef std::vector<unsigned> block_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t;
typedef contained_tys_vec_t::iterator ty_iterator;
typedef contained_tys_vec_t::const_iterator const_ty_iterator;
public:
enum id_t {
// primitive types
VoidTyID = 0, ///< type with no size
FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
FP32TyID, ///< 32-bit floating point type
FP64TyID, ///< 64-bit floating point type
LabelTyID, ///< Labels
MetadataTyID, ///< Metadata
TokenTyID, ///< Token
// derived types
IntegerTyID, ///< Arbitrary bit width integers
FunctionTyID, ///< Functions
PointerTyID, ///< Pointers
StructTyID, ///< Struct
BlockTyID, ///< Block
};
public:
//constructors
type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
//destructor
virtual ~type(){}
// accessors
context &get_context() const { return ctx_; }
id_t get_type_id() const { return id_; }
// type attributes
unsigned get_fp_mantissa_width() const;
unsigned get_integer_bitwidth() const;
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
block_shapes_t get_block_shapes() const;
const size_t get_tile_rank() const;
const size_t get_tile_ranks1() const;
unsigned get_tile_num_elements() const;
type *get_tile_element_ty() const;
unsigned get_pointer_address_space() const;
type *get_pointer_element_ty() const;
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_fp8_ty() const { return id_ == FP8TyID; }
bool is_fp16_ty() const { return id_ == FP16TyID; }
bool is_bf16_ty() const { return id_ == BF16TyID; }
bool is_fp32_ty() const { return id_ == FP32TyID; }
bool is_fp64_ty() const { return id_ == FP64TyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
bool is_integer_ty(unsigned width) const;
bool is_floating_point_ty() const;
bool is_sized() const ;
// Factory methods
// primitive types
static type *get_void_ty(context &ctx);
static type *get_label_ty(context &ctx);
// half
static type *get_fp8_ty(context &ctx);
static type *get_fp16_ty(context &ctx);
static type *get_bf16_ty(context &ctx);
static type *get_fp32_ty(context &ctx);
static type *get_fp64_ty(context &ctx);
// integer types
static integer_type *get_int1_ty(context &ctx);
static integer_type *get_int8_ty(context &ctx);
static integer_type *get_int16_ty(context &ctx);
static integer_type *get_int32_ty(context &ctx);
static integer_type *get_int64_ty(context &ctx);
static integer_type *get_int128_ty(context &ctx);
// repr
std::string tile_repr() const {
std::string res = get_tile_element_ty()->repr();
auto shapes = get_block_shapes();
res += "<";
for(size_t i = 0; i < shapes.size(); i++){
if(i > 0)
res += ", ";
res += std::to_string(shapes[i]);
}
res+= ">";
return res;
}
std::string repr() const {
switch(id_) {
case VoidTyID: return "void";
case FP8TyID: return "fp8";
case FP16TyID: return "f16";
case FP32TyID: return "f32";
case FP64TyID: return "f64";
case LabelTyID: return "label";
case MetadataTyID: return "md";
case TokenTyID: return "tok";
case IntegerTyID: return "i" + std::to_string(get_integer_bitwidth());
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
case BlockTyID: return tile_repr();
default: break;
}
assert(false);
return "";
};
private:
context &ctx_;
id_t id_;
protected:
contained_tys_vec_t contained_tys_;
};
class integer_type: public type {
friend class context_impl;
private:
// constructors
integer_type(context &ctx, unsigned bitwidth)
: type(ctx, IntegerTyID), bitwidth_(bitwidth){ }
public:
// accessors
unsigned get_bitwidth() const { return bitwidth_; }
// factory methods
static integer_type* get(context &ctx, unsigned width);
private:
unsigned bitwidth_;
};
class composite_type: public type{
protected:
using type::type;
public:
bool index_valid(value *idx) const;
type* get_type_at_index(value *idx) const;
};
class block_type: public composite_type {
private:
block_type(type *ty, const block_shapes_t &shapes);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
const block_shapes_t& get_shapes() const { return shapes_; }
unsigned get_num_elements() const;
unsigned get_bitwidth() const;
// factory methods
static block_type* get(type *ty, const block_shapes_t &shapes);
static block_type* get_same_shapes(type *ty, type *ref);
private:
block_shapes_t shapes_;
};
class pointer_type: public type {
private:
pointer_type(type *ty, unsigned address_space);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
unsigned get_address_space() const { return address_space_; }
type *get_element_ty() const { return contained_tys_[0]; }
// factory methods
static pointer_type* get(type *ty, unsigned address_space);
private:
unsigned address_space_;
};
class function_type: public type {
private:
function_type(type *ret_ty, const std::vector<type *> &param_tys);
public:
// accessors
unsigned get_num_params() const { return contained_tys_.size() - 1; }
const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
const_ty_iterator params_end() const { return contained_tys_.end(); }
ty_iterator params_begin() { return contained_tys_.begin() + 1; }
ty_iterator params_end() { return contained_tys_.end(); }
type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
type* get_return_ty() const { return contained_tys_.at(0); }
// factory methods
static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
};
}
}
#endif

Some files were not shown because too many files have changed in this diff Show More