Commit Graph

459 Commits

Author SHA1 Message Date
Philippe Tillet
b6f15e214b [FRONTEND] Fixed up type cast in atomics codegen (#853) 2022-11-07 05:46:24 -08:00
ben-zhang-609
84ad215268 [Triton-MLIR] Enable libdevice for ptx backend when has external functions. (#848)
At the phase from ptx to cubin, check whether llvm::Module has external
functions. if has, link with libdevice at:
https://github.com/openai/triton/blob/triton-mlir/python/triton/language/libdevice.10.bc
2022-11-07 08:01:50 +00:00
Keren Zhou
fdd59900f7 [Triton-MLIR] Replace triton.extract_slice with tensor.extract_slice and support more general tensor slicing (#837)
## Features

- Allow taking a block of tensor slice, as long as each dimension is
contiguous (unit stride).
- Fix some problems in `insert_slice_async`'s semantic.
- More general verification for ops that return shared layout encoding.

## Known Limitations

- `insert_slice_async` still uses the old semantic. May submit another
PR later to support similar semantic like `tensor.extract_slice`.
- No encoding verification for `tensor.extract_slice`.
- 3d tensor ops are broken.
- Strided accesses are not allowed.
- May cause a little performance slowdown since we are passing strides
as values but not constants (e.g., int).
It would be difficult to pass strides as attributes when we have control
flows. A block argument is possible to accept tensors with different
strides.
2022-11-06 22:59:03 -08:00
Philippe Tillet
b6dbe959f0 [RUNTIME] Re-vamped cache so users can manually patch IR / ptx / cubin files (#845)
Also deprecates a couple of tests
2022-11-04 10:57:29 -07:00
Keren Zhou
4218e68d74 [Triton-MLIR] [Frontend] Return a scalar if all input args are scalar (#839) 2022-11-03 20:27:47 -07:00
Philippe Tillet
91a9773b38 [OPTIMIZER] Minor bugfixes that affected matmul codegen performance (#834) 2022-11-02 22:58:09 -07:00
ben-zhang-609
5feb6e24f9 [Triton-MLIR]Add ptx vprintf support (#825)
Not know how to write unit test for this feature.

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
2022-11-02 16:39:09 +08:00
Philippe Tillet
12d60cb4a3 [BACKEND] Added support for 1D conversion blocked -> slice (#831) 2022-11-01 13:19:58 -07:00
Chenggang Zhao
c9d84237e8 [Triton-MLIR][Frontend] Interface fixes for libdevice (#829)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-11-01 10:51:32 -07:00
Qingyi Liu
cdc0ec5077 [Triton-MLIR][Backend] Fix reduce conversion and unit tests for int dtypes (#826) 2022-11-01 17:42:59 +08:00
Philippe Tillet
cb1b87a688 [FRONTEND] Made test_if/test_default pass (#823) 2022-10-30 15:32:55 -07:00
Philippe Tillet
e61dc75942 [FRONTEND] Fixed inliner and got more tests to pass (#822)
This adds a `DialectInlinerInterface` to the Triton dialect. This, along
with a few other minor semantic changes, fixes our tests on call
instructions. Also added the option to provide use an "LLVM_SYSPATH"
environment variable to link against locally build of LLVM; this was
useful for debugging this issue.
2022-10-30 14:10:02 -07:00
Philippe Tillet
7dfab26a39 [FRONTEND][BACKEND] Fixed various bugs (#819)
- Fixed bugs on layout conversions for int1 data (we should use int8
internally for int1 data to prevent llvm from using vec<i1> which has
different semantics)
- Fixed semantics of some casts to bool in the frontend
2022-10-29 06:34:14 +00:00
Philippe Tillet
82834d34f9 [BUILD] No longer use include((HandleLLVMOptions) (#818) 2022-10-28 17:02:49 -07:00
Ian Bearman
f2106d0aa2 [BUILD] Fix Warnings and Enable Warnings as Errors (#794) 2022-10-28 12:36:09 -07:00
Philippe Tillet
ac0f6793cc [BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)
Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
2022-10-28 16:17:55 +08:00
ben-zhang-609
3685194456 [Triton-MLIR][BACKEND] Add elementwise ops and tests (#804)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-10-28 05:26:29 +00:00
Keren Zhou
3b80801dff [Triton-MLIR][Backend] Fix many problems to get the pipeline working (#809)
1. Rewrite code generation of insert_slice_async.
2. Correct the wrong index passed to extract_slice in pipeline.
3. Add a prologue in pipeline to wait for dangling cp.asyncs.  
4. Move scf to cf conversion inside TritonGPUToLLVM because we need to
perform membar before scf to cf. It shouldn't be a technical limitation
and could be improved by a more general membar analysis.
5. Use an attribute to memoize the shared memory size and support
dynamic shared memory.
6. Prevent the combine pass to reorder insert_slice and extract_slice
across async_wait

Co-authored-by: Superjomn <yanchunwei@outlook.com>
2022-10-27 22:09:06 -07:00
Qingyi Liu
42db3538e4 [Triton-MLIR][Backend] Add ReduceOpConversion into TritonGPUToLLVM conversion (#774)
What is done in this PR:
- [x] Add `ConvertLayout`, `getSizePerThread` and `getShapePerCTA`
implementation for `SliceEncodingAttr`
- [x] Split `emitIndices` into two phases:
`emitBaseIndexForBlockedLayout` and `emitOffsetForBlockedLayout`
- [x] Add `ReduceOpConversion::matchAndRewriteBasic` implementation
- [x] Add `ReduceOpConversion::matchAndRewriteFast` implementation with
ptx instruction `shfl.sync`
- [x] Add support for scalar value in `StoreOpConversion`
- [x] Add Reduce1d and Reduce2d unit tests and pass all unit tests

Co-authored-by: Qingyi Liu <liuqingyi1993@gmail.com>
2022-10-28 11:07:45 +08:00
Philippe Tillet
3e6cc6d66c [FRONTEND] Made more tests pass (#805) 2022-10-26 17:47:33 -07:00
goostavz
bb7008651a [Backend] Hacky fix of missing barrier in ConvertLayout blocked->shared (#803)
Barrier should be set by a separate pass, but it seems like there may be some bugs
2022-10-26 13:39:38 -07:00
Yan Chunwei
4dc2396ca0 [Triton-MLIR][BACKEND] Support $c from mma layout in dot (#798)
This PR does

1. Support the case where $c holding a mma layout, this should be useful
in forloop in k-axis in GEMM
2. Fix the `unrealized_conversion_cast` in ConvertLayout[shared->dot_op]

Known issue

1. There is some IO conflict in GEMM with a k-forloop, it is temporarily
solved by [adding a
barrier](https://github.com/openai/triton/pull/798/files#diff-8a9a5a7f4a025fb1299af29d190d5626bd9000406d3ea47c49679272d3d6abe9R3028)
in dot conversion, but we are still working on it, will get a more
generic fix for it in the following PR.
2. The parallel pass will result in a buggy instruction result type
```mlir
%1049 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.commit_group ;", ""  : () -> !llvm.void
%1050 = builtin.unrealized_conversion_cast %1049 : !llvm.void to !llvm.ptr<f16, 3>
```
So we temporarily disable it.
2022-10-26 10:33:04 +08:00
Philippe Tillet
a2cbe7af91 [FRONTEND] Enhanced support for binary operators (#801)
Disabled modulo test (due to change in behavior for `frem` in nvptx
between llvm-11 and llvm-14) and bfloat16 (will require some work to
emulate in software similar to how it's done in `master`)
2022-10-24 19:47:01 -07:00
Philippe Tillet
fcb228d1d4 Merge select commits from master branch into triton-mlir (#799)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: vesuppi <zt9465@gmail.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: daadaada <dyanab@connect.ust.hk>
Co-authored-by: Anton Kostin <masguit42@users.noreply.github.com>
Co-authored-by: Yunxing Dai <nov503@gmail.com>
Co-authored-by: Shintaro Iwasaki <shintaro.iwasaki.work@gmail.com>
2022-10-24 14:52:37 -07:00
Philippe Tillet
3aa8296b06 [BUILD] Download pybind11 in setup.py (#703) (#797)
Cherry-picks #703 and resolves conflicts

Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
2022-10-23 18:52:48 -07:00
Yan Chunwei
1bf59d315c [Triton-MLIR][FRONTEND] Remove the dangling check-triton call in setup.py (#796) 2022-10-23 18:26:18 -07:00
Philippe Tillet
bb0f9235d1 [OPTIMIZER] Made layout simplification pass efficient for fused attention kernels (#790) 2022-10-21 16:52:15 -07:00
goostavz
c4726333bf [Triton-MLIR] Minor fixes related with scf/swizzling support (#791)
1, Disable static loop unrolling in the frontend by default;
2, A minor fix in axisAnalysis in order to support scf;
3, A minor fix in TritonGPUToLLVM to support swizzling.
2022-10-21 11:46:28 +08:00
Philippe Tillet
dc0588a898 [OPTIMIZER] Improved layout simplification pass so it handles swizzled layouts better (#789)
Note: uncommented `test_gemm`, since backend has an issue with swizzling. This will get uncommented in a subsequent PR.
2022-10-20 19:03:37 -07:00
Shintaro Iwasaki
0d22d2bc03 [TritonMLIR] Disallow 0D tensor (#788) 2022-10-19 10:34:32 -07:00
Yan Chunwei
4464646efb [Triton-MLIR][BACKEND] Fix masked load store op vector size (#785)
Correct the Load/Store Op's vector size with the mask's alignment
correctly considered.

Some cases:

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
    // mask = make_range(128) < n_element
}
```
This should get the vec=2 `ld`/`st` instructions.

While the following example

```mlir
// num_warp = 2
// block_size = 128
func @vecadd_mask_align_16(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, 
  %out_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
    // mask = make_range(128) < n_element
}
```
it should get the vec=1 `ld`/`st` instructions.
2022-10-18 11:43:50 +08:00
Philippe Tillet
623c99609f [Triton-IR] Added type inference and verifier for Triton-IR operations (#767) 2022-10-11 18:16:41 -07:00
Yan Chunwei
555f94f9b9 [triton-mlir][BACKEND] Support masked load/store (#657)
This PR does

- fix some bugs to support masked load/store,
- refine frontend, and support the `and` and `or` syntax in mask(by
extending the BoolOp in python ast.visitor), e.g. `tl.store(...,
mask=offset<n and other_conditions)`,
- add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by
mask),
- add more test cases in vecadd.
2022-10-10 13:29:53 +08:00
Ian Bearman
448d14a598 [BUILD] Add TRITON Prefix to build variables (#752) 2022-10-09 10:55:17 -07:00
goostavz
1d772cd843 [Triton-MLIR][Backend] Add SCF lowering in the backend (#750) 2022-10-08 18:36:37 +08:00
goostavz
f9d7f2f126 [Triton-MLIR][Backend] Support ConvertLayout blocked->shared and a few fixes related with mma(#716) 2022-10-03 19:33:25 +08:00
goostavz
61b61755e5 [Triton-MLIR][Backend] Support layout conversion between mmaLayout and blockedLayout (#693) 2022-09-27 03:58:47 +00:00
Philippe Tillet
1e91ed30d0 [RUNTIME] Major code cleanup (#711)
This PR does the following:
- CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore.
- Refactoring driver/llvm.cc to split it between PTX codegen and python.
- By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore.
- `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
2022-09-26 16:38:06 -07:00
Philippe Tillet
22ec22c257 [FRONTEND] Backport new runtime from master (#706)
This PR merges the new runtime back into the `triton-mlir` branch. This
adds caching and just-in-time compilation functionality to the
triton-mlir project, and paves the way for re-using tests from the
master branch.
2022-09-23 16:09:43 -07:00
Philippe Tillet
c56f0198dd Revert "[Triton-MLIR][pybind11] Update pybind11 to 2.10.0" (#702)
Reverts openai/triton#694
2022-09-23 12:31:33 -07:00
Shintaro Iwasaki
23f424c660 [Triton-MLIR][pybind11] Update pybind11 to 2.10.0 (#694)
This PR applies #691 to the Triton-MLIR branch.
2022-09-22 17:53:42 -07:00
goostavz
15bfd0cb79 [BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658) 2022-09-17 14:58:42 -07:00
Shintaro Iwasaki
13669b46a6 [DOCS] Correct spelling (#665)
This PR corrects spelling like #664 for Triton-MLIR. It should not break anything.
2022-09-16 15:07:34 -07:00
Shintaro Iwasaki
e9e1a4e682 [FRONTEND] Fix the implicit broadcasting rule (#663)
This PR solves the cast issue that appears in some tutorial code.
2022-09-16 10:49:15 -07:00
Philippe Tillet
80e3fb5270 [CI] Now using clang-format from pip (#662) 2022-09-15 16:24:37 -07:00
Shintaro Iwasaki
43be75ad42 [FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
2022-09-15 16:12:52 -07:00
Yan Chunwei
a9464f4993 [Backend] Vectorize Load/Store Ops (#86)
This PR does the following things:

- Code refactoring on Load and Store op codegen, rewrite with same logic
and share much code
- Support the vectorized load/store
2022-09-06 12:28:09 -07:00
Shintaro Iwasaki
3c635449e5 [Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) 
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
  - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
  - No constant folding etc for `libdevice` ops.

```py
import triton
import triton.language as tl
import sys

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    x = tl.sin(x)
    output = tl.libdevice.sin(x)
    output = tl.libdevice.fdiv_rn(output, output)
    output = tl.libdevice.fmaf_rd(output, output, output)
    tl.store(y_ptr + offsets, output)


if __name__ == "__main__" and len(sys.argv) >= 2:
    signature = "*fp32,*fp32"
    constants = {'BLOCK_SIZE': 1024}
    output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
    print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
    %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
    %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
    %4 = math.sin %3 : tensor<1024xf32, #blocked>
    %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.store %9, %7 : tensor<1024xf32, #blocked>
    return
  }
}
```
2022-09-01 16:34:27 -07:00
Shintaro Iwasaki
d01353de07 [CI] add assert-enabled MLIR option (#78)
This deprecates the use of release-build LLVM hosted by the LLVM project, which makes debugging harder for developers.

This PR implements the following solution:
1. Create LLVM release tarballs with assert enabled on our own (using Docker)
2. Host them in our own GitHub repositories
3. Use our LLVM for CI and/or development if `TRITON_USE_ASSERT_ENABLED_LLVM=1` is set.
2022-08-31 18:55:32 -07:00
goostavz
bedbf221c0 [BACKEND] Support optional mask in TritonGPUToLLVM (#80)
Co-authored-by: gzhu <gzhu@nvidia.com>
2022-08-24 17:51:37 -07:00