Commit Graph

437 Commits

Author SHA1 Message Date
Crutcher Dunnavant
f98aed1258 [Triton-MLIR][RUNTIME] Add /usr/bin/ptxas as a search path (#909)
Make `ptxas` search a bit broader to include `/usr/bin/ptxas`, installed
by the lambda stack repo versions:
https://lambdalabs.com/lambda-stack-deep-learning-software
2022-11-24 18:49:16 +00:00
Crutcher Dunnavant
ace7d28736 [Triton-MLIR][RUNTIME] Fix ir metadata lookup bug (#910) 2022-11-24 09:27:23 +01:00
donproc
8925c2cd11 [TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar

Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-11-23 07:59:09 +00:00
Keren Zhou
2e33352419 [Triton-MLIR] Fix side effects (#906)
Try to add proper side effects for triton operations. 

The CSE pass could fail, hang, or output incorrect IRs for unknown
reasons, if side effects are not defined properly.

For instance, suppose we have two shared memory tensors:

```
%a = triton_gpu.alloc_tensor shape0, share_encoding0
%b = triton_gpu.alloc_tensor shape0, share_encoding0
```

The CSE pass will consider `%a` and `%b` are the same thing and
eliminate one of them, resulting in mysterious outcomes.
2022-11-22 23:29:18 -08:00
Yan Chunwei
037f9efa95 [Triton-MLIR][BACKEND] Fix wpt overflow issue in mma v2 (#904)
This PR

1. Fix wpt overflow issue in mma v2
2. Refine transpose logic
2022-11-23 11:27:15 +08:00
ben-zhang-609
07786dc932 [Triton-MLIR] Add compute capability (#902)
add compute capability from python frontend to backend.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-11-22 11:08:23 -08:00
Keren Zhou
85cccfb81f [BUILD] Fix compilation problems in the release build (#897) 2022-11-21 05:40:36 +00:00
Philippe Tillet
23f71daa27 [OPTIMIZER] Fixed up order of shared layouts (#881) 2022-11-21 06:25:02 +01:00
Philippe Tillet
4d64ffb5fe [FRONTEND] Handle for loops with negative constant steps (#896) 2022-11-20 11:37:38 +01:00
Keren Zhou
6c5f646f4e [WIP][Triton-MLIR] Prefetch pass fixup (#873)
A (potential) problem by directly adopting `tensor.extract_slice`.

Long story short, `tensor.extract_slice` is not aware of swizzling.
Consider the following shared memory tensor and its first three slices,
where each slice includes two tile (the loading unit of LDGSTS) of
elements. Currently, the tiles haven't been swizzled yet, so slicing
seems to work.

<img width="1219" alt="image"
src="https://user-images.githubusercontent.com/2306281/201833023-a7950705-2d50-4c0a-8527-7505261c3a3c.png">

However, now consider the following figure, which is the layout after
applying swizzling on the first figure.

<img width="1244" alt="image"
src="https://user-images.githubusercontent.com/2306281/201834824-7daae360-f5bc-4e6b-a921-20be3f294b78.png">

Note that on phase 2, all tiles have been swizzled out of their
originally slices. This implies that if we use the tile index after
slicing, we can no longer locate the correct tiles. For example, T3 was
in slice 1 but got swapped to slice 0 after swizzling.

Here's a more detailed explanation. In the current `triton-mlir` branch,
we only compute the relative offset of each tile. So T3's index in Slice
1 is *1*, and it will be swizzled using *1* and *phase id*. Whereas the
correct index of T3 should be *3*, which is the relative offset to the
beginning of the shared memory tensor being swizzled, and T3 should be
swizzled using *3* and *phase id*.

This PR proposes a hacky solution for this problem. We restore the
"correct" offset of each tile by **assuming that slicing on a specific
dim only happens at most once on the output of insert_slice_async**. I
admit it's risky and fragile.

The other possible solution is adopting cutlass' swizzling logic that
limits the indices being swizzled in a "bounding box" that matches the
mma instruction executes. For example, in the following tensor layout,
each 4x4 submatrix is a minimum swizzling unit, and the entire tensor
represents the tensor layout of operand A in `mma.16816`.

<img width="565" alt="image"
src="https://user-images.githubusercontent.com/2306281/201836879-4ca7824b-530c-4a06-a3d5-1e74a2de1b42.png">

Co-authored-by: Phil Tillet <phil@openai.com>
2022-11-19 19:57:16 -08:00
Jun Yang
8a5647782d [Triton-MLIR][Testing]Fix tests warning, with small code clean-up (#894)
1.Code clean-up to remove superfluous #includes.
2.Fix two python test warnings, in which one relates to ["#"
formats](https://jira.mongodb.org/browse/PYTHON-2343), the other relates
to regular expression string usage.
2022-11-19 14:33:59 +00:00
donproc
afaf59b0c9 [TRITON-MLIR][BACKEND] Atomic support mask (#889)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-11-19 19:57:19 +08:00
Philippe Tillet
dab4855bdf [TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885) 2022-11-18 07:46:45 +01:00
goostavz
9ea6135eb5 [Triton-MLIR][Backend] Some cleanup in getMultiDimIndex/getLinearIndex (#880) 2022-11-18 01:19:21 +00:00
donproc
5eee738df7 [Triton-MLIR][FRONTEND] [BACKEND] fix atomics (#879)
minor fix to backend and frontend of atomics, we can pass 1 test without
mask and the shape aligned with CTA size now

Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-11-16 12:25:15 +08:00
Qingyi Liu
4c4159c6fa [Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875) 2022-11-15 01:27:32 +00:00
goostavz
c28cfd821b [Triton-MLIR][Backend] Fix convert_layout blocked->shared in non-default order (#876)
This PR fix the problem of TN/NT GEMM correctness when no SCF involved.
I'll continue to clean up getLinearIndex/getMultiDimIndex in a uniformed
way which should be benifical to avoid different kinds of order issues.
This is not fully done yet, just merge to sync the code.
2022-11-15 09:02:46 +08:00
Yan Chunwei
1eedaf7bec [Triton-MLIR][BACKEND] adapt DotOp layout for FMADot (#872) 2022-11-14 16:56:30 +08:00
Chenggang Zhao
516a241234 [Triton-MLIR] Fix some typos (#874)
Fix some typos
2022-11-13 18:15:53 -08:00
Philippe Tillet
f40c63fb03 [Triton-MLIR][OPTIMIZER] Cleaned up swizzling (#869)
Swizzling is no longer implemented as a separate pass. It is instead
done in a specialized constructor of SharedEncodingAttr, and tested via
google tests instead of triton-opt + filecheck.

In the future we may want to implement it as a pass again once we have
an additional dialect between TritonGPU and LLVM.
2022-11-10 12:05:46 -08:00
Philippe Tillet
2aa538ec2e [BACKEND] Added support for mma layouts in reductions (#863)
Validated hackily by manually modifying the reduction .ttgir in my local
cache. There will be a follow-up PR adding some better testing
infrastructure to test out conversions and reductions on arbitrary
layouts.
2022-11-10 09:58:07 -08:00
Chenggang Zhao
57fd1864a7 [Triton-MLIR] Support FP8 (#864)
Co-authored-by: Superjomn <yanchunwei@outlook.com>
2022-11-10 15:53:06 +08:00
Da Yan
4946167241 [Triton-MLIR] tt.dot operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
Co-authored-by: Superjomn <yanchunwei@outlook.com>
2022-11-10 05:57:27 +00:00
Yan Chunwei
0c87360657 [Triton-MLIR][Backend] Port FMADot conversion for DotOp (#844)
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
2022-11-09 12:57:50 +08:00
Yan Chunwei
de5b84c476 [Triton-MLIR][Backend] Fix mma<v2> int8 precision error (#850)
Fix mma.16816 s8 precision error

Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
2022-11-09 12:23:43 +08:00
Keren Zhou
2da71b2aaa [Triton-MLIR] Increase block size K to completely eliminate shared memory bank conflicts (#862) 2022-11-08 17:39:23 -08:00
goostavz
080b4addf8 [Triton-MLIR][Backend] Fix the order in linear/delinear and a few bugs in reduce conversion (#851)
1, fix the order in linearize/delinearize, which fix the error of order
in emitIndices;
2, fix the selecting of fast implementation in reduce codegen;
3, fix the redundant barrier in reduce codegen;
4, fix the index mapping of the second round of warp_shuffle in shuffle
version of reduce codegen.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-11-08 10:10:09 -08:00
Philippe Tillet
976cf12af1 [OPTIMIZER] Fixed memory coalescing (#847) 2022-11-07 06:22:18 -08:00
Philippe Tillet
b6f15e214b [FRONTEND] Fixed up type cast in atomics codegen (#853) 2022-11-07 05:46:24 -08:00
ben-zhang-609
84ad215268 [Triton-MLIR] Enable libdevice for ptx backend when has external functions. (#848)
At the phase from ptx to cubin, check whether llvm::Module has external
functions. if has, link with libdevice at:
https://github.com/openai/triton/blob/triton-mlir/python/triton/language/libdevice.10.bc
2022-11-07 08:01:50 +00:00
Keren Zhou
fdd59900f7 [Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features

- Allow taking a block of tensor slice, as long as each dimension is
contiguous (unit stride).
- Fix some problems in `insert_slice_async`'s semantic.
- More general verification for ops that return shared layout encoding.

## Known Limitations

- `insert_slice_async` still uses the old semantic. May submit another
PR later to support similar semantic like `tensor.extract_slice`.
- No encoding verification for `tensor.extract_slice`.
- 3d tensor ops are broken.
- Strided accesses are not allowed.
- May cause a little performance slowdown since we are passing strides
as values but not constants (e.g., int).
It would be difficult to pass strides as attributes when we have control
flows. A block argument is possible to accept tensors with different
strides.
2022-11-06 22:59:03 -08:00
Philippe Tillet
b6dbe959f0 [RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)
Also deprecates a couple of tests
2022-11-04 10:57:29 -07:00
Keren Zhou
4218e68d74 [Triton-MLIR] [Frontend] Return a scalar if all input args are scalar (#839) 2022-11-03 20:27:47 -07:00
Philippe Tillet
91a9773b38 [OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834) 2022-11-02 22:58:09 -07:00
ben-zhang-609
5feb6e24f9 [Triton-MLIR]Add ptx vprintf support (#825)
Not know how to write unit test for this feature.

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
2022-11-02 16:39:09 +08:00
Philippe Tillet
12d60cb4a3 [BACKEND] Added support for 1D conversion blocked -> slice (#831) 2022-11-01 13:19:58 -07:00
Chenggang Zhao
c9d84237e8 [Triton-MLIR][Frontend] Interface fixes for libdevice (#829)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-11-01 10:51:32 -07:00
Qingyi Liu
cdc0ec5077 [Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826) 2022-11-01 17:42:59 +08:00
Philippe Tillet
cb1b87a688 [FRONTEND] Made test_if/test_default pass (#823) 2022-10-30 15:32:55 -07:00
Philippe Tillet
e61dc75942 [FRONTEND] Fixed inliner and got more tests to pass (#822)
This adds a `DialectInlinerInterface` to the Triton dialect. This, along
with a few other minor semantic changes, fixes our tests on call
instructions. Also added the option to provide use an "LLVM_SYSPATH"
environment variable to link against locally build of LLVM; this was
useful for debugging this issue.
2022-10-30 14:10:02 -07:00
Philippe Tillet
7dfab26a39 [FRONTEND][BACKEND] Fixed various bugs (#819)
- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
2022-10-29 06:34:14 +00:00
Philippe Tillet
82834d34f9 [BUILD] No longer use include((HandleLLVMOptions) (#818) 2022-10-28 17:02:49 -07:00
Ian Bearman
f2106d0aa2 [BUILD] Fix Warnings and Enable Warnings as Errors (#794) 2022-10-28 12:36:09 -07:00
Philippe Tillet
ac0f6793cc [BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)
Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
2022-10-28 16:17:55 +08:00
ben-zhang-609
3685194456 [Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-10-28 05:26:29 +00:00
Keren Zhou
3b80801dff [Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)
1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
2022-10-27 22:09:06 -07:00
Qingyi Liu
42db3538e4 [Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)
What is done in this PR:
- [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA`
implementation for `SliceEncodingAttr`
- [x] Split `emitIndices` into two phases:
`emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout`
- [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation
- [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with
ptx instruction `shfl.sync`
- [x] Add support for scalar value in `StoreOpConversion`
- [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests

Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
2022-10-28 11:07:45 +08:00
Philippe Tillet
3e6cc6d66c [FRONTEND] Made more tests pass (#805) 2022-10-26 17:47:33 -07:00
goostavz
bb7008651a [Backend] Hacky fix of missing barrier in ConvertLayout blocked->shared (#803)
Barrier should be set by a separate pass, but it seems like there may be some bugs
2022-10-26 13:39:38 -07:00
Yan Chunwei
4dc2396ca0 [Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798)
This PR does

1. Support the case where $c holding a mma layout, this should be useful
in forloop in k-axis in GEMM
2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op]

Known issue

1. There is some IO conflict in GEMM with a k-forloop, it is temporarily
solved by [adding a
barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028)
in dot conversion, but we are still working on it, will get a more
generic fix for it in the following PR.
2. The parallel pass will result in a buggy instruction result type
```mlir
%1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", ""  : () -> !llvm.void
%1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr<f16, 3>
```
So we temporarily disable it.
2022-10-26 10:33:04 +08:00