Commit Graph

759 Commits

Author SHA1 Message Date
Yan Chunwei
24fd953f9a [BACKEND] Refine v100 tests and fix mmav1 numwarps>1 hang issue (#971)
This PR

- Fix numWarps>1 hang issue
- add existing test cases in test_gemm.py to CI, and add a common flag
`valid_on_Volta` to determine whether the test case should be activated
on Volta or just skip.
  - Currently, the column-major cases are disabled.
 - Add test_core.py and other tests to Volta CI
   - the `test_printf.py` failed.
2022-12-09 07:41:22 -08:00
goostavz
793012b4c4 [Triton-MLIR][Backend] Fix mmav1 in case of numWarps > 1 (#972) 2022-12-09 18:36:05 +08:00
Keren Zhou
3ed36dcb4d [BACKEND] MMA->DotOperand conversion for chain dot of float32 tensors (#962)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-12-08 20:11:51 +00:00
Keren Zhou
83f3b9165b [FRONTEND][BACKEND] Fix bool and int8 load when the other operand is given (#968) 2022-12-08 11:52:18 -08:00
Keren Zhou
71c35bcf9c [Triton-MLIR][BACKEND] Mark float to integer in Arithmetic Dialect as legal (#963) 2022-12-08 09:07:01 -08:00
Dongdong Li
c7cf9c6a32 [TRITON-MLIR][BACKEND]fix atomic_rmw for vector (#966)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-08 20:03:40 +08:00
Yan Chunwei
f0885e9caf [Triton-MLIR][BACKEND] Tiny patch for MMAv1 and code clean (#964)
This PR:

- Several fix on MMAV1 code
- Remove the env `TRITON_STATIC_LOOP_UNROLLING` in v100 CI since the
pipeline pass works now
- some code clean
2022-12-08 16:39:32 +08:00
Keren Zhou
18e683d9bb [Triton-MLIR][BACKEND] Pass compute capability from the frontend and code cleanup (#961) 2022-12-07 15:03:46 -08:00
Yan Chunwei
4eab9dcedf [Triton-MLIR][BACKEND] make MMAv1 splitk works (#960) 2022-12-07 08:58:38 +00:00
Philippe Tillet
b2b793dfb5 [FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)
Most notably, this PR:
- changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width.
- adds support for `cat`
2022-12-06 23:29:50 -08:00
Philippe Tillet
981aee7f1e [FRONTEND] Frontend fixes for uint / for loops / random (#958) 2022-12-06 20:25:47 -08:00
Philippe Tillet
115cd3ac47 [FRONTEND] Added reshape as an alias for view (for now) (#956) 2022-12-06 09:57:05 -08:00
Philippe Tillet
532e10cf87 [FRONTEND][BACKEND] Clean-up transpositions (#953) 2022-12-06 09:32:13 -08:00
Keren Zhou
16e973edf2 [BACKEND] Fix dependency analysis in pipeline (#946) 2022-12-06 09:08:55 -08:00
Crutcher Dunnavant
9490252261 [FRONTEND] Support alternative install locations of system libdevice.10.bc (#951) 2022-12-06 03:41:44 +00:00
Yan Chunwei
e419781978 [Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO:

- Add more cases
- Currently, we just set vec to 4 to make the basic cases pass

Issue:

- the vec in shared layout is different compared to master branch
- when vec=1, it encounters CUDA misalignment error, it doesn't work in
master branch as well
- when setting vec to the value identical to master branch, the MMA
works
2022-12-06 10:57:08 +08:00
Crutcher Dunnavant
189491727a [FRONTEND] Extract and unify @builtin/@extern (#913)
This change attaches builtin-ness as an explicit attribute, rather than
a module prefix expectation. This permits us to source those builtins
from multiple sub-modules (useful when some builtins are part of the
true cyclic implementation core, and some are just useful library
additions); but also prevents accidental inclusion of non-builtins that
happen to be in the right library.

Once the flag exists, and the compiler is using `is_builtin()` for
decision making; the existence of the current `@extern` interface
becomes isomorphic to `@builtin`; and the interface can be unified.

Leaving `@extern` a thin-wrapper, and encouraging continued use of it,
establishes future-proofing towards adding additional extern tracing,
metric hooks, or scanning in the future.

* Add `triton.impl` package to hold the core, order dependent impl
details.
 * Extract `@builtin` and unify `@extern`; add `is_builtin()`
   * Add sense bit so that `@builtin` detection is less fragile.
 * Modify the compiler to use `is_builtin()`
2022-12-05 22:59:41 +00:00
Crutcher Dunnavant
e0072d210a [FRONTEND] Propagate mypy types through @jit, @builtin, etc (#915)
Changes to make decorated API methods no longer type-opaque.

```
$ echo 'import triton; reveal_type(triton.language.max)' | mypy /dev/stdin
/dev/stdin:1: note: Revealed type is "def (input: Any, axis: Any, _builder: Any =) -> Any"
Success: no issues found in 1 source file
```
2022-12-05 22:41:02 +00:00
Crutcher Dunnavant
2fa17588f7 [FRONTEND] Expand __init__ * imports, add __all__ (#912)
Expand `from .foo import *` to full listings, and `__all__` sections.

This reifies the module export listings, which is useful for code
importing this module; without this, clients will need special `mypy`
control pragmas for this library.

This removes a number of `# flake8` control pragmas.

Verified with `flake8`
2022-12-05 14:22:55 -08:00
goostavz
e057c65cf0 [BACKEND] Porting the legacy heuristic rule in assigning shared layout for A/B of MMAv1 (#948) 2022-12-05 11:30:23 -08:00
Philippe Tillet
99c7e0e008 [BUILD] Change default build type (#945) 2022-12-03 17:47:33 -08:00
Keren Zhou
f2fcaeabf3 [BACKEND] Support dot op when the output is mma encoding and allowtf32 is true (#937) 2022-12-03 19:14:12 +00:00
Philippe Tillet
8edfe813a5 [FRONTEND][BACKEND] Added trans instruction; made flash attention bwd pass work (#943) 2022-12-03 09:58:24 -08:00
goostavz
4d64589b22 [Triton-MLIR][Backend] Fix the definition of MmaEncodingAttr v1, and the output sequence of DotConversion in MMAv1 (#941) 2022-12-03 21:12:48 +08:00
donproc
521ff9ad74 [TRITON-MLIR][FRONTEND]fix scf.if to run through layernorm tutorial (#938)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-02 17:45:29 +08:00
Keren Zhou
c280ebda1b [Triton-MLIR][BACKEND] Fix the membar pass to add missing barriers caused by scf.for (#933)
1. Add missing barriers and revert the previous temporary solution
2. Extract the `run` method from membar analysis because the membar
analysis should have two phases, including construction, which doesn't
modify any IR, and modification, which adds barrier IRs. Hope this could
make the use of membar clear.
2022-12-01 11:54:18 -08:00
donproc
9def1bcebf [TRITON-MLIR][FRONTEND]minor fix to run through atomic_cas test (#925)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-01 13:43:26 +00:00
Keren Zhou
7d90a07d0b [Triton-MLIR][BACKEND] Refactor decompose insert_slice_async (#929)
1. Improve pipline's comment
2. Decompose insert_slice_async when load vector size is not supported
3. Add a test that could fail our gemm code

Copy my comments here:

There's a knob that may cause performance regression when decomposition
has been performed. We should remove this knob once we have thorough
analysis on async wait. Currently, we decompose `insert_slice_async`
into `load` and `insert_slice` without knowing which `async_wait` is
responsible for the `insert_slice_async`. To guarantee correctness, we
blindly set the `async_wait` to wait for all async ops if any `insert_slice_async` has been decomposed.

There are two options to improve this:
1. We can perform a dataflow analysis to find the `async_wait` that is
responsible for the `insert_slice_async` in the backend.
4. We can modify the pipeline to perform the decomposition before the
`async_wait` is inserted. However, it is also risky because we don't
know the correct vectorized shape yet in the pipeline pass. Making the
pipeline pass aware of the vectorization could introduce additional
dependencies on the AxisInfoAnalysis and the Coalesce analysis.
2022-11-30 10:07:34 -08:00
Philippe Tillet
6461254fb5 [BACKEND] Make flash attention forward pass work (#928)
This also simplifies BroadcastOp codegen
2022-11-30 10:13:24 +00:00
goostavz
4e6a8209ed [Triton-MLIR] Two fixes on allocation and backend related with MMA v1 (#930) 2022-11-30 09:27:26 +00:00
Philippe Tillet
9bb54402b3 [FRONTEND][BACKEND] Small fixes to multiple_of, num_programs, axisinfo; enable block-sparse tests (#927) 2022-11-29 20:00:34 +01:00
Philippe Tillet
66c36c4378 [BACKEND] Fixed bounds-wrapping issues (#926)
This fixes an issue that led to out-of-bounds shared memory accesses on
small matrices
2022-11-29 17:56:45 +01:00
Qingyi Liu
661be523c0 [Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924) 2022-11-29 11:50:31 +08:00
Yan Chunwei
c87fbf886e [Triton-MLIR][BACKEND] Remove static and unnamed namespace in Utility.h (#923)
Reference
https://wiki.sei.cmu.edu/confluence/display/cplusplus/DCL59-CPP.+Do+not+define+an+unnamed+namespace+in+a+header+file
2022-11-29 01:06:06 +00:00
goostavz
0c1d4d764e [Triton-MLIR][Backend] support MMA v1 in ConvertLayout (#922)
The e2e verification of mma v1 is not done yet. 
Get this merged in advance just to prevent more conflicts.
2022-11-28 08:10:30 +00:00
Qingyi Liu
9d31998a9d [Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918) 2022-11-27 22:59:27 -08:00
Yan Chunwei
04ec5deb41 [Triton-MLIR][BACKEND] decouple the dot code (#921)
This PR
- apply minimal modification to decouple the Dot helper related code
from TritonGPUToLLVM.cpp to a separate local header file to make it
easier to share some data structure for Dot
- add some patch necessary for transA and transB
- add some patch necessary for MMA v1 execution in backend
2022-11-28 13:30:27 +08:00
goostavz
630dc315ee [Triton-MLIR] uncomment the UT in test_gemm that has already been fixed (#920) 2022-11-28 11:23:20 +08:00
Keren Zhou
35c9ec1103 [Triton-MLIR][Backend] Fix number of warps and threads per warp when matrices are small (#917) 2022-11-26 12:30:38 -08:00
donproc
f63be0e9b5 [TRITON-MLIR][BACKEND]support atomic_cas (#914)
1. support atomics-cas
2. add xchg support in atomic_rmw

Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-11-25 12:02:08 +08:00
Keren Zhou
153aecb339 [Triton-MLIR][BACKEND] insert_slice_async on GPUs < sm80 (#908)
`insert_slice_async` is decomposed into `load + insert_slice` in the
backend.

Not sure if V100 perf can match the master branch though in this way.
Maybe the performance can be improved if instructions are arranged in
the following form:

```
%0 = load
%1 = load 
%2 = load 
...
insert_slice %0
insert_slice %1
insert_slice %2
```

Tested on A100 when manually enabling this decomposition.
Tests on V100 haven't been integrated yet, we can divide the tests into
two phases:
1. Test only load, insert_slice, and insert_slice_async, given TritonGPU
IRs in `test_backend.py`.
2. End to end gemm tests on V100.
2022-11-24 14:05:54 -08:00
Crutcher Dunnavant
f98aed1258 [Triton-MLIR][RUNTIME] Add /usr/bin/ptxas as a search path (#909)
Make `ptxas` search a bit broader to include `/usr/bin/ptxas`, installed
by the lambda stack repo versions:
https://lambdalabs.com/lambda-stack-deep-learning-software
2022-11-24 18:49:16 +00:00
Crutcher Dunnavant
ace7d28736 [Triton-MLIR][RUNTIME] Fix ir metadata lookup bug (#910) 2022-11-24 09:27:23 +01:00
ben-zhang-609
b688f7b7b8 [Triton-MLIR] add_volta_warpsPerTile (#907) 2022-11-24 01:44:29 +00:00
donproc
8925c2cd11 [TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar

Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-11-23 07:59:09 +00:00
Keren Zhou
2e33352419 [Triton-MLIR] Fix side effects (#906)
Try to add proper side effects for triton operations. 

The CSE pass could fail, hang, or output incorrect IRs for unknown
reasons, if side effects are not defined properly.

For instance, suppose we have two shared memory tensors:

```
%a = triton_gpu.alloc_tensor shape0, share_encoding0
%b = triton_gpu.alloc_tensor shape0, share_encoding0
```

The CSE pass will consider `%a` and `%b` are the same thing and
eliminate one of them, resulting in mysterious outcomes.
2022-11-22 23:29:18 -08:00
Yan Chunwei
037f9efa95 [Triton-MLIR][BACKEND] Fix wpt overflow issue in mma v2 (#904)
This PR

1. Fix wpt overflow issue in mma v2
2. Refine transpose logic
2022-11-23 11:27:15 +08:00
ben-zhang-609
07786dc932 [Triton-MLIR] Add compute capability (#902)
add compute capability from python frontend to backend.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
2022-11-22 11:08:23 -08:00
Keren Zhou
2afebcd79b [Triton-MLIR][Backend] Remove unnecessary barriers (#901)
Cross operation barriers are taken care of by the Membar pass. 

Explicit barriers are only required if there's any synchronization
necessary within each operation.
2022-11-22 10:03:29 -08:00
Yan Chunwei
136668bac3 [Triton-MLIR][BACKEND] tiny code cleanup (#899)
- Remove the unnecessary `static` in the anonymous namespace
- Remove several unnecessary functions
- Several simple rewrites to make code more clear
2022-11-21 16:00:46 +08:00