116 Commits

Author SHA1 Message Date
Keren Zhou
733301ff31 [Backend] Rewrite code for linking external library to expose more inlining opportunities (#1037)
- Also make it cleaner. 
- And mark out the code needs to be fixed in `semantic.py`.
2023-01-08 13:44:29 -08:00
Keren Zhou
678b9f53a2 [Backend] Use post-order traversal for liveness numbering (#1027)
Also add tests for `tt.trans`.
2023-01-03 15:11:54 -08:00
goostavz
0e8590f1c9 [BACKEND] Add generic support of convert_layout from distributed to shared (#1025) 2022-12-30 11:29:58 -08:00
Philippe Tillet
20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00
Yang Hau
8650b4d1cb [DRIVER] Fix typos (#939) 2022-12-02 11:13:46 -08:00
Keren Zhou
db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00
Yu Guo
71b46acc42 [IR] Added special-purpose dequantize instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
2022-10-12 14:14:45 -07:00
Philippe Tillet
4a77dfb042 [FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
2022-09-18 08:51:48 -07:00
Shintaro Iwasaki
c668d6596e [DOCS] Fix spelling (#664)
This PR applies minor spelling fix in comments and string literals to
`master`. It shouldn't hurt anything.
2022-09-16 12:26:40 -07:00
Da Yan
437ced38c2 fp8 <> bf16 conversion (#637)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-08-30 14:20:12 -07:00
Daniil Fukalov
fe0c29b9ec Fix inconsistent struct declaration instead of class. (#632)
Looks like typo.
2022-08-26 16:20:21 -07:00
Da Yan
3e2953f357 Allow multiple_of and max_contiguous to accept n-d values (#617) 2022-08-10 09:59:32 -07:00
Keren Zhou
4912916c11 [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) 2022-07-13 15:52:21 -07:00
Philippe Tillet
5b4c8f221e [BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
2022-06-27 11:49:19 -07:00
Keren Zhou
b5e728cb14 Add argmin argmax (#552) 2022-06-15 13:55:20 -07:00
Keren Zhou
93209c07e0 [BACKEND][CODEGEN] Fix reduce uint (#547) 2022-06-13 16:43:57 -07:00
Philippe Tillet
8876e53206 [BACKEND] Restored reduction bugfixes 2022-06-03 11:38:52 -07:00
Philippe Tillet
a60374a597 Revert "[BACKEND] Various bug fixes; making reductions faster (#533)".
This is a more stable commit that produce bitwise identical code to earlier
versions. Using commits after this one may lead to slightly different numerics
2022-06-03 11:36:06 -07:00
Philippe Tillet
3e7500dfe6 [BACKEND] Various bug fixes; making reductions faster (#533) 2022-05-31 17:14:44 -07:00
daadaada
d5eaa8dfa0 Making the generated Triton IR deterministic & a script to compare cached assembly (#522) 2022-05-24 08:56:36 -07:00
Philippe Tillet
d35617bea1 [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) 2022-05-14 15:26:13 -07:00
Mengchi Zhang
d1a22a94e6 [FRONTEND] Add empty return value and remove protect to open the access to contained_tys_vec_t (#514)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-05-13 11:46:12 -07:00
Sriram Murali
7c9bc5a47b [CODEGEN] Change return type of generator::packed_type to appease build warnings (#507) 2022-05-04 20:03:37 -07:00
Philippe Tillet
bda209002e [BACKEND][CODEGEN] vectorization bugfix (#502) 2022-04-23 13:18:33 -07:00
Philippe Tillet
0cc3b1129b [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) 2022-04-21 23:56:01 -07:00
Philippe Tillet
9f08ecd684 [FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
2022-04-06 16:13:53 -07:00
Philippe Tillet
2bed6fc850 [LANG] Added support for device functions (#484) 2022-04-03 20:58:16 -07:00
Philippe Tillet
e0cc488055 [FRONTEND] Added tl.clock and tl.globaltimer (#485) 2022-03-28 16:15:43 -07:00
Philippe Tillet
76a9ee50a8 Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)
This reverts commit 539961072c.
2022-03-24 17:16:50 -07:00
daadaada
539961072c [FRONTEND] Semantic analysis refactor (#473)
Moved dispatch.cc to semantic.py
Integer signedness now moved from C++ to python
Cleaner frontend type

Co-authored-by: Phil Tillet <phil@openai.com>
2022-03-16 21:25:30 -07:00
Philippe Tillet
bb5765df5c [CODEGEN] Now padding shared memory for layout conversion (#468) 2022-03-03 22:19:05 -08:00
daadaada
d9dd97492f Use unique_ptr in ir::context_impl (#462)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-02-24 16:07:10 -08:00
Philippe Tillet
98ed7db8c1 [CODEGEN] Improvements and bugfixes (#463) 2022-02-24 14:56:24 -08:00
Philippe Tillet
807d8a1945 [ALL] Merge master (#447) 2022-01-30 20:21:20 -08:00
Philippe Tillet
bef76b142a [BACKEND] float division is now approximate by default (#446) 2022-01-29 18:29:29 -08:00
daadaada
59d371c6eb [BACKEND] Added Int8 mma (#440) 2022-01-27 09:12:44 -08:00
Benjamin Lefaudeux
3a23c1dd33 [BACKEND] minor, hotfix for gcc compilation (#439) 2022-01-23 14:24:02 -08:00
Philippe Tillet
4c94359199 [FRONTEND] Alignment fix-up (#428) 2022-01-11 23:11:58 -08:00
daadaada
94a2e10fe5 [BACKEND] Add bf16 & tf32 mma supports (on A100) (#426) 2022-01-11 10:20:31 -08:00
Madeleine Thompson
0ab9d67bad uint8, uint16, uint32, and uint64 in kernels (#413)
A forthcoming PR will update the RNG to use these types.

Also:
- Add tests for the `//`, `<<`, and `>>` operators.
- Change `TensorWrapper` to unwrap objects when the resulting object would be simpler.
- Clean up `throw_unreachable`, since it was triggering compiler warnings.
2022-01-05 15:27:17 -08:00
Philippe Tillet
03f1256f60 [FRONTEND] Added volatile flag for load (#407) 2021-12-30 22:33:24 -08:00
Madeleine Thompson
985798f101 add missing bfloat16 repr and improve assertions (#403)
- `BF16TyID` was missing a repr implementation.
- Throw a better exception on impossible casts.
- Add a few assertions. Tested with a debug build.
- Add `pointer_dtype.__str__` to aid kernel debugging.
2021-12-23 17:01:17 -08:00
daadaada
39d4bfed83 [OPS] Add performance model for gemm/gemv (#397)
Significantly improves the performance of `triton.ops.matmul` in memory-bound settings via the use of many more block configs coupled with a performance model to drive the auto-tuning process.
2021-12-21 09:56:10 -08:00
Madeleine Thompson
fa62b4a8f6 [FRONTEND] better stringification (#394)
- Don't override `self.args` in `CompilationError`, and show the line number and column in error messages. This causes it to generate an easier-to-read backtrace.
- Better `__str__` on `TensorWrapper`, `dtype`, and `block`.
2021-12-17 20:11:45 -08:00
Philippe Tillet
558555630f [FRONTEND] Added xor_sum 2021-12-16 17:55:35 -08:00
Philippe Tillet
5ce1b726dc [CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356) 2021-10-24 02:30:46 -07:00
daadaada
858dec8372 [CODEGEN] Add cache modifier to tl.load (#351)
* Add cache modifier to tl.load
* Add comment to cache_modifier
* Remove force_nc_cache
* Update test
2021-10-17 22:14:04 -07:00
Stephen McGroarty
c2e6b90ff1 [CODEGEN] Fixes masked load exception (#342) 2021-10-13 13:31:52 -07:00
Philippe Tillet
6e5b0b4301 [FRONTEND] Added on-disk cache for compiled kernels (#287) 2021-09-18 22:48:26 -07:00
Philippe Tillet
94c83d30ce [GENERAL] Removed deprecated driver files and added basic compatibility with rocm (#268)
- Removed driver module -- accelerator runtime is handled by pytorch
- Added basic support for ROCM based on @micmelesse 's PR -- now can execute empty kernel on AMD devices without any compile-time changes
- Now only using PREFER_SHARED for kernels when the size of shared memory is greater than 49k. Otherwise there can be poor L1 performance for broadcast tensors
2021-09-09 00:04:28 -07:00