166 Commits

Author SHA1 Message Date
Da Yan
0f5c6e619c [BUILD] Add the missing triton/impl to setup.py (#1042) 2023-01-09 19:03:45 +00:00
Connor Baker
c20215dad1 [FRONTEND] Update PTX/SM support for LLVM14 (PR #1038 redux) (#1039)
=
2023-01-09 10:31:55 -08:00
Keren Zhou
733301ff31 [Backend] Rewrite code for linking external library to expose more inlining opportunities (#1037)
- Also make it cleaner. 
- And mark out the code needs to be fixed in `semantic.py`.
2023-01-08 13:44:29 -08:00
Shintaro Iwasaki
ff399fbc20 [Build] Support GCC 8.x to build Triton (#1036) 2023-01-06 19:36:14 -08:00
Keren Zhou
4023149ee3 [Frontend] Convert constexpr to value for store and load ops (#1030)
Fixing problem 2 in https://github.com/openai/triton/issues/1017

Co-authored-by: Philippe Tillet <phil@openai.com>
2023-01-05 14:40:16 -05:00
Gregory Axler
2193bee94e [Example] Fix the compile function in copy_strided.py (#1029) 2023-01-05 10:37:41 -08:00
Sophia Wisdom
411bacb2a8 [FRONTEND] Add logical operations on constexprs (#1033) 2023-01-04 18:06:32 -08:00
Sharad Vikram
bc73bbb12c [FRONTEND] Fix argmin/max output type (#1012)
Currently Triton returns tensors with the input types rather than i32
when doing reduce argmax/argmin.
2023-01-03 23:12:16 -08:00
Keren Zhou
8460ea3df1 [Frontend] Fix import for libdevice (#1028)
This is a hotfix for issue 1 in
https://github.com/openai/triton/issues/1017
2023-01-03 15:48:05 -08:00
Keren Zhou
678b9f53a2 [Backend] Use post-order traversal for liveness numbering (#1027)
Also add tests for `tt.trans`.
2023-01-03 15:11:54 -08:00
goostavz
0e8590f1c9 [BACKEND] Add generic support of convert_layout from distributed to shared (#1025) 2022-12-30 11:29:58 -08:00
fdrocha
194ba103b1 [BUILD] Fixed error when compiling in systems with multiple versions of python installed (#1019) 2022-12-29 15:10:34 -08:00
goostavz
1d3029faf8 [Backend] Add value cache in emitting indices calculation and some refinement (#1018)
1, add explicit value cache in emitting indices calculation;
2, move the indices calculation emitting logics into
ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by
templates. Refer to the discussion in this thread by @LyricZhao :
https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969
2022-12-29 11:19:59 -08:00
Yan Chunwei
2ba74d2729 [OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014)
Continue the work https://github.com/openai/triton/pull/990

# Background
The `versionMinor` in MmaEncodingAttr holds some states of DotOp's
operands in Volta, while such operands will be modified by some
patterns, making the states out-of-date.

This PR helps to correct the states.

# Implementation
It adds three new patterns:

1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding
the MmaEncodingAttr instances with wrong states and create new correct
ones for them,
2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating
the wrong MmaEncodingAttr instances with new correct ones, currently it
supports the following Ops
    a. `convert_layout[X -> mma]`
    b. `arith.constant SplatAttr : !tensor<mma>`
    c. `dot ... : !tensor<mma>`

# Limitation
This PR chooses the mapping way to bypass the IR walk complexity from
the circular dependency between dot_operand[parent] and mma.
We use the MmaEncodingAttr instance as the mapping key, but there might
be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the
same wrong MmaEncodingAttr instance.
To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID
field to MmaEncodingAttr.
2022-12-28 12:24:01 +08:00
Keren Zhou
fd2da4aff6 [BACKEND] Support splat constant on the DotOperandLayout (#1008) 2022-12-22 00:48:46 -08:00
Sharad Vikram
925d3d7f98 [FRONTEND] Export broadcast and broadcast_to in triton.language (#1007) 2022-12-22 01:57:33 +00:00
Keren Zhou
b5aafb0dab [FRONTEND] Fix 3d indexing (#1006) 2022-12-21 12:52:32 -08:00
Philippe Tillet
20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00
Yang Hau
8650b4d1cb [DRIVER] Fix typos (#939) 2022-12-02 11:13:46 -08:00
Crutcher Dunnavant
44f577984d Fix format double substitution bug: {i} => {{i}} (#886)
The previous `{i}` was silently expanding to the `i` from the
enumeration loop on `regular_args` (when it wasn't empty).
2022-11-20 11:44:42 -08:00
Crutcher Dunnavant
0e4691e6dd [FRONTEND] Fix ExternLibrary(format=) bug; type annotate build_extern.py (#883)
Ran mypy over `build_extern.py`, cleaned up type annotations.

Found a fixed a bug where `ExternLibrary(format=)` was being ignored.
2022-11-17 18:45:30 +01:00
Natalia Gimelshein
0d7e753227 [TESTING] use torch.int for autotuning cache (#840)
For stupid reasons, ops on int8 are 3 times slower than on int, and for
another set of stupid reasons we are not using cudaMemset for `zero_`,
so using `int8` buffer in `do_bench` makes it slow.

Co-authored-by: Philippe Tillet <phil@openai.com>
2022-11-04 18:05:16 -07:00
Shintaro Iwasaki
77bc5187b5 Better NVIDIA Pascal GPU Support (#827)
This PR clarifies which features are supported on P100 via its tests,
though Pascal is not officially and fully supported by Triton.

## What this PR does

- Skip unsupported tests on P100.
  - Atomic RMW
- `tl.dot()` (perhaps not all patterns, but basically most `tl.dot()`
tests do not work on P100).
- Add an explicit error if shared memory size >= 64K on P100.
- Otherwise it causes `Invalid CUDA argument` error at
`cuLaunchKernel()`, but this error is not very straightforward to
understand. Instead of this generic CUDA argument error, this PR makes
Triton show an error during codegen when `sm < 70`. This check happens
in C/C++ so won't add an overhead in Triton's Python runtime.
- 3 tests (see below) are currently failing, but these are not marked as
skipped because any codegen update in the future can change the kernel
size of the other tests.
- This change won't affect Triton-MLIR. Hopefully Triton-MLIR's generic
`tl.dot()` implementation would support P100.

Importantly, Triton passed all the other tests on P100. Though this
support is not official, it is great for, for example, PyTorch's
TorchDynamo/Inductor, which can use Triton (without `tl.dot()`) for its
backend (https://github.com/pytorch/torchdynamo/issues/1591).

### Results on P100 (Google Cloud)

```sh
$ pytest test/unit
...
================================================================================== short test summary info ==================================================================================
FAILED test/unit/language/test_core.py::test_reduce2d[argmin-float32-shape99-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_reduce2d[argmax-float32-shape113-1] - RuntimeError: Device does not support shared memory of 65536bytes
FAILED test/unit/language/test_core.py::test_permute[float32-shape5-perm5] - RuntimeError: Device does not support shared memory of 67584bytes
================================================================== 3 failed, 3824 passed, 952 skipped in 470.90s (0:07:50) ==================================================================
```

<details><summary> <b>Environment Details (collapsed)</b></summary>
<p>

### VM details (Google Cloud)
https://cloud.google.com/
```
# You need a paid account (free trial does not cover GPUs)
Google Cloud -> New Project -> Compute-Engine -> VM Instance
Machine:
GPU: NVIDIA Tesla P100 x 1
CPU: 2 vCPUs, 7.5GB memory
Boot disk:
  OS: Ubuntu 18.04 LTS
  Disk: 40GB (cannot build Triton on the default 10GB disk)
- When I tried, about $1.2 per hour.
- US instances were full when I tried.  I used Asia or Australia.
- Needed a paid account (GPU is not covered by free trial)
- Needed quota request for any GPU instance (by default, no GPU instance is allowed).  Needed to wait an hour for approval
```

### Reproducer
```sh
## 1. Install CUDA and a driver
# Update the apt key (https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key/)
sudo apt-key del 7fa2af80
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
# Download CUDA as instructed
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda
# Are you using P100?
nvidia-smi | grep "Tesla P100"

## 2. Setup the build environment
sudo apt update
sudo apt install -y build-essential wget git libz-dev
wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
bash Anaconda3-2022.05-Linux-x86_64.sh -b -p $(pwd)/anaconda3
eval "$($(pwd)/anaconda3/bin/conda shell.bash hook)"
conda create -y --name triton_base
conda activate triton_base
conda install -y cmake setuptools

## 3. Build Triton
git clone https://github.com/openai/triton.git
cd triton/python
pip3 install -e '.[tests]'

## 4. Test
pytest test/unit
```

### Environment
```sh
$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 250W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
```

</p></details>
2022-11-03 00:11:52 -07:00
Chenggang Zhao
f16138d447 [Frontend] Interface fixes for libdevice (#830)
- Unifying several interfaces with different types to a single one, e.g.
`fsub_ru` and `dsub_ru` -> `sub_ru`;
- Minor bug fix: `fast_pow` is incorrectly classified into the `pow`
interface, of which arguments are the same as `powf`;
- Explicit interfaces for casting functions, e.g. decoupling
`ll2float_ru` to `ll2float_ru` and `ull2float_ru`;
- Removing interfaces that are not in NVIDIA's official documents, e.g.
`fmaf_ieee_rn`, which is confusing together with `fmaf_rn`.

Note that this PR for the master branch is different from #829, which is
for the MLIR branch.
2022-11-01 10:51:58 -07:00
Mark Saroufim
578ada7740 [DOCS] Add install from source instructions to README (#821) 2022-10-31 11:08:18 -07:00
Phil Tillet
6311d70406 Revert "[BUILD] Now using cibuildwheel default"
This reverts commit 584086f08c.
2022-10-29 17:15:47 -07:00
Phil Tillet
584086f08c [BUILD] Now using cibuildwheel default 2022-10-29 16:59:06 -07:00
Keren Zhou
3ca667dfa8 [Frontend] Return a scalar if all input args are scalar (#816) 2022-10-28 23:27:06 -07:00
Yanbo Liang
5ca1ed0101 Add bf16/fp16/fp64 support for ty_to_cpp (#800)
In ```torch._inductor```, we [convert 0d CPU tensor to scalar during
triton codegen](https://github.com/pytorch/pytorch/pull/87329), so need
add missing triton support for bf16/fp16/fp64.
2022-10-24 19:41:25 -07:00
Keren Zhou
db3aa1d1fb [FRONTEND] Fix libdevice (#776)
Fix two problems in libdevice and external dispatch:

1. Use static triton types (e.g., tl.int32) instead of creating new
types. Otherwise, `tl.int32` and `tl.dtype('int32')` are not the same
thing.

2. The name of an extern inst should be empty but not the symbol name of
the inst. TTIR generator will assign names automatically. Otherwise, we
have the same variable name when there are multiple same extern insts.

Before the PR:

```bash
  __nv_exp = extern_elementwise f64<1024> %11;
  __nv_exp = extern_elementwise f64<1024> %11;
```

After the PR:

```bash
  %12 = extern_elementwise f64<1024> %11;
  %13 = extern_elementwise f64<1024> %11;
```
2022-10-13 17:18:16 -07:00
Twizzes
ddae106c0e [DOCS] Update installation.rst to fix windows build error (#747) 2022-10-13 13:27:15 -07:00
Keren Zhou
bc98aead33 [Backend] Fix for mov.u8 (#766)
Init a potential fix for mov.u8 which is not supported by ptx for now.
Use mov.u16 instead and cast it to u8.
2022-10-12 14:32:27 -07:00
Yu Guo
71b46acc42 [IR] Added special-purpose dequantize instruction (#759)
It is currently necessary for optimal performance in quantized workloads to add a special-purpose instruction in the IR. Backward compatibility with this instruction is *NOT* guaranteed.
2022-10-12 14:14:45 -07:00
Philippe Tillet
33e6f0df7f [DRIVER] Bumped CUDA requirement to 11.4+. This is to avoid bad performance surprises as older ptxas are much slower. (#769)
This also makes codegen simpler by avoiding special handling of eviction policies
2022-10-12 12:02:30 -07:00
Philippe Tillet
af76c989eb [RUNTIME] Make entry point cache key depend on triton version hash (#765) 2022-10-11 13:24:30 -07:00
Bin Bao
09cc2d454b [FRONTEND] Fix a bool tensor storing problem (#746) 2022-10-10 12:11:50 -07:00
Felipe Petroski Such
5d4b26d380 [RUNTIME] support multiple devices in the same process (#757) 2022-10-09 20:30:04 -07:00
Chris
9a11a567ce [DOCS] Fixed typos in 01-vector-add.py (#751) 2022-10-09 18:12:46 -07:00
Keren Zhou
11345e9b74 [RUNTIME] Add callback functions for external tools (#738) 2022-10-05 14:46:55 -07:00
Philippe Tillet
bdfdb9a1d2 [RUNTIME] Fixed JIT bug that leg some constexpr values to be overriden by specialization parameters (#742) 2022-10-05 11:00:32 -07:00
shenggan
77c752dc78 [RUNTIME] remove fixed cu_include_dir (#739)
Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
2022-10-04 19:49:57 -07:00
Natalia Gimelshein
d3c925db8a [FRONTEND] properly broadcast scalar where condition (#736) 2022-10-04 12:44:03 -07:00
fdrocha
2b0f877fad [RUNTIME] Support environments with multiple cudalibs (#733) 2022-10-03 18:36:24 +00:00
Keren Zhou
4a2d3b7d79 [RUNTIME] Dump llvm, ttir, and sass to help debugging (#732) 2022-10-03 00:39:52 +00:00
Natalia Gimelshein
f55960e773 [FRONTEND] fix broadcasting for where (#729)
Fixes #532, all 3 inputs to where have to be broadcast together.
2022-10-01 13:18:47 -07:00
Phil Tillet
b244db06da [TUTORIALS] Attention tutorial fixup 2022-09-30 19:31:43 -07:00
Shintaro Iwasaki
7b61303ea1 [CODEGEN] Fix extract_N_bufferable in layout analysis (#728) 2022-09-30 12:21:22 -07:00
Shintaro Iwasaki
ae59f51c2d [CODEGEN] Fix an inliner to call a function with a phi-node (#727) 2022-09-29 21:36:40 -07:00
albanD
f45e31ba7c [FRONTEND] Make sure to hold the gil when creating python objects (#726)
Without this patch, a debug version of python complains that:
```
Fatal Python error: Python memory allocator called without holding the GIL
Python runtime state: initialized
```
2022-09-29 18:06:22 -07:00
Philippe Tillet
dad97528b2 [TESTING] allclose fixup (#724) 2022-09-28 22:49:05 +00:00
Jason Ansel
998fd5f9af [FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
2022-09-24 13:41:47 -07:00
Shintaro Iwasaki
3ac929b48b [BUILD] Download pybind11 in setup.py (#703)
Based on the discussion in #700, this PR enables downloading pybind11 in
`setup.py` without `git submodule` instead of copy-pasting pybind11
code. The downloaded pybind11 will be in `~/.triton/pybind` (like
`llvm`).
2022-09-23 15:54:07 -07:00
Jason Ansel
579c03615d [FRONTEND] Reduce number of compiles in JITFunction (#704)
I suspect this was the cause of the "new compiles even on a warm cache"
behavior I was seeing, though haven't 100% confirmed it.

Python `set()` iteration order is nondeterministic when you create a new
process. So the same args could produce different `instance_descriptor`s
and have false cache misses.
2022-09-23 21:44:52 +00:00
Philippe Tillet
25e1b36785 Revert "[pybind11] Use git-submodule for pybind11" (#701)
Reverts openai/triton#699
2022-09-23 12:25:38 -07:00
Shintaro Iwasaki
61d104ab3a [FRONTEND] Use git-submodule for pybind11 (#699)
This PR changes the `pybind11` source code management from copy-paste to
a package controlled by git-submodule.

See the discussion in #694 for details.
2022-09-23 09:55:03 -07:00
Philippe Tillet
8c3d4d5749 [RUNTIME] now decoupling entry point from cubin (#696) 2022-09-22 16:44:22 -07:00
Shintaro Iwasaki
df67068bb0 [pybind11] Update pybind11 to 2.10.0 (#691)
This PR updates the version of pybind11 to 2.10.0 (the latest stable).
2022-09-21 20:18:02 -07:00
Philippe Tillet
677ddae618 [FRONTEND] Add warmup for triton.jit() (#684)
This revives #671 , removing the static functions that may unnecessarily hold a reference to the grid and the JITFunction object

Co-authored-by: Jason Ansel <jansel@jansel.net>
2022-09-21 19:13:20 +00:00
Jason Ansel
6abe813d1c Fix issue breaking cudagraphs (#685)
@ngimel figured this one out. 

The errors we were seeing from cudagraphs capture were coming from
`cuStreamGetCtx` which is not allowed while a stream is capturing.

It appears the result of `cuStreamGetCtx()` isn't even used, so I
believe it can just be removed.
2022-09-21 10:20:48 -07:00
Philippe Tillet
e318185eb4 [DOCS] Improved README.md wording (#683)
Initial wording dates from a time where nobody knew Triton, and
comparing it to CUDA helped differentiate it from other existing DSLs.
But nowadays this comparison doesn't make much sense; Triton is its own
thing, and some people may even still be more productive in CUDA than
Triton -- language preferences are subjective after all.
2022-09-20 18:09:43 -07:00
Philippe Tillet
7dc2a70edb Revert "Add .warmup() for triton.jit()" (#682)
Reverts openai/triton#671

It seems like for some reason this caused out-of-memory errors on some
of our internal workloads. I'm reverting this so that HEAD can be used
in production at OpenAI, and I will work on digging into this issue
asynchronously.
2022-09-20 16:05:14 -07:00
Philippe Tillet
48f30550f1 [FRONTEND] Now using raw compiler syscalls when possible (#678) 2022-09-19 21:01:36 -07:00
Jason Ansel
93b1adc53b [FRONTEND] Add .warmup() for triton.jit() (#671) 2022-09-18 23:09:34 -07:00
Phil Tillet
82956e5d6b [PACKAGING] Added missing package 2022-09-18 17:34:05 -07:00
Philippe Tillet
2baf333d44 [DOCS] Fixed typos (#670) 2022-09-18 17:13:12 -07:00
Jason Ansel
49f6bc3f2b [FRONTEND] Fix filename too long error in new runtime (#669) 2022-09-18 21:26:29 +00:00
Phil Tillet
00f4ef6958 [CI] wheel/docs workflows now only run on V100 machine 2022-09-18 13:28:35 -07:00
Jason Ansel
e647402fd3 Fix warning in generated C code (#667) 2022-09-18 12:57:32 -07:00
Philippe Tillet
4a77dfb042 [FRONTEND] Complete rewrite of the runtime (#644)
This PR completely rewrites the runtime of Triton to be more lean and
clearly separate the compilation step from the just-in-time caching logic.
This should substantially reduce launch overhead.
2022-09-18 08:51:48 -07:00
Ian Bearman
889d9e34a1 [REPO] update gitignore (#666)
Update `.gitignore` to include `.vs` and `.vscode`
2022-09-17 14:25:28 -07:00
Shintaro Iwasaki
c668d6596e [DOCS] Fix spelling (#664)
This PR applies minor spelling fix in comments and string literals to
`master`. It shouldn't hurt anything.
2022-09-16 12:26:40 -07:00
Sophia Wisdom
4580a04710 [FRONTEND] Improve error message for CPU tensors (#654)
Redo of #651 against master. Fixes #525 by catching CUDA error when we
check pytorch tensor size and rethrowing a more informative error that
says why we failed.
2022-09-14 14:26:42 -07:00
Philippe Tillet
cfbbc7b43a [CI] Added V100 tag to disambiguate self-hosted runners (#653) 2022-09-14 13:47:50 -07:00
Yunxing Dai
59a8e25f43 [DOCS] Fix typo (#650) 2022-09-14 12:17:05 -07:00
Da Yan
437ced38c2 fp8 <> bf16 conversion (#637)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-08-30 14:20:12 -07:00
Da Yan
210a296699 [BACKEND] bf16 flash-attention (#636) 2022-08-26 20:40:55 -07:00
Daniil Fukalov
fe0c29b9ec Fix inconsistent struct declaration instead of class. (#632)
Looks like typo.
2022-08-26 16:20:21 -07:00
Phil Wang
7394d732ad [DOCS] support for variable head dimensions in flash attention triton tutorial (#623) 2022-08-15 19:16:49 -07:00
Da Yan
3e2953f357 Allow multiple_of and max_contiguous to accept n-d values (#617) 2022-08-10 09:59:32 -07:00
Daniil Fukalov
cc79376222 Fix deprectaion warning on CreateGEP(Value *, ArrayRef<Value *>, const Twine &) (#608)
This variant of CreateGEP() is already removed in LLVM 14.
2022-08-07 17:10:18 -07:00
Daniil Fukalov
7b91c7befd Fix "warning: control reaches end of non-void function". (#607) 2022-08-02 16:12:48 -07:00
Sharad Vikram
968f59027e Expose module.print in pybind (#604) 2022-07-29 21:36:08 -07:00
Anton Kostin
923d468187 Update LICENSE (#602) 2022-07-25 09:30:03 -07:00
Jason Ansel
027321cdcf [FRONTEND] Make tl.rand() 1-exclusive (#601) 2022-07-24 17:47:23 -07:00
Jason Ansel
e02e56dc63 [FRONTEND] Add missing rfloordiv (#598)
* [FRONTEND] Add missing rfloordiv

* fix tests
2022-07-23 21:54:12 -07:00
Philippe Tillet
ab56d310dd [BACKEND][IR] Fixed up internal dtype size for booleans (1bit -> 8bit) (#600) 2022-07-23 20:08:03 -07:00
Da Yan
f28caddbf8 [FRONTEND] Allow tl.where to select pointers (#595) 2022-07-21 09:54:27 -07:00
Keren Zhou
af85f5fa46 [FRONTEND] Refresh cache when the source code of outlined functions are changed (#590) 2022-07-20 17:34:07 -07:00
daadaada
9b2bc88d11 [BACKEND] Better bf16 support (#588) 2022-07-19 21:22:37 -07:00
Philippe Tillet
86cab58d89 [CI] Changed dev wheel date to UTC time to match CRON schedule (#587) 2022-07-18 14:54:13 -07:00
Phil Tillet
5b04331dd2 [TUTORIALS] Added more credits in fused attention tutorial 2022-07-13 23:48:58 -07:00
Jason Ansel
0a3f3d5f25 [PACKAGING] Include triton/language/libdevice.10.bc in package data (#582) 2022-07-13 23:45:27 -07:00
Keren Zhou
4912916c11 [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) 2022-07-13 15:52:21 -07:00
Phil Tillet
971f5782b4 [tutorials] Added flash attention credits in tutorial 2022-07-11 18:56:48 -07:00
Philippe Tillet
d5eb9bc230 [tutorial] Added bwd in fused attention example (#579)
Doesn't work on V100
2022-07-11 15:43:46 -07:00
Jason Ansel
c9a2b9c7d4 [FRONTEND] Add missing args to get_simd_tflops() (#578) 2022-07-11 14:37:59 -07:00
Philippe Tillet
4a399a7e40 [BACKEND] Fix some bugs (atomics, a segfault...) (#577)
This should fix #558 , #573 and #574
2022-07-06 20:03:04 -07:00
vesuppi
22105bc33b [FRONTEND] Added type check in semantic arange (#572) 2022-07-03 15:25:37 -07:00
Keren Zhou
4bf509889b [BUILD] Change the default build type to Release (#571) 2022-07-01 12:17:22 -07:00
Keren Zhou
a74cce375f [FRONTEND] Raise broadcast error (#555) 2022-06-30 17:32:07 -07:00
Philippe Tillet
f733327ba4 [BACKEND][CODEGEN] Disabling L2 residency control by default (#570) 2022-06-29 17:05:13 -07:00
Natalia Gimelshein
1bbb2430d9 [TUTORIALS] adjust heuristics for dwdb kernel (#565) 2022-06-29 17:00:22 -07:00
Kashif Rasul
1895ceaa2d [TUTORIAL] Fix f-string for older python (#569)
fixes issue #568
2022-06-29 09:39:10 -07:00
Philippe Tillet
feb7a2a0dc [FRONTEND] Hotfix for store argument order (#567) 2022-06-28 00:24:02 -07:00
Philippe Tillet
5b4c8f221e [BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
2022-06-27 11:49:19 -07:00
Keren Zhou
87413bc925 [BACKEND] Fix layout convert for non-contiguous input (#564) 2022-06-25 23:12:03 -07:00
Keren Zhou
d345ddf837 [DOCS] Separate atomic cas from other atomic operations since operands are very different (#559) 2022-06-22 17:51:17 -07:00
Keren Zhou
b02bac41ba [CI] Change cache dir (#561) 2022-06-22 11:44:35 -07:00
Keren Zhou
a428cf0bb2 [FRONTEND] Fix pytorch warning. (#560)
UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc').
2022-06-20 20:12:09 -07:00
Keren Zhou
b5e728cb14 Add argmin argmax (#552) 2022-06-15 13:55:20 -07:00
Jason Ansel
6b9756532f [BACKEND] Remove print in coalesce.cc (#551) 2022-06-15 13:13:20 -07:00
Madeleine Thompson
8ce2c12e33 [PYTHON] move ephemeral files to homedir (#549)
This prevents potential conflicts with other users on shared machines.
2022-06-13 19:37:52 -07:00
Keren Zhou
93209c07e0 [BACKEND][CODEGEN] Fix reduce uint (#547) 2022-06-13 16:43:57 -07:00
Philippe Tillet
58c8889235 [FRONTEND] Fix scanline layout (#548) 2022-06-13 16:21:10 -07:00
Natalia Gimelshein
7094657aa9 [FRONTEND] fix bool conversion of floating types (#545) 2022-06-13 15:52:37 -07:00
Keren Zhou
38573d1261 [FRONTEND] Return allocated registers and spilled registers for users (#541) 2022-06-07 18:37:12 -07:00
Mengchi Zhang
2cdc6d35c4 [FRONTEND] Give col_per_thread an initial value to make the compiler happy (#535)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-06-06 12:48:23 -07:00
TC
f13cbaab9f [FRONTEND] assert that num_warps is a power of 2 (#539) 2022-06-06 11:37:08 -07:00
Philippe Tillet
751e325d2e [TUTORIALS] Fixed typo 2022-06-05 13:33:21 -07:00
Philippe Tillet
801c8a4c92 [TUTORIALS] Fixed typo 2022-06-05 12:32:07 -07:00
Philippe Tillet
8876e53206 [BACKEND] Restored reduction bugfixes 2022-06-03 11:38:52 -07:00
Philippe Tillet
a60374a597 Revert "[BACKEND] Various bug fixes; making reductions faster (#533)".
This is a more stable commit that produce bitwise identical code to earlier
versions. Using commits after this one may lead to slightly different numerics
2022-06-03 11:36:06 -07:00
Philippe Tillet
efa04cac1f [FRONTEND] A couple of bugfixes (#534) 2022-06-02 16:57:37 -07:00
Philippe Tillet
3e7500dfe6 [BACKEND] Various bug fixes; making reductions faster (#533) 2022-05-31 17:14:44 -07:00
Bert Maher
37037bb3be [FRONTEND] Default cache dir to /tmp/triton_$USER (#527) 2022-05-27 13:51:05 -07:00
Philippe Tillet
c82a206684 [FRONTEND] Better dot error message (#531) 2022-05-26 17:41:09 -07:00
Philippe Tillet
0e2883020a [BACKEND] Fixed typo in alignment analysis (#528) 2022-05-25 20:01:19 -07:00
Bert Maher
43fec2adca [FRONTEND] Add binding for create_int_to_ptr (#526) 2022-05-25 15:26:18 -07:00
Philippe Tillet
011bc83c1b [FRONTEND] For loops now promote initial value (#524) 2022-05-24 13:20:10 -07:00
Natalia Gimelshein
96bff90471 [FRONTEND] faster jit function launch (#523)
With fast (200 ns) get_stream function soon to be available from pytorch this shaves off approx 25-30 us from function launch, but even without that function due to caching device properties we are saving ~15-20us.
2022-05-24 12:08:49 -07:00
daadaada
d5eaa8dfa0 Making the generated Triton IR deterministic & a script to compare cached assembly (#522) 2022-05-24 08:56:36 -07:00
Shantanu
80f6a2698b [FRONTEND] Ensure version_key is called at most once (#519)
Co-authored-by: hauntsaninja <>
2022-05-23 13:40:08 -07:00
daadaada
205a493b10 [FRONTEND] Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas (#520)
Fix a bug in atomic_cas (correct cmp to val) & more tests on atomic_cas
2022-05-21 09:45:54 -07:00
Jiabao Lei
abea3dc2c6 [FRONTEND] provide device kwargs && fix fstring error for py<3.8 (#515)
Co-authored-by: Philippe Tillet <phil@openai.com>
2022-05-14 16:21:46 -07:00
Philippe Tillet
d35617bea1 [BACKEND][CODEGEN] Faster reduction for scanline layout (#516) 2022-05-14 15:26:13 -07:00
Mengchi Zhang
d1a22a94e6 [FRONTEND] Add empty return value and remove protect to open the access to contained_tys_vec_t (#514)
Signed-off-by: Mengchi Zhang <mengchi@fb.com>
2022-05-13 11:46:12 -07:00
Jason Ansel
d954a05989 [FRONTEND] Handle torch.uint8 args (#513)
Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
2022-05-12 13:07:39 -07:00
Philippe Tillet
0835a4fb05 [TUTORIALS] Removed #noformat in layer norm tutorial 2022-05-12 12:41:25 -07:00
Philippe Tillet
c736ba7c3e [TUTORIALS] Fixed formatting 2022-05-12 12:31:23 -07:00
Philippe Tillet
cd30a99aa2 [TUTORIALS] fixed formatting 2022-05-12 12:28:22 -07:00
Philippe Tillet
d87435e536 [TUTORIALS] Layer norm tutorial now uses residency control (#510) 2022-05-05 19:53:54 -07:00
Sriram Murali
7c9bc5a47b [CODEGEN] Change return type of generator::packed_type to appease build warnings (#507) 2022-05-04 20:03:37 -07:00
Philippe Tillet
95feb10ec9 [FRONTEND] fixup (#505) 2022-04-30 14:25:06 -07:00
Philippe Tillet
11a908655d [FRONTEND] Fixup 2022-04-29 14:35:09 -07:00
Phil Tillet
cd78ce4888 [FRONTEND] Improved error message when assigning None to non-constexpr 2022-04-29 09:17:54 -07:00
Philippe Tillet
ae2a1ab225 [BACKEND] Alignment pass improvements (#503) 2022-04-25 21:16:00 -07:00
Philippe Tillet
7d544799a0 [BACKEND] Now disabling L2 eviction policy for sm < 80 2022-04-25 09:35:36 -07:00
Philippe Tillet
3ca792043f [TEST] Added test for vectorization 2022-04-24 13:50:48 -07:00
Philippe Tillet
bda209002e [BACKEND][CODEGEN] vectorization bugfix (#502) 2022-04-23 13:18:33 -07:00
Philippe Tillet
0cc3b1129b [BACKEND][CODE_GEN] eviction policies now also apply to L2 (#501) 2022-04-21 23:56:01 -07:00
Philippe Tillet
7d6c504e8d [TESTING] Added testing utilities for fixing clock and using cuda-memcheck (#500) 2022-04-21 22:40:10 -07:00
Philippe Tillet
073be1d2ee [FRONTEND] check that tensors have power-of-two number of elements (#499) 2022-04-14 19:30:02 -07:00
Philippe Tillet
5c7122004c [TUTORIALS] Tutorial shouldn't expose clock. Just removed it. 2022-04-14 17:33:44 -07:00
Philippe Tillet
dc4d40faec [FRONTEND] now mangle constexpr float containing "e-" 2022-04-14 10:26:48 -07:00
Philippe Tillet
25f6689508 [FRONTEND] rename current stream monkey patch (#495) 2022-04-13 11:45:55 -07:00
Philippe Tillet
76bfac9f15 [FRONTEND] Improved constexpr handling (#493) 2022-04-12 00:02:54 -07:00
Philippe Tillet
14b0fd4cfb [FRONTEND] Added possibility for users to customize current stream query (#492) 2022-04-07 12:11:32 -07:00
Philippe Tillet
6424771f55 [CI] Documentation fixup 2022-04-07 09:42:35 -07:00
Philippe Tillet
9f08ecd684 [FRONTEND] Semantic analysis refactor (#491)
Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
2022-04-06 16:13:53 -07:00
Philippe Tillet
2bed6fc850 [LANG] Added support for device functions (#484) 2022-04-03 20:58:16 -07:00
apd10
e85c7a7fc7 Bugfix in ptxas path. (#487)
Bug: "ret" value is destroyed when a failing "ptxas --version" is run
overwriting the previous valid "ret" value.

Fix: keep rets only for those runs which are successful. Pick the first
one
2022-03-30 20:45:41 -07:00
Philippe Tillet
bace26143d [TUTORIALS] Removed leftover print 2022-03-28 16:53:23 -07:00
Philippe Tillet
e0cc488055 [FRONTEND] Added tl.clock and tl.globaltimer (#485) 2022-03-28 16:15:43 -07:00
Philippe Tillet
76a9ee50a8 Revert "[FRONTEND] Semantic analysis refactor (#473)" (#483)
This reverts commit 539961072c.
2022-03-24 17:16:50 -07:00
Philippe Tillet
ea6d1f1b85 [DRIVER] LLVM driver fixup (#482)
Current way of doing things is probably not super thread safe. init is shared between threads and some threads my not call the LLVMInitialize* function.
2022-03-23 00:24:45 -07:00
Keren Zhou
a4f68165cd [FRONTEND] Hot fix for lineno (#481)
Override __reduce__ to make CompilationError pickable and print out error messages
2022-03-22 22:09:49 -07:00
97 changed files with 7477 additions and 7226 deletions

View File

@@ -4,7 +4,7 @@ on:
workflow_dispatch:
pull_request:
branches:
- main
- master
- triton-mlir
jobs:
@@ -40,26 +40,26 @@ jobs:
rm -rf ~/.triton/cache/
- name: Check imports
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check python style
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Check cpp style
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
- name: Flake8
if: startsWith(matrix.runner, 'ubuntu')
if: ${{ matrix.runner != 'macos-10.15' }}
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
@@ -79,24 +79,11 @@ jobs:
lit -v "$LIT_TEST_DIR"
- name: Run python tests
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'A10'}}
if: ${{matrix.runner[0] == 'self-hosted'}}
run: |
cd python/tests
cd python/test/unit/
pytest
# TODO[Superjomn] Enable all the tests on V100 if available
- name: Run python tests on V100
if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}}
run: |
cd python/tests
pytest -k "not test_where_broadcast and not test_dot" test_core.py
pytest test_gemm.py
pytest test_backend.py
pytest test_reduce.py
pytest test_vecadd.py
pytest test_elementwise.py
pytest test_ext_elemwise.py
pytest test_transpose.py
- name: Run CXX unittests
run: |

View File

@@ -19,6 +19,10 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
# used conditionally in this file and by lit tests
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
# Default build type
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Default build type: Release")
@@ -218,8 +222,10 @@ target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
if(WIN32)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} dl) # dl is from dlfcn-win32
else()
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z)
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs)
endif()

View File

@@ -33,6 +33,15 @@ And the latest nightly release:
pip install -U --pre triton
```
# Install from source
```
git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .
```
# Changelog
Version 1.1 is out! New features include:

View File

@@ -10,8 +10,8 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"

View File

@@ -45,6 +45,10 @@ bool maybeAliasOp(Operation *op);
bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
Type getElementType(Value value);
std::string getValueOperandName(Value value, AsmState &state);
template <typename T_OUT, typename T_IN>

View File

@@ -10,20 +10,22 @@ namespace triton {
namespace type {
// Integer types
Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
Type u32Ty(MLIRContext *ctx) {
// TODO(Superjomn): may change `static` into better implementations
static Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); }
static Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); }
static Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); }
static Type u32Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 32, IntegerType::Unsigned);
}
Type u1Ty(MLIRContext *ctx) {
static Type u1Ty(MLIRContext *ctx) {
return IntegerType::get(ctx, 1, IntegerType::Unsigned);
}
// Float types
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
static Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
static Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
static Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }
static bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();

View File

@@ -2,8 +2,8 @@
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
namespace mlir {
namespace triton {

View File

@@ -1,5 +1,5 @@
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_ASM_FORMAT_H
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -172,11 +172,11 @@ private:
return argArchive.back().get();
}
// Make the oprands in argArchive follow the provided \param order.
// Make the operands in argArchive follow the provided \param order.
void reorderArgArchive(ArrayRef<Operand *> order) {
assert(order.size() == argArchive.size());
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
// it do necessary when onlyAttachMLIRArgs is true for the $0,$1.. are
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
// determined by PTX code snippet passed from external.
sort(argArchive.begin(), argArchive.end(),
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
@@ -306,8 +306,7 @@ struct PTXInstrExecution {
bool onlyAttachMLIRArgs{};
};
//// =============================== Some instruction wrappers
///===============================
/// ====== Some instruction wrappers ======
// We add the wrappers to make the usage more intuitive by avoiding mixing the
// PTX code with some trivial C++ code.
@@ -324,4 +323,4 @@ struct PTXCpAsyncLoadInstr : PTXInstrBase<PTXCpAsyncLoadInstr> {
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
#endif

View File

@@ -1,43 +0,0 @@
#ifndef TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#define TRITON_CONVERSION_TRITONGPUTOLLVM_TRITONGPUTOLLVMPASS_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx,
mlir::LLVMTypeConverter &typeConverter);
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
};
namespace triton {
// Names for identifying different NVVM annotations. It is used as attribute
// names in MLIR modules. Refer to
// https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for
// the full list.
struct NVVMMetadataField {
static constexpr char MaxNTid[] = "nvvm.maxntid";
static constexpr char Kernel[] = "nvvm.kernel";
};
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
} // namespace triton
} // namespace mlir
#endif

View File

@@ -0,0 +1,22 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PASS_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability = 80);
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,5 +1,5 @@
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H_
#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H
#include <memory>

View File

@@ -3,4 +3,9 @@
include "mlir/IR/OpBase.td"
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
#endif // TRITON_INTERFACES

View File

@@ -12,10 +12,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
//
// Op Base
//
@@ -293,7 +289,7 @@ def TT_CatOp : TT_Op<"cat", [NoSideEffect,
}
def TT_TransOp : TT_Op<"trans", [NoSideEffect,
SameOperandsAndResultElementType]> {
SameOperandsAndResultElementType]> {
let summary = "transpose a tensor";

View File

@@ -39,6 +39,8 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout);
SmallVector<unsigned> getOrder(const Attribute &layout);
bool isaDistributedLayout(const Attribute &layout);
} // namespace gpu
} // namespace triton
} // namespace mlir

View File

@@ -81,7 +81,6 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
if(!mmaEnc)
return $_get(context, 1, 1, 1, order);
int version = mmaEnc.getVersion();
int opIdx = dotOpEnc.getOpIdx();
// number of rows per phase
@@ -91,8 +90,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// index of the inner dimension in `order`
unsigned inner = (opIdx == 0) ? 0 : 1;
// ---- begin version 1 ----
if (version == 1) {
// ---- begin Volta ----
if (mmaEnc.isVolta()) {
bool is_row = order[0] != 0;
bool is_vec4 = opIdx == 0 ? !is_row && (shape[order[0]] <= 16) :
is_row && (shape[order[0]] <= 16);
@@ -107,8 +106,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
return $_get(context, vec, perPhase, maxPhase, order);
}
// ---- begin version 2 ----
if (version == 2) {
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
// for now, disable swizzle when using transposed int8 tensor cores
@@ -292,9 +291,12 @@ def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
let description = [{
An encoding for tensors that have been produced by tensor cores.
It is characterized by two parameters:
- A 'version' which specifies the generation the tensor cores
- A 'versionMajor' which specifies the generation the tensor cores
whose output is being partitioned: 1 for first-gen tensor cores (Volta),
and 2 for second-gen tensor cores (Turing/Ampere).
- A 'versionMinor' which indicates the specific layout of a tensor core
generation, e.g. for Volta, there might be multiple kinds of layouts annotated
by 0,1,2 and so on.
- A `blockTileSize` to indicate how data should be
partitioned between warps.
@@ -305,7 +307,8 @@ Note: the layout is different from the recommended in PTX ISA
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
(mma.884 section, FP32 accumulator).
For example, the matrix L corresponding to blockTileSize=[32,16] is:
For example, when versionMinor=1, the matrix L corresponding to
blockTileSize=[32,16] is:
warp 0
--------------------------------/\-------------------------------
@@ -367,11 +370,39 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
let parameters = (
ins
"unsigned":$version,
"unsigned":$versionMajor,
"unsigned":$versionMinor,
ArrayRefParameter<"unsigned">:$warpsPerCTA
);
let extraClassDeclaration = extraBaseClassDeclaration;
let builders = [
// specific for MMAV1(Volta)
AttrBuilder<(ins "int":$versionMajor,
"ArrayRef<unsigned>":$warpsPerCTA,
"ArrayRef<int64_t>":$shapeA,
"ArrayRef<int64_t>":$shapeB,
"bool":$isARow,
"bool":$isBRow), [{
assert(versionMajor == 1 && "Only MMAv1 has multiple versionMinor.");
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
int versionMinor = (isARow * (1<<0)) |\
(isBRow * (1<<1)) |\
(isAVec4 * (1<<2)) |\
(isBVec4 * (1<<3));
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
}]>
];
let extraClassDeclaration = extraBaseClassDeclaration # [{
bool isVolta() const;
bool isAmpere() const;
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool> decodeVoltaLayoutStates() const;
}];
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
@@ -434,7 +465,7 @@ section 9.7.13.4.1 for more details.
"Attribute":$parent), [{
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
parent.cast<MmaEncodingAttr>().isVolta()){
isMMAv1Row = BoolAttr::get(context, true);
}
return $_get(context, opIdx, parent, isMMAv1Row);

View File

@@ -44,7 +44,9 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";
let description = [{}];
@@ -56,7 +58,9 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
let results = (outs TT_BoolLike:$result);
}
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";
let description = [{}];
@@ -69,7 +73,9 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect]> {
}
// TODO: migrate to arith::SelectOp on LLVM16
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect]> {
def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";
let description = [{}];

View File

@@ -14,6 +14,7 @@ namespace mlir {
class TritonGPUTypeConverter : public TypeConverter {
public:
TritonGPUTypeConverter(MLIRContext *context, int numWarps);
int getNumWarps() const { return numWarps; }
private:
MLIRContext *context;

View File

@@ -25,15 +25,12 @@ void addExternalLibs(mlir::ModuleOp &module,
// Translate TritonGPU dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module,
int computeCapability);
mlir::ModuleOp module, int computeCapability);
// Translate mlir LLVM dialect to LLVMIR, return null if failed.
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module);
bool linkExternLib(llvm::Module &module, llvm::StringRef path);
} // namespace triton
} // namespace mlir

View File

@@ -25,13 +25,14 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
if (maybeSharedAllocationOp(op)) {
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// FIXME(Keren): extract and insert are always alias for now
// XXX(Keren): the following ops are always aliasing for now
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
// extract_slice %src
// trans %src
aliasInfo = AliasInfo(operands[0]->getValue());
pessimistic = false;
} else if (isa<tensor::InsertSliceOp>(op) ||
isa<triton::gpu::InsertSliceAsyncOp>(op)) {
} else if (isa<tensor::InsertSliceOp, triton::gpu::InsertSliceAsyncOp>(
op)) {
// insert_slice_async %src, %dst, %index
// insert_slice %src into %dst[%offsets]
aliasInfo = AliasInfo(operands[1]->getValue());

View File

@@ -177,9 +177,10 @@ private:
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * srcTy.getElementTypeBitWidth() / 8;
auto bytes =
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
@@ -193,9 +194,10 @@ private:
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
auto bytes =
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
@@ -296,10 +298,24 @@ private:
/// Resolves liveness of all values involved under the root operation.
void resolveLiveness() {
// In the SCF dialect, we always have a sequentially nested structure of
// blocks
// Assign an ID to each operation using post-order traversal.
// To achieve the correct liveness range, the parent operation's ID
// should be greater than each of its child operation's ID .
// Example:
// ...
// %5 = triton.convert_layout %4
// %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) {
// %2 = triton.convert_layout %5
// ...
// scf.yield %arg0
// }
// For example, %5 is defined in the parent region and used in
// the child region, and is not passed as a block argument.
// %6 should should have an ID greater than its child operations,
// otherwise %5 liveness range ends before the child operation's liveness
// range ends.
DenseMap<Operation *, size_t> operationId;
operation->walk<WalkOrder::PreOrder>(
operation->walk<WalkOrder::PostOrder>(
[&](Operation *op) { operationId[op] = operationId.size(); });
// Analyze liveness of explicit buffers

View File

@@ -116,11 +116,29 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
return (aElemTy.isF16() && bElemTy.isF16()) ||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
version >= 2) ||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) && version >= 2);
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.allowTF32() && version >= 2;
}
return supportMMA(op.a(), version) && supportMMA(op.b(), version);
}
bool supportMMA(Value value, int version) {
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
// We cannot get both the operand types(in TypeConverter), here we assume the
// types of both the operands are identical here.
assert((version == 1 || version == 2) &&
"Unexpected MMA layout version found");
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
return elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return tensorType.getElementType();
return type;
}
std::string getValueOperandName(Value value, AsmState &state) {

View File

@@ -1,20 +0,0 @@
#ifndef TRITON_CONVERSION_PASSDETAIL_H
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
namespace triton {
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
} // namespace triton
} // namespace mlir
#endif

View File

@@ -1,6 +1,13 @@
add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVM.cpp
PtxAsmFormat.cpp
TritonGPUToLLVMPass.cpp
PTXAsmFormat.cpp
ConvertLayoutOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
ViewOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
DotOpToLLVM.cpp
ReduceOpToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM

View File

@@ -0,0 +1,635 @@
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpHelpers.h"
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::isaDistributedLayout;
using ::mlir::triton::gpu::SharedEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout) {
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout;
}
void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> dstStrides,
ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemTy,
Location loc,
ConversionPatternRewriter &rewriter) {
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared");
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcDistributedLayout = srcTy.getEncoding();
if (auto mmaLayout = srcDistributedLayout.dyn_cast<MmaEncodingAttr>()) {
assert((!mmaLayout.isVolta()) &&
"ConvertLayout MMAv1->Shared is not suppported yet");
}
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = getOrder(srcDistributedLayout);
auto outOrd = dstSharedLayout.getOrder();
unsigned inVec =
inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1;
unsigned outVec = dstSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned perPhase = dstSharedLayout.getPerPhase();
unsigned maxPhase = dstSharedLayout.getMaxPhase();
unsigned numElems = getElemsPerThread(srcTy);
assert(numElems == srcIndices.size());
auto inVals = getElementsFromStruct(loc, llSrc, rewriter);
auto wordTy = vec_ty(elemTy, minVec);
auto elemPtrTy = ptr_ty(elemTy);
Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec);
Value word;
for (unsigned i = 0; i < numElems; ++i) {
if (i % minVec == 0)
word = undef(wordTy);
word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec));
if (i % minVec == minVec - 1) {
// step 1: recover the multidim_index from the index of
SmallVector<Value> multiDimIdx = srcIndices[i];
SmallVector<Value> dbgVal = srcIndices[i];
// step 2: do swizzling
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
phaseId = urem(phaseId, i32_val(maxPhase));
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal);
remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal));
Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);
smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3));
store(word, smemAddr);
}
}
}
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isaDistributedLayout(srcLayout) &&
dstLayout.isa<SharedEncodingAttr>()) {
return lowerDistributedToShared(op, adaptor, rewriter);
}
if (srcLayout.isa<SharedEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerSharedToDotOperand(op, adaptor, rewriter);
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
return lowerDistributedToDistributed(op, adaptor, rewriter);
}
if (srcLayout.isa<MmaEncodingAttr>() &&
dstLayout.isa<DotOperandEncodingAttr>()) {
return lowerMmaToDotOperand(op, adaptor, rewriter);
}
// TODO: to be implemented
llvm_unreachable("unsupported layout conversion");
return failure();
}
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, ArrayRef<int64_t> shape,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
for (unsigned d = 0; d < rank; ++d) {
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
multiDimElemId[d]));
}
return multiDimOffset;
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto multiDimOffsetParent =
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
sliceLayout.paddedShape(shape),
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
continue;
unsigned slicedD = d < dim ? d : (d - 1);
multiDimOffset[slicedD] = multiDimOffsetParent[d];
}
return multiDimOffset;
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
SmallVector<Value> mmaColIdx(4);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
Value _1 = idx_val(1);
Value _2 = idx_val(2);
Value _4 = idx_val(4);
Value _8 = idx_val(8);
Value _16 = idx_val(16);
if (mmaLayout.isAmpere()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Value mmaThreadIdInGrp = urem(laneId, _4);
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
Value colWarpOffset = mul(multiDimWarpId[1], _8);
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
} else if (mmaLayout.isVolta()) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
Value laneIdDiv16 = udiv(laneId, _16);
Value laneIdRem16 = urem(laneId, _16);
Value laneIdRem2 = urem(laneId, _2);
Value laneIdRem16Div8 = udiv(laneIdRem16, _8);
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
Value colWarpOffset = mul(multiDimWarpId[1], _16);
mmaRowIdx[0] =
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
laneIdRem2);
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _8);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
if (mmaLayout.isAmpere()) {
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else if (mmaLayout.isVolta()) {
// the order of elements in a thread:
// c0, c1, ... c4, c5
// c2, c3, ... c6, c7
if (elemId < 2) {
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 2 && elemId < 4) {
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2];
} else if (elemId >= 4 && elemId < 6) {
multiDimOffset[0] = mmaRowIdx[0];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
} else if (elemId >= 6) {
multiDimOffset[0] = mmaRowIdx[1];
multiDimOffset[1] = mmaColIdx[elemId % 2 + 2];
}
multiDimOffset[0] = add(
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
multiDimOffset[1] = add(
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
return multiDimOffset;
}
llvm_unreachable("unexpected layout in getMultiDimOffset");
}
// shared memory rd/st for blocked or mma layout with data padding
void processReplica(Location loc, ConversionPatternRewriter &rewriter,
bool stNotRd, RankedTensorType type,
ArrayRef<unsigned> numCTAsEachRep,
ArrayRef<unsigned> multiDimRepId, unsigned vec,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
Value smemBase) const {
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
SmallVector<unsigned> numCTAs(rank);
auto shapePerCTA = getShapePerCTA(layout);
auto order = getOrder(layout);
for (unsigned d = 0; d < rank; ++d) {
numCTAs[d] = ceil<unsigned>(type.getShape()[d], shapePerCTA[d]);
}
auto elemTy = type.getElementType();
bool isInt1 = elemTy.isInteger(1);
bool isPtr = elemTy.isa<triton::PointerType>();
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
if (isInt1)
elemTy = IntegerType::get(elemTy.getContext(), 8);
else if (isPtr)
elemTy = IntegerType::get(elemTy.getContext(), 64);
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) {
auto multiDimCTAInRepId =
getMultiDimIndex<unsigned>(ctaId, numCTAsEachRep, order);
SmallVector<unsigned> multiDimCTAId(rank);
for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) {
auto d = it.index();
multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value();
}
auto linearCTAId =
getLinearIndex<unsigned>(multiDimCTAId, numCTAs, order);
// TODO: This is actually redundant index calculation, we should
// consider of caching the index calculation result in case
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
Value ptr = gep(elemPtrTy, smemBase, offset);
auto vecTy = vec_ty(llvmElemTy, vec);
ptr = bitcast(ptr, ptr_ty(vecTy, 3));
if (stNotRd) {
Value valVec = undef(vecTy);
for (unsigned v = 0; v < vec; ++v) {
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
if (isInt1)
currVal = zext(llvmElemTy, currVal);
else if (isPtr)
currVal = ptrtoint(llvmElemTy, currVal);
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
}
store(valVec, ptr);
} else {
Value valVec = load(ptr);
for (unsigned v = 0; v < vec; ++v) {
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
if (isInt1)
currVal = icmp_ne(currVal,
rewriter.create<LLVM::ConstantOp>(
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
else if (isPtr)
currVal = inttoptr(llvmElemTyOrig, currVal);
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
}
}
}
}
}
// blocked/mma -> blocked/mma.
// Data padding in shared memory to avoid bank conflict.
LogicalResult
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType());
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
SmallVector<unsigned> inNumCTAsEachRep(rank);
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTA = getShapePerCTA(srcLayout);
auto dstShapePerCTA = getShapePerCTA(dstLayout);
for (unsigned d = 0; d < rank; ++d) {
unsigned inPerCTA = std::min<unsigned>(shape[d], srcShapePerCTA[d]);
unsigned outPerCTA = std::min<unsigned>(shape[d], dstShapePerCTA[d]);
unsigned maxPerCTA = std::max(inPerCTA, outPerCTA);
numReplicates[d] = ceil<unsigned>(shape[d], maxPerCTA);
inNumCTAsEachRep[d] = maxPerCTA / inPerCTA;
outNumCTAsEachRep[d] = maxPerCTA / outPerCTA;
assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0);
inNumCTAs[d] = ceil<unsigned>(shape[d], inPerCTA);
outNumCTAs[d] = ceil<unsigned>(shape[d], outPerCTA);
}
// Potentially we need to store for multiple CTAs in this replication
auto accumNumReplicates = product<unsigned>(numReplicates);
// unsigned elems = getElemsPerThread(srcTy);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned inVec = 0;
unsigned outVec = 0;
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
unsigned outElems = getElemsPerThread(dstTy);
auto outOrd = getOrder(dstLayout);
SmallVector<Value> outVals(outElems);
for (unsigned repId = 0; repId < accumNumReplicates; ++repId) {
auto multiDimRepId =
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
if (repId != 0)
barrier();
if (srcLayout.isa<BlockedEncodingAttr>() ||
srcLayout.isa<SliceEncodingAttr>() ||
srcLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, outOrd, vals,
smemBase);
} else {
assert(0 && "ConvertLayout with input layout not implemented");
return failure();
}
barrier();
if (dstLayout.isa<BlockedEncodingAttr>() ||
dstLayout.isa<SliceEncodingAttr>() ||
dstLayout.isa<MmaEncodingAttr>()) {
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy,
outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape,
outOrd, outVals, smemBase);
} else {
assert(0 && "ConvertLayout with output layout not implemented");
return failure();
}
}
SmallVector<Type> types(outElems, llvmElemTy);
auto *ctx = llvmElemTy.getContext();
Type structTy = struct_ty(types);
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
rewriter.replaceOp(op, result);
return success();
}
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
LogicalResult
lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
assert(srcShape.size() == 2 &&
"Unexpected rank of ConvertLayout(blocked->shared)");
auto srcLayout = srcTy.getEncoding();
auto dstSharedLayout = dstTy.getEncoding().cast<SharedEncodingAttr>();
auto inOrd = getOrder(srcLayout);
auto outOrd = dstSharedLayout.getOrder();
Value smemBase = getSharedMemoryBase(loc, rewriter, dst);
auto elemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3);
smemBase = bitcast(smemBase, elemPtrTy);
auto dstStrides =
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
auto smemObj =
SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
// shared -> mma_operand
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
auto dstTensorTy = dst.getType().cast<RankedTensorType>();
auto srcTensorTy = src.getType().cast<RankedTensorType>();
auto dotOperandLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto sharedLayout = srcTensorTy.getEncoding().cast<SharedEncodingAttr>();
bool isOuter{};
int K{};
if (dotOperandLayout.getOpIdx() == 0) // $a
K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]];
else // $b
K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]];
isOuter = K == 1;
Value res;
if (auto mmaLayout =
dotOperandLayout.getParent().dyn_cast_or_null<MmaEncodingAttr>()) {
res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout,
dotOperandLayout, isOuter);
} else if (auto blockedLayout =
dotOperandLayout.getParent()
.dyn_cast_or_null<BlockedEncodingAttr>()) {
auto dotOpLayout =
dstTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
DotOpFMAConversionHelper helper(blockedLayout);
auto thread = getThreadId(rewriter, loc);
if (dotOpLayout.getOpIdx() == 0) { // $a
res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
} else { // $b
res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc,
rewriter);
}
} else {
assert(false && "Unsupported dot operand layout found");
}
rewriter.replaceOp(op, res);
return success();
}
// mma -> dot_operand
LogicalResult
lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.src().getType().cast<RankedTensorType>();
auto dstTy = op.result().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto srcMmaLayout = srcLayout.cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
// get source values
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
unsigned elems = getElemsPerThread(srcTy);
Type elemTy =
this->getTypeConverter()->convertType(srcTy.getElementType());
// for the destination type, we need to pack values together
// so they can be consumed by tensor core operations
unsigned vecSize =
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
Type vecTy = vec_ty(elemTy, vecSize);
SmallVector<Type> types(elems / vecSize, vecTy);
SmallVector<Value> vecVals;
for (unsigned i = 0; i < elems; i += vecSize) {
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (unsigned j = 0; j < vecSize; j++)
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
vecVals.push_back(packed);
}
// This needs to be ordered the same way that
// ldmatrix.x4 would order it
// TODO: this needs to be refactor so we don't
// implicitly depends on how emitOffsetsForMMAV2
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
}
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value view =
getStructFromElements(loc, reorderedVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
return failure();
}
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout,
const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const {
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.result();
bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor());
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2
MMA16816ConversionHelper mmaHelper(src.getType(), mmaLayout,
getThreadId(rewriter, loc), rewriter,
getTypeConverter(), op.getLoc());
if (dotOperandLayout.getOpIdx() == 0) {
// operand $a
res = mmaHelper.loadA(src, smemObj);
} else if (dotOperandLayout.getOpIdx() == 1) {
// operand $b
res = mmaHelper.loadB(src, smemObj);
}
} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;
res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc),
loc, rewriter);
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
// TODO[Superjomn]: transB is not available here.
bool transB = false;
res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc),
loc, rewriter);
}
} else {
assert(false && "Unsupported mma layout found");
}
return res;
}
};
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, benefit);
}

View File

@@ -0,0 +1,28 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_CONVERT_LAYOUT_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_CONVERT_LAYOUT_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
DotOperandEncodingAttr &dotOperandLayout);
void storeDistributedToShared(Value src, Value llSrc,
ArrayRef<Value> srcStrides,
ArrayRef<SmallVector<Value>> srcIndices,
Value dst, Value smemBase, Type elemPtrTy,
Location loc,
ConversionPatternRewriter &rewriter);
void populateConvertLayoutOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,311 @@
#include "DotOpToLLVM.h"
#include "DotOpHelpers.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::DotOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// D = A * B + C
Value A = op.a();
Value D = op.getResult();
// Here we assume the DotOp's operands always comes from shared memory.
auto AShape = A.getType().cast<RankedTensorType>().getShape();
size_t reduceAxis = 1;
unsigned K = AShape[reduceAxis];
bool isOuter = K == 1;
MmaEncodingAttr mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
if (!isOuter && mmaLayout && supportMMA(op, mmaLayout.getVersionMajor())) {
if (mmaLayout.isVolta())
return convertMMA884(op, adaptor, rewriter);
if (mmaLayout.isAmpere())
return convertMMA16816(op, adaptor, rewriter);
llvm::report_fatal_error(
"Unsupported MMA kind found when converting DotOp to LLVM.");
}
if (D.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<BlockedEncodingAttr>())
return convertFMADot(op, adaptor, rewriter);
llvm::report_fatal_error(
"Unsupported DotOp found when converting TritonGPU to LLVM.");
}
private:
// Convert to mma.m16n8k16
LogicalResult convertMMA16816(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto mmaLayout = op.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
Value A = op.a();
Value B = op.b();
Value C = op.c();
MMA16816ConversionHelper mmaHelper(A.getType(), mmaLayout,
getThreadId(rewriter, loc), rewriter,
getTypeConverter(), loc);
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
assert(ATensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
BTensorTy.getEncoding().isa<DotOperandEncodingAttr>() &&
"Both $a and %b should be DotOperand layout.");
Value loadedA, loadedB, loadedC;
loadedA = adaptor.a();
loadedB = adaptor.b();
loadedC = mmaHelper.loadC(op.c(), adaptor.c());
return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op,
adaptor);
}
/// Convert to mma.m8n8k4
LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *ctx = op.getContext();
auto loc = op.getLoc();
Value A = op.a();
Value B = op.b();
Value D = op.getResult();
auto mmaLayout = D.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
auto DTensorTy = D.getType().cast<RankedTensorType>();
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
DotOpMmaV1ConversionHelper helper(mmaLayout);
unsigned numM = helper.getNumM(AShape, isARow);
unsigned numN = helper.getNumN(BShape, isBRow);
unsigned NK = AShape[1];
auto has = helper.extractLoadedOperand(adaptor.a(), NK, rewriter);
auto hbs = helper.extractLoadedOperand(adaptor.b(), NK, rewriter);
// Initialize accumulators with external values, the acc holds the
// accumulator value that is shared between the MMA instructions inside a
// DotOp, we can call the order of the values the accumulator-internal
// order.
SmallVector<Value> acc = getElementsFromStruct(loc, adaptor.c(), rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
// NOTE The current order of resVals is different from acc, we call it the
// accumulator-external order. and
SmallVector<Value> resVals(resSize);
auto getIdx = [&](int m, int n) {
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
(m * 2 + 1) + (n * 4 + 0) * numM, // row1
(m * 2 + 1) + (n * 4 + 1) * numM,
(m * 2 + 0) + (n * 4 + 2) * numM, // row2
(m * 2 + 0) + (n * 4 + 3) * numM,
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
(m * 2 + 1) + (n * 4 + 3) * numM,
}};
return idx;
};
{ // convert the acc's value from accumuator-external order to
// accumulator-internal order.
SmallVector<Value> accInit(acc.size());
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
auto idx = getIdx(m, n);
for (unsigned i = 0; i < 8; ++i)
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
}
acc = accInit;
}
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
PTXBuilder builder;
auto idx = getIdx(m, n);
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
{ha.first, "r"},
{ha.second, "r"},
});
auto *BOprs = builder.newListOperand({
{hb.first, "r"},
{hb.second, "r"},
});
auto *COprs = builder.newListOperand();
for (int i = 0; i < 8; ++i)
COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i)));
auto mma = builder.create("mma.sync.aligned.m8n8k4")
->o(isARow ? "row" : "col")
.o(isBRow ? "row" : "col")
.o("f32.f16.f16.f32");
mma(resOprs, AOprs, BOprs, COprs);
Value res =
builder.launch(rewriter, loc, helper.getMmaRetType(ATensorTy));
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (unsigned i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, getIntAttr(i));
acc[idx[i]] = elem;
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};
for (unsigned k = 0; k < NK; k += 4)
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
callMMA(m, n, k);
}
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}
LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();
auto threadId = getThreadId(rewriter, loc);
auto A = op.a();
auto B = op.b();
auto C = op.c();
auto D = op.getResult();
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto bTensorTy = B.getType().cast<RankedTensorType>();
auto cTensorTy = C.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
auto cShape = cTensorTy.getShape();
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
auto order = dLayout.getOrder();
auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter);
DotOpFMAConversionHelper helper(dLayout);
Value llA = adaptor.a();
Value llB = adaptor.b();
auto sizePerThread = getSizePerThread(dLayout);
auto shapePerCTA = getShapePerCTA(dLayout);
int K = aShape[1];
int M = aShape[0];
int N = bShape[1];
int mShapePerCTA =
order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int mSizePerThread =
order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]];
int nShapePerCTA =
order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]];
int nSizePerThread =
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
mSizePerThread, rewriter, loc);
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
nSizePerThread, rewriter, loc);
SmallVector<Value> ret = cc;
bool isCRow = order[0] == 1;
for (unsigned k = 0; k < K; k++) {
for (unsigned m = 0; m < M; m += mShapePerCTA)
for (unsigned n = 0; n < N; n += nShapePerCTA)
for (unsigned mm = 0; mm < mSizePerThread; ++mm)
for (unsigned nn = 0; nn < nSizePerThread; ++nn) {
int mIdx = m / mShapePerCTA * mSizePerThread + mm;
int nIdx = n / nShapePerCTA * nSizePerThread + nn;
int z = isCRow ? mIdx * N / nShapePerCTA * mSizePerThread + nIdx
: nIdx * M / mShapePerCTA * nSizePerThread + mIdx;
ret[z] = rewriter.create<LLVM::FMulAddOp>(
loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]);
}
}
auto res = getStructFromElements(
loc, ret, rewriter,
struct_ty(SmallVector<Type>(ret.size(), ret[0].getType())));
rewriter.replaceOp(op, res);
return success();
}
};
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
}

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,865 @@
#include "ElementwiseOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
static SmallVector<Value>
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
"shr.b32 b0, b0, 1; \n"
"shr.b32 b1, b1, 1; \n"
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2x2StructTy =
struct_ty(SmallVector<Type>{fp16x2VecTy, fp16x2VecTy});
auto fp16x2x2Struct =
builder.launch(rewriter, loc, fp16x2x2StructTy, false);
auto fp16x2Vec0 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto fp16x2Vec1 =
extract_val(fp16x2VecTy, fp16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1)),
extract_element(f16_ty, fp16x2Vec1, i32_val(0)),
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, b<2>; \n"
"shl.b32 a0, $1, 1; \n"
"shl.b32 a1, $2, 1; \n"
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
"add.u32 a0, a0, 0x00800080; \n"
"add.u32 a1, a1, 0x00800080; \n"
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
"prmt.b32 $0, b0, b1, 0x7531; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value fp8x4Vec = undef(fp8x4VecTy);
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v2, i32_val(2));
fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v3, i32_val(3));
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
"prmt.b32 a0, 0, $2, 0x5040; \n"
"prmt.b32 a1, 0, $2, 0x7060; \n"
"and.b32 sign0, a0, 0x80008000; \n"
"and.b32 sign1, a1, 0x80008000; \n"
"and.b32 nosign0, a0, 0x7fff7fff; \n"
"and.b32 nosign1, a1, 0x7fff7fff; \n"
"shr.b32 nosign0, nosign0, 4; \n"
"shr.b32 nosign1, nosign1, 4; \n"
"add.u32 nosign0, nosign0, 0x38003800; \n"
"add.u32 nosign1, nosign1, 0x38003800; \n"
"or.b32 $0, sign0, nosign0; \n"
"or.b32 $1, sign1, nosign1; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o0 = builder.newOperand("=r");
auto *o1 = builder.newOperand("=r");
auto *i = builder.newOperand(fp8x4Vec, "r");
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
auto bf16x2VecTy = vec_ty(i16_ty, 2);
auto bf16x2x2StructTy =
struct_ty(SmallVector<Type>{bf16x2VecTy, bf16x2VecTy});
auto bf16x2x2Struct =
builder.launch(rewriter, loc, bf16x2x2StructTy, false);
auto bf16x2Vec0 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({0}));
auto bf16x2Vec1 =
extract_val(bf16x2VecTy, bf16x2x2Struct, rewriter.getI32ArrayAttr({1}));
return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)),
extract_element(i16_ty, bf16x2Vec0, i32_val(1)),
extract_element(i16_ty, bf16x2Vec1, i32_val(0)),
extract_element(i16_ty, bf16x2Vec1, i32_val(1))};
}
static SmallVector<Value>
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v0, i32_val(0));
bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v1, i32_val(1));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v2, i32_val(0));
bf16x2Vec1 = insert_element(bf16x2VecTy, bf16x2Vec1, v3, i32_val(1));
bf16x2Vec0 = bitcast(bf16x2Vec0, i32_ty);
bf16x2Vec1 = bitcast(bf16x2Vec1, i32_ty);
PTXBuilder builder;
auto *ptxAsm = "{ \n"
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
"mov.u32 fp8_min, 0x38003800; \n"
"mov.u32 fp8_max, 0x3ff03ff0; \n"
"mov.u32 rn_, 0x80008; \n"
"mov.u32 zero, 0; \n"
"and.b32 sign0, $1, 0x80008000; \n"
"and.b32 sign1, $2, 0x80008000; \n"
"prmt.b32 sign, sign0, sign1, 0x7531; \n"
"and.b32 nosign0, $1, 0x7fff7fff; \n"
"and.b32 nosign1, $2, 0x7fff7fff; \n"
".reg .u32 nosign_0_<2>, nosign_1_<2>; \n"
"and.b32 nosign_0_0, nosign0, 0xffff0000; \n"
"max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n"
"min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n"
"and.b32 nosign_0_1, nosign0, 0x0000ffff; \n"
"max.u32 nosign_0_1, nosign_0_1, 0x3800; \n"
"min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n"
"or.b32 nosign0, nosign_0_0, nosign_0_1; \n"
"and.b32 nosign_1_0, nosign1, 0xffff0000; \n"
"max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n"
"min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n"
"and.b32 nosign_1_1, nosign1, 0x0000ffff; \n"
"max.u32 nosign_1_1, nosign_1_1, 0x3800; \n"
"min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n"
"or.b32 nosign1, nosign_1_0, nosign_1_1; \n"
"add.u32 nosign0, nosign0, rn_; \n"
"add.u32 nosign1, nosign1, rn_; \n"
"sub.u32 nosign0, nosign0, 0x38003800; \n"
"sub.u32 nosign1, nosign1, 0x38003800; \n"
"shr.u32 nosign0, nosign0, 4; \n"
"shr.u32 nosign1, nosign1, 4; \n"
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
"or.b32 $0, nosign, sign; \n"
"}";
auto &call = *builder.create(ptxAsm);
auto *o = builder.newOperand("=r");
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
auto fp8x4VecTy = vec_ty(i8_ty, 4);
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
}
static SmallVector<Value>
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static SmallVector<Value>
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[3])};
}
static SmallVector<Value>
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
}
static Value convertBf16ToFp32(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.f32.bf16");
auto res = builder.newOperand("=r");
auto operand = builder.newOperand(v, "h");
cvt(res, operand);
return builder.launch(rewriter, loc, f32_ty, false);
}
static Value convertFp32ToBf16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v) {
PTXBuilder builder;
auto &cvt = *builder.create("cvt.rn.bf16.f32");
auto res = builder.newOperand("=h");
auto operand = builder.newOperand(v, "r");
cvt(res, operand);
// TODO: This is a hack to get the right type. We should be able to invoke
// the type converter
return builder.launch(rewriter, loc, i16_ty, false);
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTensorType = op.from().getType().cast<mlir::RankedTensorType>();
auto dstTensorType = op.result().getType().cast<mlir::RankedTensorType>();
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
auto loc = op->getLoc();
auto elems = getElemsPerThread(dstTensorType);
SmallVector<Value> resultVals;
// Select convertor
if (srcEltType.isa<triton::Float8Type>() ||
dstEltType.isa<triton::Float8Type>()) {
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
const Value &, const Value &,
const Value &, const Value &)>
convertor;
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
convertor = convertFp8x4ToFp16x4;
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
convertor = convertFp8x4ToBf16x4;
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertBf16x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
convertor = convertFp8x4ToFp32x4;
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp32x4ToFp8x4;
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
convertor = convertFp8x4ToFp64x4;
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
convertor = convertFp64x4ToFp8x4;
} else {
assert(false && "unsupported fp8 casting");
}
// Vectorized casting
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getElementsFromStruct(loc, adaptor.from(), rewriter);
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
} else if (srcEltType.isBF16() && dstEltType.isF32()) {
resultVals.emplace_back(convertBf16ToFp32(loc, rewriter, adaptor.from()));
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
resultVals.emplace_back(convertFp32ToBf16(loc, rewriter, adaptor.from()));
} else {
assert(false && "unsupported type casting");
}
assert(resultVals.size() == elems);
auto convertedDstTensorType =
this->getTypeConverter()->convertType(dstTensorType);
auto result = getStructFromElements(loc, resultVals, rewriter,
convertedDstTensorType);
rewriter.replaceOp(op, result);
return success();
}
};
template <typename SourceOp, typename ConcreteT>
class ElementwiseOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType();
Location loc = op->getLoc();
unsigned elems = getElemsPerThread(resultTy);
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Type> types(elems, elemTy);
Type structTy = this->getTypeConverter()->convertType(resultTy);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto operands = getOperands(rewriter, adaptor, elems, loc);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
operands[i], loc);
if (!bool(resultVals[i]))
return failure();
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
protected:
SmallVector<SmallVector<Value>>
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
const unsigned elems, Location loc) const {
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
return operands;
}
};
template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;
explicit ElementwiseOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ElementwiseOpConversionBase<SourceOp, ElementwiseOpConversion>(
typeConverter, benefit) {}
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, operands,
adaptor.getAttributes().getValue());
}
};
struct CmpIOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpIOp,
CmpIOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpIOp, CmpIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
return rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicateToLLVM(op.predicate()), operands[0],
operands[1]);
}
static LLVM::ICmpPredicate
ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__) \
case arith::CmpIPredicate::item__: \
return LLVM::ICmpPredicate::item__
__PRED_ENUM(eq);
__PRED_ENUM(ne);
__PRED_ENUM(sgt);
__PRED_ENUM(sge);
__PRED_ENUM(slt);
__PRED_ENUM(sle);
__PRED_ENUM(ugt);
__PRED_ENUM(uge);
__PRED_ENUM(ult);
__PRED_ENUM(ule);
#undef __PRED_ENUM
}
return LLVM::ICmpPredicate::eq;
}
};
struct CmpFOpConversion
: public ElementwiseOpConversionBase<triton::gpu::CmpFOp,
CmpFOpConversion> {
using Base =
ElementwiseOpConversionBase<triton::gpu::CmpFOp, CmpFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
// An interface to support variant DestOp builder.
static LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, ValueRange operands,
Location loc) {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicateToLLVM(op.predicate()), operands[0],
operands[1]);
}
static LLVM::FCmpPredicate
ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__, item1__) \
case arith::CmpFPredicate::item__: \
return LLVM::FCmpPredicate::item1__
__PRED_ENUM(OEQ, oeq);
__PRED_ENUM(ONE, one);
__PRED_ENUM(OGT, ogt);
__PRED_ENUM(OGE, oge);
__PRED_ENUM(OLT, olt);
__PRED_ENUM(OLE, ole);
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(UGE, uge);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
__PRED_ENUM(UNO, uno);
__PRED_ENUM(AlwaysTrue, _true);
__PRED_ENUM(AlwaysFalse, _false);
#undef __PRED_ENUM
}
return LLVM::FCmpPredicate::_true;
}
};
struct ExtElemwiseOpConversion
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion> {
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
ExtElemwiseOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
StringRef funcName = op.symbol();
if (funcName.empty())
llvm::errs() << "ExtElemwiseOpConversion";
Type funcType = getFunctionType(elemTy, operands);
LLVM::LLVMFuncOp funcOp =
appendOrGetFuncOp(rewriter, op, funcName, funcType);
return rewriter.create<LLVM::CallOp>(loc, funcOp, operands).getResult(0);
}
private:
Type getFunctionType(Type resultType, ValueRange operands) const {
SmallVector<Type> operandTypes(operands.getTypes());
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
triton::ExtElemwiseOp op,
StringRef funcName, Type funcType) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
ret.getOperation()->setAttr(
"libname", StringAttr::get(op->getContext(), op.libname()));
ret.getOperation()->setAttr(
"libpath", StringAttr::get(op->getContext(), op.libpath()));
return ret;
}
};
struct FDivOpConversion
: ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::DivFOp, FDivOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::DivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
PTXBuilder ptxBuilder;
auto &fdiv = *ptxBuilder.create<PTXInstr>("div");
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
if (32 == bitwidth) {
fdiv.o("full").o("f32");
} else if (64 == bitwidth) {
fdiv.o("rn").o("f64");
} else {
assert(0 && bitwidth && "not supported");
}
auto res = ptxBuilder.newOperand(bitwidth == 32 ? "=r" : "=l");
auto lhs = ptxBuilder.newOperand(operands[0], bitwidth == 32 ? "r" : "l");
auto rhs = ptxBuilder.newOperand(operands[1], bitwidth == 32 ? "r" : "l");
fdiv(res, lhs, rhs);
Value ret = ptxBuilder.launch(rewriter, loc, elemTy, false);
return ret;
}
};
struct FMulOpConversion
: ElementwiseOpConversionBase<mlir::arith::MulFOp, FMulOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::MulFOp, FMulOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::MulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
" mov.b16 c, 0x8000U; \n" // 0.0
" fma.rn.bf16 $0, $1, $2, c; } \n";
auto &fMul = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
} else {
return rewriter.create<LLVM::FMulOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct FAddOpConversion
: ElementwiseOpConversionBase<mlir::arith::AddFOp, FAddOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::AddFOp, FAddOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::AddFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
PTXBuilder builder;
auto ptxAsm = "{ .reg .b16 c; \n"
" mov.b16 c, 0x3f80U; \n" // 1.0
" fma.rn.bf16 $0, $1, c, $2; } \n";
auto &fAdd = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
} else {
return rewriter.create<LLVM::FAddOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct FSubOpConversion
: ElementwiseOpConversionBase<mlir::arith::SubFOp, FSubOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SubFOp, FSubOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::SubFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto lhsElemTy = getElementType(op.getLhs());
auto rhsElemTy = getElementType(op.getRhs());
if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) {
PTXBuilder builder;
auto ptxAsm = " { .reg .b16 c; \n"
" mov.b16 c, 0xbf80U; \n" // -1.0
" fma.rn.bf16 $0, $2, c, $1;} \n";
auto &fSub = *builder.create<PTXInstr>(ptxAsm);
auto res = builder.newOperand("=h");
auto lhs = builder.newOperand(operands[0], "h");
auto rhs = builder.newOperand(operands[1], "h");
fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true);
return builder.launch(rewriter, loc, i16_ty, false);
} else {
return rewriter.create<LLVM::FSubOp>(loc, elemTy, operands[0],
operands[1]);
}
}
};
struct SIToFPOpConversion
: ElementwiseOpConversionBase<mlir::arith::SIToFPOp, SIToFPOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::SIToFPOp, SIToFPOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::SIToFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16()) {
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0]);
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value);
} else {
return rewriter.create<LLVM::SIToFPOp>(loc, elemTy, operands[0]);
}
}
};
struct FPToSIOpConversion
: ElementwiseOpConversionBase<mlir::arith::FPToSIOp, FPToSIOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::FPToSIOp, FPToSIOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::FPToSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto value =
FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
return rewriter.create<LLVM::FPToSIOp>(loc, elemTy, value);
} else {
return rewriter.create<LLVM::FPToSIOp>(loc, elemTy, operands[0]);
}
}
};
struct ExtFOpConversion
: ElementwiseOpConversionBase<mlir::arith::ExtFOp, ExtFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::ExtFOp, ExtFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::ExtFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto inElemTy = getElementType(op.getIn());
if (inElemTy.isBF16()) {
auto outElemTy = getElementType(op.getOut());
assert(outElemTy.isF32() && "unsupported conversion");
return FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0]);
} else {
return rewriter.create<LLVM::FPExtOp>(loc, elemTy, operands[0]);
}
}
};
struct TruncFOpConversion
: ElementwiseOpConversionBase<mlir::arith::TruncFOp, TruncFOpConversion> {
using Base =
ElementwiseOpConversionBase<mlir::arith::TruncFOp, TruncFOpConversion>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::arith::TruncFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
auto outElemTy = getElementType(op.getOut());
if (outElemTy.isBF16()) {
auto inElemTy = getElementType(op.getIn());
assert(inElemTy.isF32() && "unsupported conversion");
return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, operands[0]);
} else {
return rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0]);
}
}
};
struct ExpOpConversionApprox
: ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox> {
using Base =
ElementwiseOpConversionBase<mlir::math::ExpOp, ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
return {};
const double log2e = 1.4426950408889634;
Value prod = fmul(f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
}
};
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
#undef POPULATE_TERNARY_OP
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
#undef POPULATE_BINARY_OP
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
patterns.add<FDivOpConversion>(typeConverter, benefit);
patterns.add<FSubOpConversion>(typeConverter, benefit);
patterns.add<FAddOpConversion>(typeConverter, benefit);
patterns.add<FMulOpConversion>(typeConverter, benefit);
patterns.add<ExtFOpConversion>(typeConverter, benefit);
patterns.add<TruncFOpConversion>(typeConverter, benefit);
patterns.add<FPToSIOpConversion>(typeConverter, benefit);
patterns.add<SIToFPOpConversion>(typeConverter, benefit);
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
// For FP64 input type, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
}

View File

@@ -0,0 +1,16 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation,
Value smem, PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,884 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "ConvertLayoutOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {}
// Get corresponding LLVM element values of \param value.
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
ConversionPatternRewriter &rewriter,
Location loc) {
if (!value)
return {};
if (!llValue.getType().isa<LLVM::LLVMStructType>())
return {llValue};
// Here, we assume that all inputs should have a blockedLayout
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
return valueVals;
}
unsigned getVectorSize(Value ptr) const {
return axisAnalysisPass.getPtrVectorSize(ptr);
}
unsigned getMaskAlignment(Value mask) const {
return axisAnalysisPass.getMaskAlignment(mask);
}
protected:
AxisInfoAnalysis &axisAnalysisPass;
};
struct LoadOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
// original values
Value ptr = op.ptr();
Value mask = op.mask();
Value other = op.other();
// adaptor values
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
// Determine the vectorization size
Type valueTy = op.getResult().getType();
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));
// Get the LLVM values for pointers
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
assert(ptrElems.size() == numElems);
// Get the LLVM values for mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(maskElems.size() == numElems);
}
// Get the LLVM values for `other`
// TODO: (goostavz) handle when other is const but not splat, which
// should be rarely seen
bool otherIsSplatConstInt = false;
DenseElementsAttr constAttr;
int64_t splatVal = 0;
if (other && valueElemTy.isa<IntegerType>() &&
matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
// vectorized iteration through all the pointer/mask/other elements
const int valueElemNbits =
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
const int numVecs = numElems / vec;
SmallVector<Value> loadedVals;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
const bool hasL2EvictPolicy = false;
PTXBuilder ptxBuilder;
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
const std::string readConstraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
dstsOpr->listAppend(opr);
}
auto *addrOpr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
// Define the instruction opcode
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.isVolatile())
.global()
.o("ca", op.cache() == triton::CacheModifier::CA)
.o("cg", op.cache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.evict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.evict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
PTXBuilder::Operand *evictOpr{};
// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");
if (!evictOpr)
ld(dstsOpr, addrOpr).predicate(pred, "b");
else
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
if (other) {
for (size_t ii = 0; ii < nWords; ++ii) {
// PTX doesn't support mov.u8, so we need to use mov.u16
auto movWidth = width < 16 ? 16 : width;
PTXInstr &mov =
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
size_t size = width / valueElemNbits;
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
Value v = undef(vecTy);
for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s];
Value sVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
v = insert_element(vecTy, v, falseVal, sVal);
}
v = bitcast(v, IntegerType::get(getContext(), width));
PTXInstr::Operand *opr{};
if (otherIsSplatConstInt)
opr = ptxBuilder.newConstantOperand(splatVal);
else
opr = ptxBuilder.newOperand(v, readConstraint);
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
}
}
// Create inline ASM signature
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
Type retTy = retTys.size() > 1
? LLVM::LLVMStructType::getLiteral(getContext(), retTys)
: retTys[0];
// TODO: if (has_l2_evict_policy)
// auto asmDialectAttr =
// LLVM::AsmDialectAttr::get(rewriter.getContext(),
// LLVM::AsmDialect::AD_ATT);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
// Extract and store return values
SmallVector<Value> rets;
for (unsigned int ii = 0; ii < nWords; ++ii) {
Value curr;
if (retTy.isa<LLVM::LLVMStructType>()) {
curr = extract_val(IntegerType::get(getContext(), width), ret,
rewriter.getI64ArrayAttr(ii));
} else {
curr = ret;
}
curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy,
width / valueElemNbits));
rets.push_back(curr);
}
int tmp = width / valueElemNbits;
for (size_t ii = 0; ii < vec; ++ii) {
Value vecIdx = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp);
Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx);
loadedVals.push_back(loaded);
}
} // end vec
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
struct StoreOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value ptr = op.ptr();
Value mask = op.mask();
Value value = op.value();
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
auto loc = op->getLoc();
MLIRContext *ctx = rewriter.getContext();
auto valueTy = value.getType();
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
unsigned vec = getVectorSize(ptr);
unsigned numElems = getElemsPerThread(ptr.getType());
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(valueElems.size() == maskElems.size());
unsigned maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
}
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNbits = dtsize * 8;
const int numVecs = numElems / vec;
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
// TODO: optimization when ptr is AddPtr with constant offset
size_t in_off = 0;
const size_t maxWordWidth = std::max<size_t>(32, valueElemNbits);
const size_t totalWidth = valueElemNbits * vec;
const size_t width = std::min(totalWidth, maxWordWidth);
const size_t nWords = std::max<size_t>(1, totalWidth / width);
const size_t wordNElems = width / valueElemNbits;
assert(wordNElems * nWords * numVecs == numElems);
// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems);
SmallVector<std::pair<Value, std::string>> asmArgs;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
// llWord is a width-len composition
Value llWord = undef(wordTy);
// Insert each value element to the composition
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
assert(elemOffset < valueElems.size());
Value elem = valueElems[elemOffset];
if (elem.getType().isInteger(1))
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bitcast(elem, valueElemTy);
Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
std::string constraint =
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
asmArgs.emplace_back(llWord, constraint);
}
// Prepare the PTX inline asm.
PTXBuilder ptxBuilder;
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
auto *asmAddr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
auto &ptxStoreInstr =
ptxBuilder.create<>("st")->global().v(nWords).b(width);
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
llvm::SmallVector<Type> argTys({boolTy, ptr.getType()});
argTys.insert(argTys.end(), nWords, valArgTy);
auto asmReturnTy = void_ty(ctx);
ptxBuilder.launch(rewriter, loc, asmReturnTy);
}
rewriter.eraseOp(op);
return success();
}
};
struct AtomicCASOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
Value ptr = op.ptr();
Value llPtr = adaptor.ptr();
Value llCmp = adaptor.cmp();
Value llVal = adaptor.val();
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
Value casPtr = ptrElements[0];
Value casCmp = cmpElements[0];
Value casVal = valElements[0];
PTXBuilder ptxBuilderAtomicCAS;
auto *dstOpr = ptxBuilderAtomicCAS.newOperand("=r");
auto *ptrOpr = ptxBuilderAtomicCAS.newAddrOperand(casPtr, "l");
auto *cmpOpr = ptxBuilderAtomicCAS.newOperand(casCmp, "r");
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
PTXBuilder ptxBuilderStore;
auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "l");
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Value ret = load(atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
return success();
}
};
struct AtomicRMWOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(LLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
converter, allocation, smem, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto atomicRmwAttr = op.atomic_rmw_op();
Value ptr = op.ptr();
Value val = op.val();
Value llPtr = adaptor.ptr();
Value llVal = adaptor.val();
Value llMask = adaptor.mask();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
: op.getResult().getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getElemsPerThread(val.getType());
// vec = 1 for scalar
auto vec = getVectorSize(ptr);
Value mask = int_val(1, 1);
auto tid = tid_val();
// tensor
if (valueTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
auto shape = valueTy.getShape();
auto numElements = product(shape);
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
i32_val(numElements)));
}
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
}
Value rmwPtr = ptrElements[i];
Value rmwMask = maskElements[i];
rmwMask = and_(rmwMask, mask);
std::string sTy;
PTXBuilder ptxBuilderAtomicRMW;
std::string tyId = valueElemNbits * vec == 64
? "l"
: (valueElemNbits * vec == 32 ? "r" : "h");
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId);
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNbits);
switch (atomicRmwAttr) {
case RMWOp::AND:
sTy = "b" + sBits;
break;
case RMWOp::OR:
sTy = "b" + sBits;
break;
case RMWOp::XOR:
sTy = "b" + sBits;
break;
case RMWOp::ADD:
sTy = "s" + sBits;
break;
case RMWOp::FADD:
rmwOp = "add";
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
sTy = "f" + sBits;
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
break;
case RMWOp::MAX:
sTy = "s" + sBits;
break;
case RMWOp::MIN:
sTy = "s" + sBits;
break;
case RMWOp::UMAX:
rmwOp = "max";
sTy = "u" + sBits;
break;
case RMWOp::UMIN:
rmwOp = "min";
sTy = "u" + sBits;
break;
case RMWOp::XCHG:
sTy = "b" + sBits;
break;
default:
return failure();
}
atom.o(rmwOp).o(sTy);
if (valueTy) {
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto retType = vec == 1 ? valueElemTy : vecTy;
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
}
} else {
PTXBuilder ptxBuilderMemfence;
auto memfenc = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
store(old, atomPtr);
barrier();
Value ret = load(atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
if (valueTy) {
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
}
};
struct InsertSliceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<tensor::InsertSliceOp> {
using ConvertTritonGPUOpToLLVMPattern<
tensor::InsertSliceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// %dst = insert_slice %src into %dst[%offsets]
Location loc = op->getLoc();
Value dst = op.dest();
Value src = op.source();
Value res = op.result();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice for now");
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcLayout && "Unexpected srcLayout in InsertSliceOpConversion");
auto dstTy = dst.getType().dyn_cast<RankedTensorType>();
auto dstLayout = dstTy.getEncoding().dyn_cast<SharedEncodingAttr>();
auto llDst = adaptor.dest();
assert(dstLayout && "Unexpected dstLayout in InsertSliceOpConversion");
assert(op.hasUnitStride() &&
"Only unit stride supported by InsertSliceOpConversion");
// newBase = base + offset
// Triton support either static and dynamic offsets
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
SmallVector<Value, 4> offsets;
SmallVector<Value, 4> srcStrides;
auto mixedOffsets = op.getMixedOffsets();
for (auto i = 0; i < mixedOffsets.size(); ++i) {
if (op.isDynamicOffset(i)) {
offsets.emplace_back(adaptor.offsets()[i]);
} else {
offsets.emplace_back(i32_val(op.getStaticOffset(i)));
}
// Like insert_slice_async, we only support slice from one dimension,
// which has a slice size of 1
if (op.getStaticSize(i) != 1) {
srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original strides of the shared memory
// object
auto offset = dot(rewriter, loc, offsets, smemObj.strides);
auto elemTy = getTypeConverter()->convertType(dstTy.getElementType());
auto elemPtrTy = ptr_ty(elemTy, 3);
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
auto llSrc = adaptor.source();
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
elemTy, loc, rewriter);
// Barrier is not necessary.
// The membar pass knows that it writes to shared memory and will handle it
// properly.
rewriter.replaceOp(op, llDst);
return success();
}
};
struct InsertSliceAsyncOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, indexCacheInfo, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::gpu::InsertSliceAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// insert_slice_async %src, %dst, %index, %mask, %other
auto loc = op.getLoc();
Value src = op.src();
Value dst = op.dst();
Value res = op.result();
Value mask = op.mask();
Value other = op.other();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice_async for now");
auto srcTy = src.getType().cast<RankedTensorType>();
auto resTy = dst.getType().cast<RankedTensorType>();
auto resElemTy = getTypeConverter()->convertType(resTy.getElementType());
auto srcBlockedLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto resSharedLayout = resTy.getEncoding().cast<SharedEncodingAttr>();
auto srcShape = srcTy.getShape();
assert(srcShape.size() == 2 &&
"insert_slice_async: Unexpected rank of %src");
Value llDst = adaptor.dst();
Value llSrc = adaptor.src();
Value llMask = adaptor.mask();
Value llOther = adaptor.other();
Value llIndex = adaptor.index();
// %src
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
// %dst
auto dstTy = dst.getType().cast<RankedTensorType>();
auto dstShape = dstTy.getShape();
auto smemObj = getSharedMemoryObjectFromStruct(loc, llDst, rewriter);
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
SmallVector<Value, 4> offsetVals;
SmallVector<Value, 4> srcStrides;
for (auto i = 0; i < dstShape.size(); ++i) {
if (i == axis) {
offsetVals.emplace_back(llIndex);
} else {
offsetVals.emplace_back(i32_val(0));
srcStrides.emplace_back(smemObj.strides[i]);
}
}
// Compute the offset based on the original dimensions of the shared
// memory object
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
auto dstPtrTy = ptr_ty(resElemTy, 3);
Value dstPtrBase = gep(dstPtrTy, smemObj.base, dstOffset);
// %mask
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
assert(srcElems.size() == maskElems.size());
}
// %other
SmallVector<Value> otherElems;
if (llOther) {
// FIXME(Keren): always assume other is 0 for now
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc);
assert(srcElems.size() == otherElems.size());
}
unsigned inVec = getVectorSize(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
unsigned numElems = getElemsPerThread(srcTy);
unsigned perPhase = resSharedLayout.getPerPhase();
unsigned maxPhase = resSharedLayout.getMaxPhase();
auto sizePerThread = srcBlockedLayout.getSizePerThread();
auto threadsPerCTA = getThreadsPerCTA(srcBlockedLayout);
auto inOrder = srcBlockedLayout.getOrder();
// If perPhase * maxPhase > threadsPerCTA, we will have elements
// that share the same tile indices. The index calculation will
// be cached.
auto numSwizzleRows = std::max<unsigned>(
(perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1);
// A sharedLayout encoding has a "vec" parameter.
// On the column dimension, if inVec > outVec, it means we have to divide
// single vector read into multiple ones
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
// <<tileVecIdxRow, tileVecIdxCol>, TileOffset>
DenseMap<std::pair<unsigned, unsigned>, Value> tileOffsetMap;
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
// minVec = 2, inVec = 4, outVec = 2
// baseOffsetCol = 0 baseOffsetCol = 0
// tileVecIdxCol = 0 tileVecIdxCol = 1
// -/\- -/\-
// [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
// baseOffsetRow [|x x| |x x| x x x x x]
// [|x x| |x x| x x x x x]
auto vecIdx = elemIdx / minVec;
auto vecIdxCol = vecIdx % (sizePerThread[inOrder[0]] / minVec);
auto vecIdxRow = vecIdx / (sizePerThread[inOrder[0]] / minVec);
auto baseOffsetCol =
vecIdxCol / numVecCols * numVecCols * threadsPerCTA[inOrder[0]];
auto baseOffsetRow = vecIdxRow / numSwizzleRows * numSwizzleRows *
threadsPerCTA[inOrder[1]];
auto tileVecIdxCol = vecIdxCol % numVecCols;
auto tileVecIdxRow = vecIdxRow % numSwizzleRows;
if (!tileOffsetMap.count({tileVecIdxRow, tileVecIdxCol})) {
// Swizzling
// Since the swizzling index is related to outVec, and we know minVec
// already, inVec doesn't matter
//
// (Numbers represent row indices)
// Example1:
// outVec = 2, inVec = 2, minVec = 2
// outVec = 2, inVec = 4, minVec = 2
// | [1 2] [3 4] [5 6] ... |
// | [3 4] [1 2] [7 8] ... |
// | [5 6] [7 8] [1 2] ... |
// Example2:
// outVec = 4, inVec = 2, minVec = 2
// | [1 2 3 4] [5 6 7 8] [9 10 11 12] ... |
// | [5 6 7 8] [1 2 3 4] [13 14 15 16] ... |
// | [9 10 11 12] [13 14 15 16] [1 2 3 4] ... |
auto srcIdx = srcIndices[tileVecIdxRow * sizePerThread[inOrder[0]]];
Value phase = urem(udiv(srcIdx[inOrder[1]], i32_val(perPhase)),
i32_val(maxPhase));
// srcShape and smemObj.shape maybe different if smemObj is a
// slice of the original shared memory object.
// So we need to use the original shape to compute the offset
Value rowOffset = mul(srcIdx[inOrder[1]], srcStrides[inOrder[1]]);
Value colOffset =
add(srcIdx[inOrder[0]], i32_val(tileVecIdxCol * minVec));
Value swizzleIdx = udiv(colOffset, i32_val(outVec));
Value swizzleColOffset =
add(mul(xor_(swizzleIdx, phase), i32_val(outVec)),
urem(colOffset, i32_val(outVec)));
Value tileOffset = add(rowOffset, swizzleColOffset);
tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}] =
gep(dstPtrTy, dstPtrBase, tileOffset);
}
// 16 * 8 = 128bits
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto numWords = vecBitWidth / bitWidth;
auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth();
// Tune CG and CA here.
auto byteWidth = bitWidth / 8;
CacheModifier srcCacheModifier =
byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA;
assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4);
auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8;
Value tileOffset = tileOffsetMap[{tileVecIdxRow, tileVecIdxCol}];
Value baseOffset =
add(mul(i32_val(baseOffsetRow), srcStrides[inOrder[1]]),
i32_val(baseOffsetCol));
Value basePtr = gep(dstPtrTy, tileOffset, baseOffset);
for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) {
PTXBuilder ptxBuilder;
auto wordElemIdx = wordIdx * numWordElems;
auto &copyAsyncOp =
*ptxBuilder.create<PTXCpAsyncLoadInstr>(srcCacheModifier);
auto *dstOperand =
ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth);
auto *srcOperand =
ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l");
auto *copySize = ptxBuilder.newConstantOperand(byteWidth);
auto *srcSize = copySize;
if (op.mask()) {
// We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now.
auto selectOp = select(maskElems[elemIdx + wordElemIdx],
i32_val(byteWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
}
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
}
}
PTXBuilder ptxBuilder;
ptxBuilder.create<>("cp.async.commit_group")->operator()();
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
rewriter.replaceOp(op, llDst);
return success();
}
};
void populateLoadStoreOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
axisInfoAnalysis, benefit);
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, axisInfoAnalysis,
benefit);
}

View File

@@ -0,0 +1,16 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_LOAD_STORE_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_LOAD_STORE_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
#endif

View File

@@ -1,8 +1,10 @@
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/raw_ostream.h"
#include <sstream> // unify to llvm::raw_string_ostream ?
// TODO(Superjomn): unify to llvm::raw_string_ostream
#include <sstream>
namespace mlir {
namespace triton {

View File

@@ -0,0 +1,490 @@
#include "ReduceOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::shflSync;
using ::mlir::LLVM::storeShared;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::getOrder;
struct ReduceOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::ReduceOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
triton::ReduceOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (ReduceOpHelper(op).isFastReduction())
return matchAndRewriteFast(op, adaptor, rewriter);
return matchAndRewriteBasic(op, adaptor, rewriter);
}
private:
void accumulate(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value cur, bool isFirst) const {
if (isFirst) {
acc = cur;
return;
}
switch (redOp) {
case RedOp::ADD:
acc = add(acc, cur);
break;
case RedOp::FADD:
acc = fadd(acc.getType(), acc, cur);
break;
case RedOp::MIN:
acc = smin(acc, cur);
break;
case RedOp::MAX:
acc = smax(acc, cur);
break;
case RedOp::UMIN:
acc = umin(acc, cur);
break;
case RedOp::UMAX:
acc = umax(acc, cur);
break;
case RedOp::FMIN:
acc = fmin(acc, cur);
break;
case RedOp::FMAX:
acc = fmax(acc, cur);
break;
case RedOp::XOR:
acc = xor_(acc, cur);
break;
case RedOp::ARGMIN:
case RedOp::ARGMAX:
case RedOp::ARGUMIN:
case RedOp::ARGUMAX:
case RedOp::ARGFMIN:
case RedOp::ARGFMAX:
llvm::report_fatal_error(
"This accumulate implementation is not for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
}
}
void accumulateWithIndex(ConversionPatternRewriter &rewriter, Location loc,
RedOp redOp, Value &acc, Value &accIndex, Value cur,
Value curIndex, bool isFirst) const {
if (isFirst) {
acc = cur;
accIndex = curIndex;
return;
}
switch (redOp) {
case RedOp::ARGMIN:
accIndex = select(
icmp_slt(acc, cur), accIndex,
select(icmp_sgt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smin(acc, cur);
break;
case RedOp::ARGMAX:
accIndex = select(
icmp_sgt(acc, cur), accIndex,
select(icmp_slt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = smax(acc, cur);
break;
case RedOp::ARGUMIN:
accIndex = select(
icmp_ult(acc, cur), accIndex,
select(icmp_ugt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umin(acc, cur);
break;
case RedOp::ARGUMAX:
accIndex = select(
icmp_ugt(acc, cur), accIndex,
select(icmp_ult(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = umax(acc, cur);
break;
case RedOp::ARGFMIN:
accIndex = select(
fcmp_olt(acc, cur), accIndex,
select(fcmp_ogt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmin(acc, cur);
break;
case RedOp::ARGFMAX:
accIndex = select(
fcmp_ogt(acc, cur), accIndex,
select(fcmp_olt(acc, cur), curIndex, smin(accIndex, curIndex)));
acc = fmax(acc, cur);
break;
case RedOp::ADD:
case RedOp::FADD:
case RedOp::MIN:
case RedOp::MAX:
case RedOp::UMIN:
case RedOp::UMAX:
case RedOp::FMIN:
case RedOp::FMAX:
case RedOp::XOR:
llvm::report_fatal_error(
"This accumulate implementation is only for argmin / argmax");
default:
llvm::report_fatal_error("Unsupported reduce op");
}
}
// Use shared memory for reduction within warps and across warps
LogicalResult
matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
unsigned axis = op.axis();
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().cast<BlockedEncodingAttr>();
auto srcOrd = srcLayout.getOrder();
auto srcShape = srcTy.getShape();
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
ReduceOpHelper helper(op);
auto smemShape = helper.getScratchConfigBasic();
unsigned elems = product<unsigned>(smemShape);
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key],
accIndices[key], srcValues[i], curIndex, isFirst);
}
if (isFirst)
indices[key] = srcIndices[i];
}
// cached int32 constants
std::map<int, Value> ints;
ints[0] = i32_val(0);
for (int N = smemShape[axis] / 2; N > 0; N >>= 1)
ints[N] = i32_val(N);
Value sizePerThread = i32_val(srcLayout.getSizePerThread()[axis]);
// reduce across threads
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = udiv(writeIdx[axis], sizePerThread);
Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
store(acc, writePtr);
if (withIndex)
store(accIndex, indexWritePtr);
SmallVector<Value> readIdx(writeIdx.size(), ints[0]);
for (int N = smemShape[axis] / 2; N > 0; N >>= 1) {
readIdx[axis] = ints[N];
Value readMask = icmp_slt(writeIdx[axis], ints[N]);
Value readOffset = select(
readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd),
ints[0]);
Value readPtr = gep(elemPtrTy, writePtr, readOffset);
barrier();
if (!withIndex) {
Value cur = load(readPtr);
accumulate(rewriter, loc, op.redOp(), acc, cur, false);
barrier();
store(acc, writePtr);
} else {
Value cur = load(readPtr);
Value indexReadPtr = gep(indexPtrTy, indexWritePtr, readOffset);
Value curIndex = load(indexReadPtr);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, cur,
curIndex, false);
barrier();
store(acc, writePtr);
store(accIndex, indexWritePtr);
}
}
}
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (unsigned i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, ints[0]);
Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
return success();
}
// Use warp shuffle for reduction within warps and shared memory for data
// exchange across warps
LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
unsigned axis = adaptor.axis();
bool withIndex = triton::ReduceOp::withIndex(op.redOp());
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto srcRank = srcTy.getRank();
auto order = getOrder(srcLayout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout);
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto llvmIndexTy = getTypeConverter()->getIndexType();
auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3);
auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, elemPtrTy);
ReduceOpHelper helper(op);
auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned srcElems = getElemsPerThread(srcTy);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
SmallVector<SmallVector<unsigned>> offset =
emitOffsetForLayout(srcLayout, srcShape);
std::map<SmallVector<unsigned>, Value> accs;
std::map<SmallVector<unsigned>, Value> accIndices;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
// reduce within threads
for (unsigned i = 0; i < srcElems; ++i) {
SmallVector<unsigned> key = offset[i];
key[axis] = 0;
bool isFirst = accs.find(key) == accs.end();
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), accs[key], srcValues[i], isFirst);
} else {
Value curIndex = srcIndices[i][axis];
accumulateWithIndex(rewriter, loc, op.redOp(), accs[key],
accIndices[key], srcValues[i], curIndex, isFirst);
}
if (isFirst)
indices[key] = srcIndices[i];
}
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
Value laneIdAxis = multiDimLaneId[axis];
Value warpIdAxis = multiDimWarpId[axis];
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
Value warpZero = icmp_eq(warpIdAxis, zero);
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
Value acc = it.second;
Value accIndex;
if (withIndex)
accIndex = accIndices[key];
// Reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false);
}
}
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
storeShared(rewriter, loc, writePtr, acc, laneZero);
if (withIndex) {
Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, indexWritePtr, accIndex, laneZero);
}
}
barrier();
// The second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn
// where sizeInterWarps is 2^m
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
// FIXME(Qingyi): need predicate icmp_slt(threadId,
// i32_val(sizeInerWarps))
Value acc = load(readPtr);
Value accIndex;
if (withIndex) {
Value readIndexPtr = gep(indexPtrTy, indexSmemBase, readOffset);
accIndex = load(readIndexPtr);
}
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else {
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false);
}
}
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
storeShared(rewriter, loc, writePtr, acc, pred);
if (withIndex) {
Value writeIndexPtr = gep(indexPtrTy, indexSmemBase, writeOffset);
storeShared(rewriter, loc, writeIndexPtr, accIndex, pred);
}
if (round != elemsPerThread - 1) {
readOffset = add(readOffset, i32_val(numThreads));
}
}
// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier incase the layouts are accepted.
barrier();
// set output values
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices =
emitIndices(loc, rewriter, resultLayout, resultShape);
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
}
SmallVector<Type> resultTypes(resultElems,
withIndex ? llvmIndexTy : llvmElemTy);
Type structTy =
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
} else {
// 0d-tensor -> scalar
Value resultVal = withIndex ? load(indexSmemBase) : load(smemBase);
rewriter.replaceOp(op, resultVal);
}
return success();
}
};
void populateReduceOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, benefit);
}

View File

@@ -0,0 +1,16 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCE_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCE_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);
#endif

View File

@@ -0,0 +1,661 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_BASE_H
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
#include "triton/Analysis/Allocation.h"
//
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
// All the rights are reserved by the LLVM community.
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
private:
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : attrs) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == "std.varargs" ||
(filterArgAttrs &&
attr.getName() == FunctionOpInterface::getArgDictAttrName()))
continue;
result.push_back(attr);
}
}
/// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get(b.getContext(),
b.getNamedAttr("llvm.struct_attrs", attrs));
}
protected:
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
LLVM::LLVMFuncOp
convertFuncOpToLLVMFuncOp(FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
if (!llvmType)
return nullptr;
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
(funcOp.getNumResults() == 1)
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
attributes.push_back(rewriter.getNamedAttr(
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto mapping = result.getInputMapping(i);
assert(mapping && "unexpected deletion of function argument");
for (size_t j = 0; j < mapping->size; ++j)
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
}
attributes.push_back(
rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
rewriter.getArrayAttr(newArgAttrs)));
}
for (const auto &pair : llvm::enumerate(attributes)) {
if (pair.value().getName() == "llvm.linkage") {
attributes.erase(attributes.begin() + pair.index());
break;
}
}
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
LLVM::Linkage linkage = LLVM::Linkage::External;
if (funcOp->hasAttr("llvm.linkage")) {
auto attr =
funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
if (!attr) {
funcOp->emitError()
<< "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
return nullptr;
}
linkage = attr.getLinkage();
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
return newFuncOp;
}
};
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
struct CacheKeyDenseMapInfo {
static IndexCacheKeyT getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{});
}
static IndexCacheKeyT getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return std::make_pair(
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
}
static unsigned getHashValue(IndexCacheKeyT key) {
return llvm::hash_combine(
mlir::hash_value(key.first),
llvm::hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
return LHS == RHS;
}
};
class ConvertTritonGPUOpToLLVMPatternBase {
public:
// Two levels of value cache in emitting indices calculation:
// Key: pair<layout, shape>
struct IndexCacheInfo {
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
*baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo> *indexCache;
OpBuilder::InsertPoint *indexInsertPoint;
};
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
: converter(&typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem)
: converter(&typeConverter), allocation(allocation), smem(smem) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), indexCacheInfo(indexCacheInfo),
allocation(allocation), smem(smem) {}
LLVMTypeConverter *getTypeConverter() const { return converter; }
static Value
getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
return getStructFromElements(loc, elems, rewriter, structTy);
}
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
auto cast = rewriter.create<UnrealizedConversionCastOp>(
loc, TypeRange{llvmIndexTy},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
Value threadId = cast.getResult(0);
return threadId;
}
// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
template <typename T>
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
Value offVal = idx_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
}
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
unsigned rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered);
SmallVector<Value> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
SmallVector<Value> delinearize(ConversionPatternRewriter &rewriter,
Location loc, Value linear,
ArrayRef<unsigned> shape) const {
unsigned rank = shape.size();
assert(rank > 0);
SmallVector<Value> multiDim(rank);
if (rank == 1) {
multiDim[0] = linear;
} else {
Value remained = linear;
for (auto &&en : llvm::enumerate(shape.drop_back())) {
Value dimSize = idx_val(en.value());
multiDim[en.index()] = urem(remained, dimSize);
remained = udiv(remained, dimSize);
}
multiDim[rank - 1] = remained;
}
return multiDim;
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) const {
return linearize(rewriter, loc, reorder<Value>(multiDim, order),
reorder<unsigned>(shape, order));
}
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
auto rank = multiDim.size();
Value linear = idx_val(0);
if (rank > 0) {
linear = multiDim.back();
for (auto [dim, dimShape] :
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
Value dimSize = idx_val(dimShape);
linear = add(mul(linear, dimSize), dim);
}
}
return linear;
}
Value dot(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
assert(offsets.size() == strides.size());
Value ret = idx_val(0);
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
ret = add(ret, mul(offset, stride));
}
return ret;
}
struct SmallVectorKeyInfo {
static unsigned getHashValue(const SmallVector<unsigned> &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const SmallVector<unsigned> &lhs,
const SmallVector<unsigned> &rhs) {
return lhs == rhs;
}
static SmallVector<unsigned> getEmptyKey() {
return SmallVector<unsigned>();
}
static SmallVector<unsigned> getTombstoneKey() {
return {std::numeric_limits<unsigned>::max()};
}
};
// -----------------------------------------------------------------------
// Get offsets / indices for any layout
// -----------------------------------------------------------------------
SmallVector<Value> emitBaseIndexForLayout(Location loc,
ConversionPatternRewriter &rewriter,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
auto cache = indexCacheInfo.baseIndexCache;
assert(cache && "baseIndexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
restoreInsertionPointIfSet(insertPt, rewriter);
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result =
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
return result;
}
}
SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
return emitOffsetForBlockedLayout(blockedLayout, shape);
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isVolta())
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
}
llvm_unreachable("unsupported emitOffsetForLayout");
}
// -----------------------------------------------------------------------
// Emit indices
// -----------------------------------------------------------------------
SmallVector<SmallVector<Value>> emitIndices(Location loc,
ConversionPatternRewriter &b,
const Attribute &layout,
ArrayRef<int64_t> shape) const {
IndexCacheKeyT key(layout, llvm::to_vector(shape));
auto cache = indexCacheInfo.indexCache;
assert(cache && "indexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(b);
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, shape);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"implemented yet");
}
cache->insert(std::make_pair(key, result));
*insertPt = b.saveInsertionPoint();
return result;
}
}
private:
void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt,
ConversionPatternRewriter &rewriter) const {
if (insertPt->isSet()) {
rewriter.restoreInsertionPoint(*insertPt);
} else {
auto func =
rewriter.getInsertionPoint()->getParentOfType<LLVM::LLVMFuncOp>();
rewriter.setInsertionPointToStart(&func.getBody().front());
}
}
// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
// Get an index-base for each dimension for a \param blocked_layout.
SmallVector<Value>
emitBaseIndexForBlockedLayout(Location loc,
ConversionPatternRewriter &rewriter,
const BlockedEncodingAttr &blocked_layout,
ArrayRef<int64_t> shape) const {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blocked_layout.getSizePerThread();
auto threadsPerWarp = blocked_layout.getThreadsPerWarp();
auto warpsPerCTA = blocked_layout.getWarpsPerCTA();
auto order = blocked_layout.getOrder();
unsigned rank = shape.size();
// delinearize threadId to get the base index
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
SmallVector<Value> multiDimBase(rank);
for (unsigned k = 0; k < rank; ++k) {
// Wrap around multiDimWarpId/multiDimThreadId incase
// shape[k] > shapePerCTA[k]
auto maxWarps =
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
// multiDimBase[k] = (multiDimThreadId[k] +
// multiDimWarpId[k] * threadsPerWarp[k]) *
// sizePerThread[k];
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
Value sizePerThreadK = idx_val(sizePerThread[k]);
multiDimBase[k] =
mul(sizePerThreadK, add(multiDimThreadId[k],
mul(multiDimWarpId[k], threadsPerWarpK)));
}
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
ArrayRef<int64_t> shape) const {
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
auto order = blockedLayout.getOrder();
unsigned rank = shape.size();
SmallVector<unsigned> shapePerCTA = getShapePerCTA(blockedLayout);
SmallVector<unsigned> tilesPerDim(rank);
for (unsigned k = 0; k < rank; ++k)
tilesPerDim[k] = ceil<unsigned>(shape[k], shapePerCTA[k]);
SmallVector<SmallVector<unsigned>> offset(rank);
for (unsigned k = 0; k < rank; ++k) {
// 1 block in minimum if shape[k] is less than shapePerCTA[k]
for (unsigned blockOffset = 0; blockOffset < tilesPerDim[k];
++blockOffset)
for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset)
for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k];
++threadOffset)
for (unsigned elemOffset = 0; elemOffset < sizePerThread[k];
++elemOffset)
offset[k].push_back(blockOffset * sizePerThread[k] *
threadsPerWarp[k] * warpsPerCTA[k] +
warpOffset * sizePerThread[k] *
threadsPerWarp[k] +
threadOffset * sizePerThread[k] + elemOffset);
}
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
for (unsigned n = 0; n < elemsPerThread; ++n) {
unsigned linearNanoTileId = n / totalSizePerThread;
unsigned linearNanoTileElemId = n % totalSizePerThread;
SmallVector<unsigned> multiDimNanoTileId =
getMultiDimIndex<unsigned>(linearNanoTileId, tilesPerDim, order);
SmallVector<unsigned> multiDimNanoTileElemId = getMultiDimIndex<unsigned>(
linearNanoTileElemId, sizePerThread, order);
for (unsigned k = 0; k < rank; ++k) {
unsigned reorderedMultiDimId =
multiDimNanoTileId[k] *
(sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) +
multiDimNanoTileElemId[k];
reorderedOffset[n].push_back(offset[k][reorderedMultiDimId]);
}
}
return reorderedOffset;
}
// -----------------------------------------------------------------------
// Mma layout indices
// -----------------------------------------------------------------------
SmallVector<Value>
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
llvm_unreachable("emitIndicesForMmaLayoutV1 not implemented");
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 2, j});
ret.push_back({i + 2, j + 1});
ret.push_back({i, j + 8});
ret.push_back({i, j + 9});
ret.push_back({i + 2, j + 8});
ret.push_back({i + 2, j + 9});
}
}
return ret;
}
SmallVector<Value>
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
idx_val(_warpsPerCTA[1])};
Value threadId = getThreadId(rewriter, loc);
Value warpSize = idx_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
Value warpId0 = urem(warpId, warpsPerCTA[0]);
Value warpId1 = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
Value offWarp0 = mul(warpId0, idx_val(16));
Value offWarp1 = mul(warpId1, idx_val(8));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape) const {
SmallVector<SmallVector<unsigned>> ret;
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
for (unsigned j = 0; j < shape[1]; j += getShapePerCTA(mmaLayout)[1]) {
ret.push_back({i, j});
ret.push_back({i, j + 1});
ret.push_back({i + 8, j});
ret.push_back({i + 8, j + 1});
}
}
return ret;
}
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
// TODO: [phil] redundant indices computation do not appear to hurt
// performance much, but they could still significantly slow down
// computations.
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
Location loc, ConversionPatternRewriter &rewriter,
const Attribute &layout, ArrayRef<int64_t> shape) const {
// step 1, delinearize threadId to get the base index
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
// step 2, get offset of each element
auto offset = emitOffsetForLayout(layout, shape);
// step 3, add offset to base, and reorder the sequence of indices to
// guarantee that elems in the same sizePerThread are adjacent in order
unsigned rank = shape.size();
unsigned elemsPerThread = offset.size();
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
SmallVector<Value>(rank));
for (unsigned n = 0; n < elemsPerThread; ++n)
for (unsigned k = 0; k < rank; ++k)
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
return multiDimIdx;
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
size_t rank = shape.size();
auto parentIndices =
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i) {
SmallVector<Value> indices = parentIndices[i];
indices.erase(indices.begin() + dim);
resultIndices.push_back(indices);
}
return resultIndices;
}
protected:
LLVMTypeConverter *converter;
const Allocation *allocation;
Value smem;
IndexCacheInfo indexCacheInfo;
};
template <typename SourceOp>
class ConvertTritonGPUOpToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp>,
public ConvertTritonGPUOpToLLVMPatternBase {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
indexCacheInfo) {}
protected:
LLVMTypeConverter *getTypeConverter() const {
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
}
};
#endif

View File

@@ -0,0 +1,417 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "ConvertLayoutOpToLLVM.h"
#include "DotOpToLLVM.h"
#include "ElementwiseOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include "ReduceOpToLLVM.h"
#include "TritonGPUToLLVM.h"
#include "TypeConverter.h"
#include "ViewOpToLLVM.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace mlir {
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addIllegalDialect<mlir::StandardOpsDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
addIllegalOp<mlir::FuncOp>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
} // namespace mlir
namespace {
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
LogicalResult
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
return failure();
auto ctx = funcOp->getContext();
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid",
rewriter.getIntegerAttr(i32_ty, 32 * numWarps));
rewriter.eraseOp(funcOp);
return success();
}
private:
int numWarps{0};
};
class ConvertTritonGPUToLLVM
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
public:
explicit ConvertTritonGPUToLLVM(int computeCapability)
: computeCapability(computeCapability) {}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
TritonLLVMConversionTarget target(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// Step 1: Decompose unoptimized layout conversions to use shared memory
// Step 2: Decompose insert_slice_async to use load + insert_slice for
// pre-Ampere architectures or unsupported vectorized load sizes
// Step 3: Allocate shared memories and insert barriers
// Step 4: Convert SCF to CFG
// Step 5: Convert FuncOp to LLVMFuncOp via partial conversion
// Step 6: Get axis and shared memory info
// Step 7: Convert the rest of ops via partial conversion
//
// The reason for putting step 3 before step 4 is that the membar
// analysis currently only supports SCF but not CFG. The reason for a
// separation between 5/7 is that, step 6 is out of the scope of Dialect
// Conversion, thus we need to make sure the smem is not revised during the
// conversion of step 7.
// Step 1
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
// Step 2
decomposeInsertSliceAsyncOp(mod);
// Step 3
Allocation allocation(mod);
MembarAnalysis membarPass(&allocation);
membarPass.run();
// Step 4
RewritePatternSet scf_patterns(context);
mlir::populateLoopToStdConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
scf::WhileOp, scf::ExecuteRegionOp>();
scf_target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(mod, scf_target, std::move(scf_patterns))))
return signalPassFailure();
// Step 5
RewritePatternSet func_patterns(context);
func_patterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(func_patterns))))
return signalPassFailure();
// Step 6 - get axis and shared memory info
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// Step 7 - rewrite rest of ops
// We set a higher benefit here to ensure triton's patterns runs before
// arith patterns for some encoding not supported by the community
// patterns.
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
RewritePatternSet patterns(context);
// Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ElementwiseOp
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
private:
Value smem;
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
CacheKeyDenseMapInfo>
indexCache;
int computeCapability{};
void initSharedMemory(size_t size,
TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
OpBuilder b(mod.getBodyRegion());
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
// Set array size 0 and external linkage indicates that we use dynamic
// shared allocation to allow a larger shared memory size for each kernel.
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
auto global = b.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
mlir::gpu::GPUDialect::getWorkgroupAddressSpace());
SmallVector<LLVM::LLVMFuncOp> funcs;
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
assert(funcs.size() == 1 &&
"Inliner pass is expected before TritonGPUToLLVM");
b.setInsertionPointToStart(&funcs[0].getBody().front());
smem = b.create<LLVM::AddressOfOp>(loc, global);
auto ptrTy =
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
// Replace `mma -> dot_op` with `mma -> blocked -> dot_op`
// unless certain conditions are met
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcMma =
srcType.getEncoding().dyn_cast<triton::gpu::MmaEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcMma && dstDotOp && !isMmaToDotShortcut(srcMma, dstDotOp)) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::BlockedEncodingAttr::get(
mod.getContext(), srcType.getShape(), getSizePerThread(srcMma),
getOrder(srcMma), numWarps));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeBlockedToDotOperand(ModuleOp mod) const {
// Replace `blocked -> dot_op` with `blocked -> shared -> dot_op`
// because the codegen doesn't handle `blocked -> dot_op` directly
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcBlocked =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
axisInfoAnalysis.run(mod);
// TODO(Keren): This is a hacky knob that may cause performance regression
// when decomposition has been performed. We should remove this knob once we
// have thorough analysis on async wait. Currently, we decompose
// `insert_slice_async` into `load` and `insert_slice` without knowing which
// `async_wait` is responsible for the `insert_slice_async`. To guarantee
// correctness, we blindly set the `async_wait` to wait for all async ops.
//
// There are two options to improve this:
// 1. We can perform a dataflow analysis to find the `async_wait` that is
// responsible for the `insert_slice_async` in the backend.
// 2. We can modify the pipeline to perform the decomposition before the
// `async_wait` is inserted. However, it is also risky because we don't know
// the correct vectorized shape yet in the pipeline pass. Making the
// pipeline pass aware of the vectorization could introduce additional
// dependencies on the AxisInfoAnalysis and the Coalesce analysis.
bool decomposed = false;
// insert_slice_async %src, %dst, %idx, %mask, %other
// =>
// %tmp = load %src, %mask, %other
// %res = insert_slice %tmp into %dst[%idx]
mod.walk([&](triton::gpu::InsertSliceAsyncOp insertSliceAsyncOp) -> void {
OpBuilder builder(insertSliceAsyncOp);
// Get the vectorized load size
auto src = insertSliceAsyncOp.src();
auto dst = insertSliceAsyncOp.dst();
auto srcTy = src.getType().cast<RankedTensorType>();
auto dstTy = dst.getType().cast<RankedTensorType>();
auto srcBlocked =
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
unsigned inVec = axisInfoAnalysis.getPtrVectorSize(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
auto maxBitWidth =
std::max<unsigned>(128, resElemTy.getIntOrFloatBitWidth());
auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec;
auto bitWidth = std::min<unsigned>(maxBitWidth, vecBitWidth);
auto byteWidth = bitWidth / 8;
// If the load byte width is not eligible or the current compute
// capability does not support async copy, then we do decompose
if (triton::gpu::InsertSliceAsyncOp::getEligibleLoadByteWidth(
computeCapability)
.contains(byteWidth))
return;
// load
auto tmpTy =
RankedTensorType::get(srcTy.getShape(), resElemTy, srcBlocked);
auto loadOp = builder.create<triton::LoadOp>(
insertSliceAsyncOp.getLoc(), tmpTy, insertSliceAsyncOp.src(),
insertSliceAsyncOp.mask(), insertSliceAsyncOp.other(),
insertSliceAsyncOp.cache(), insertSliceAsyncOp.evict(),
insertSliceAsyncOp.isVolatile());
// insert_slice
auto axis = insertSliceAsyncOp.axis();
auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); };
auto offsets = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(0));
auto sizes = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
auto strides = SmallVector<OpFoldResult>(dstTy.getRank(), intAttr(1));
offsets[axis] = insertSliceAsyncOp.index();
for (size_t i = 0; i < dstTy.getRank(); i++) {
if (i != axis)
sizes[i] = intAttr(dstTy.getShape()[i]);
}
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.dst(),
offsets, sizes, strides);
// Replace
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
insertSliceAsyncOp.erase();
decomposed = true;
});
mod.walk([&](triton::gpu::AsyncWaitOp asyncWaitOp) -> void {
if (!triton::gpu::AsyncWaitOp::isSupported(computeCapability)) {
// async wait is supported in Ampere and later
asyncWaitOp.erase();
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
auto newAsyncWaitOp =
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
});
}
};
} // anonymous namespace
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonGPUToLLVMPass(int computeCapability) {
return std::make_unique<::ConvertTritonGPUToLLVM>(computeCapability);
}
} // namespace triton
} // namespace mlir

View File

@@ -0,0 +1,150 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/MLIRTypes.h"
#include "DotOpHelpers.h"
#include "Utility.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis = nullptr)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
return convertTritonTensorType(type);
});
// Internally store float8 as int8
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
Type convertTritonPointerType(triton::PointerType type) {
// Recursively translate pointee type
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
type.getAddressSpace());
}
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
auto ctx = type.getContext();
Attribute layout = type.getEncoding();
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
if (layout &&
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
layout.isa<MmaEncodingAttr>())) {
unsigned numElementsPerThread = getElemsPerThread(type);
SmallVector<Type, 4> types(numElementsPerThread,
convertType(type.getElementType()));
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto shared_layout =
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
SmallVector<Type, 4> types;
// base ptr
auto ptrType =
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
// offsets + strides
for (auto i = 0; i < rank * 2; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
} else if (auto dotOpLayout =
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
if (dotOpLayout.getParent()
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
int numElemsPerThread =
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
} else { // for parent is MMA layout
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
auto wpt = mmaLayout.getWarpsPerCTA();
Type elemTy = convertType(type.getElementType());
if (mmaLayout.isAmpere()) {
const llvm::DenseMap<int, Type> targetTyMap = {
{32, elemTy},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
Type targetTy;
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
} else {
assert(false && "Unsupported element type");
}
if (dotOpLayout.getOpIdx() == 0) { // $a
auto elems =
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(elems, targetTy));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
auto elems =
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
return struct_ty(SmallVector<Type>(elems, targetTy));
}
}
if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
// TODO[Superjomn]: Both transA and transB are not available here.
bool trans = false;
// TODO[Superjomn]: The order of A and B are not available here.
SmallVector<unsigned> order({1, 0});
if (trans) {
std::swap(shape[0], shape[1]);
std::swap(order[0], order[1]);
}
if (dotOpLayout.getOpIdx() == 0) { // $a
int elems = helper.numElemsPerThreadA(shape, order);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
if (dotOpLayout.getOpIdx() == 1) { // $b
int elems = helper.numElemsPerThreadB(shape, order);
Type x2Ty = vec_ty(elemTy, 2);
return struct_ty(SmallVector<Type>(elems, x2Ty));
}
}
}
llvm::errs() << "Unexpected dot operand layout detected in "
"TritonToLLVMTypeConverter";
return llvm::None;
}
return llvm::None;
}
};
#endif

View File

@@ -1,34 +1,11 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include <memory>
#include <numeric>
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
@@ -95,6 +72,7 @@
// Types
#define i32_ty rewriter.getIntegerType(32)
#define i16_ty rewriter.getIntegerType(16)
#define ui32_ty rewriter.getIntegerType(32, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
@@ -115,16 +93,79 @@
#define idx_val(...) \
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
__VA_ARGS__)
#define tid_val() getThreadId(rewriter, loc)
namespace mlir {
namespace triton {
// Delinearize supposing order is [0, 1, .. , n]
template <typename T>
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
llvm::ArrayRef<T> shape) {
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearRemain = linearIndex;
llvm::SmallVector<T> multiDimIndex(rank);
for (int i = rank - 1; i >= 0; --i) {
multiDimIndex[i] = linearRemain / accMul;
linearRemain = linearRemain % accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return multiDimIndex;
}
template <typename T>
llvm::SmallVector<T> getMultiDimIndex(T linearIndex, llvm::ArrayRef<T> shape,
llvm::ArrayRef<unsigned> order) {
size_t rank = shape.size();
assert(rank == order.size());
auto reordered = reorder(shape, order);
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
llvm::SmallVector<T> multiDim(rank);
for (unsigned i = 0; i < rank; ++i) {
multiDim[order[i]] = reorderedMultiDim[i];
}
return multiDim;
}
// Linearize supposing order is [0, 1, .. , n]
template <typename T>
static T getLinearIndexImpl(llvm::ArrayRef<T> multiDimIndex,
llvm::ArrayRef<T> shape) {
assert(multiDimIndex.size() == shape.size());
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
size_t rank = shape.size();
T accMul = product(shape.drop_back());
T linearIndex = 0;
for (int i = rank - 1; i >= 0; --i) {
linearIndex += multiDimIndex[i] * accMul;
if (i != 0) {
accMul = accMul / shape[i - 1];
}
}
return linearIndex;
}
template <typename T>
static T getLinearIndex(llvm::ArrayRef<T> multiDimIndex,
llvm::ArrayRef<T> shape,
llvm::ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
reorder(shape, order));
}
} // namespace triton
namespace LLVM {
using namespace mlir::triton;
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
static Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType) {
if (!structType.isa<LLVM::LLVMStructType>()) {
return *resultVals.begin();
}
@@ -138,8 +179,9 @@ Value getStructFromElements(Location loc, ValueRange resultVals,
return llvmStruct;
}
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
static SmallVector<Value>
getElementsFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
llvmStruct.getType().isa<triton::PointerType>() ||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
@@ -155,47 +197,50 @@ SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
}
// Create a 32-bit integer constant.
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
int32_t v) {
auto i32ty = rewriter.getIntegerType(32);
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
IntegerAttr::get(i32ty, v));
}
Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
static Value createConstantF32(Location loc, PatternRewriter &rewriter,
float v) {
auto type = type::f32Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF32FloatAttr(v));
}
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
static Value createConstantF64(Location loc, PatternRewriter &rewriter,
float v) {
auto type = type::f64Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF64FloatAttr(v));
}
// Create an index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
Type ty = converter->convertType(builder.getIndexType());
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
// Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value) {
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
short width, int64_t value) {
Type ty = builder.getIntegerType(width);
return builder.create<LLVM::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, value));
}
/// Helper function to get strides from a given shape and its order
SmallVector<Value>
static SmallVector<Value>
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
Location loc, ConversionPatternRewriter &rewriter) {
auto rank = shape.size();
SmallVector<Value> strides(rank);
auto stride = 1;
int64_t stride = 1;
for (auto idx : order) {
strides[idx] = i32_val(stride);
stride *= shape[idx];
@@ -206,7 +251,7 @@ getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
struct SharedMemoryObject {
Value base; // i32 ptr. The start address of the shared memory object.
// We need to store strides as Values but not integers because the
// extract_slice instruction can take a slice at artibary offsets.
// extract_slice instruction can take a slice at arbitrary offsets.
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
// 32, we need to let the instruction that uses $a to be aware of that.
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
@@ -266,7 +311,7 @@ struct SharedMemoryObject {
}
};
SharedMemoryObject
static SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
@@ -276,8 +321,8 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
}
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
@@ -290,8 +335,8 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
return builder.launch(rewriter, loc, void_ty(ctx));
}
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i) {
static Value shflSync(Location loc, ConversionPatternRewriter &rewriter,
Value val, int i) {
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {

View File

@@ -0,0 +1,268 @@
#include "ViewOpToLLVM.h"
#include "DotOpHelpers.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::DotOpFMAConversionHelper;
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
using ::mlir::LLVM::getElementsFromStruct;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::triton::gpu::getElemsPerThread;
struct SplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
// LLVM::StructType value.
//
// @elemType: the element type in operand.
// @resType: the return type of the Splat-like op.
// @constVal: a LLVM::ConstantOp or other scalar value.
static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
if (tensorTy.getEncoding().isa<BlockedEncodingAttr>() ||
tensorTy.getEncoding().isa<SliceEncodingAttr>()) {
auto srcType = typeConverter->convertType(elemType);
auto llSrc = bitcast(constVal, srcType);
size_t elemsPerThread = getElemsPerThread(tensorTy);
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
return getStructFromElements(loc, elems, rewriter, structTy);
} else if (auto dotLayout =
tensorTy.getEncoding()
.dyn_cast<triton::gpu::DotOperandEncodingAttr>()) {
return convertSplatLikeOpWithDotOperandLayout(
dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else if (auto mmaLayout =
tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>()) {
return convertSplatLikeOpWithMmaLayout(
mmaLayout, resType, elemType, constVal, typeConverter, rewriter, loc);
} else
assert(false && "Unsupported layout found in ConvertSplatLikeOp");
return {};
}
static Value convertSplatLikeOpWithDotOperandLayout(
const triton::gpu::DotOperandEncodingAttr &layout, Type resType,
Type elemType, Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
auto parent = layout.getParent();
int numElems{};
if (auto mmaLayout = parent.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
numElems = layout.getOpIdx() == 0
? MMA16816ConversionHelper::getANumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[0])
: MMA16816ConversionHelper::getBNumElemsPerThread(
tensorTy, mmaLayout.getWarpsPerCTA()[1]);
} else if (mmaLayout.isVolta()) {
DotOpMmaV1ConversionHelper helper(mmaLayout);
numElems = layout.getOpIdx() == 0
? helper.numElemsPerThreadA(shape, {0, 1})
: helper.numElemsPerThreadB(shape, {0, 1});
}
} else if (auto blockedLayout = parent.dyn_cast<BlockedEncodingAttr>()) {
numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout);
} else {
assert(false && "Unsupported layout found");
}
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(numElems, elemType));
return getStructFromElements(loc, SmallVector<Value>(numElems, constVal),
rewriter, structTy);
}
static Value convertSplatLikeOpWithMmaLayout(
const MmaEncodingAttr &layout, Type resType, Type elemType,
Value constVal, TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = resType.cast<RankedTensorType>();
auto shape = tensorTy.getShape();
if (layout.isAmpere()) {
auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy);
size_t fcSize = 4 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(fcSize, elemType));
return getStructFromElements(loc, SmallVector<Value>(fcSize, constVal),
rewriter, structTy);
}
if (layout.isVolta()) {
DotOpMmaV1ConversionHelper helper(layout);
int repM = helper.getRepM(shape[0]);
int repN = helper.getRepN(shape[1]);
// According to mma layout of v1, each thread process 8 elements.
int elems = 8 * repM * repN;
auto structTy = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), SmallVector<Type>(elems, elemType));
return getStructFromElements(loc, SmallVector<Value>(elems, constVal),
rewriter, structTy);
}
assert(false && "Unsupported mma layout found");
return {};
}
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
auto src = adaptor.src();
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
return success();
}
};
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr),
// the logic is the same as triton::SplatOp, so the underlying implementation
// is reused.
struct ArithConstantSplatOpConversion
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
using ConvertTritonGPUOpToLLVMPattern<
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto value = op.getValue();
if (!value.dyn_cast<SplatElementsAttr>())
return failure();
auto loc = op->getLoc();
LLVM::ConstantOp arithConstantOp;
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
auto elemType = values.getElementType();
Attribute val;
if (elemType.isBF16() || type::isFloat(elemType)) {
val = values.getValues<FloatAttr>()[0];
} else if (type::isInt(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, llStruct);
return success();
}
};
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
using OpAdaptor = typename CatOp::Adaptor;
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
// unpack input values
auto lhsVals = getElementsFromStruct(loc, adaptor.lhs(), rewriter);
auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter);
// concatenate (and potentially reorder) values
SmallVector<Value> retVals;
for (Value v : lhsVals)
retVals.push_back(v);
for (Value v : rhsVals)
retVals.push_back(v);
// pack and replace
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
rewriter.replaceOp(op, ret);
return success();
}
};
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We cannot directly run `rewriter.replaceOp(op, adaptor.src())`
// due to MLIR's restrictions
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
unsigned elems = getElemsPerThread(resultTy);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
Value view = getStructFromElements(loc, vals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
struct TransOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcSmemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter);
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
srcSmemObj.strides[0]};
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
srcSmemObj.offsets[0]};
auto dstSmemObj =
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
};
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
patterns.add<TransOpConversion>(typeConverter, benefit);
}

View File

@@ -0,0 +1,15 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_VIEW_OP_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_VIEW_OP_H
#include "TritonGPUToLLVMBase.h"
using namespace mlir;
using namespace mlir::triton;
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit);
#endif

View File

@@ -1,5 +1,5 @@
add_mlir_conversion_library(TritonToTritonGPU
TritonToTritonGPU.cpp
TritonToTritonGPUPass.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonToTritonGPU

View File

@@ -1,16 +1,24 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "../PassDetail.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace {
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
@@ -221,8 +229,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
Attribute dEncoding = retType.cast<RankedTensorType>().getEncoding();
RankedTensorType origType = op.getType().cast<RankedTensorType>();
auto origShape = origType.getShape();
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
int numWarps = typeConverter->getNumWarps();
SmallVector<unsigned> retSizePerThread = {1, 1};
if (origShape[0] * origShape[1] / (numWarps * 32) >= 4)
retSizePerThread = {2, 2};
if (origShape[0] * origShape[1] / (numWarps * 32) >= 16)
retSizePerThread = {4, 4};
SmallVector<unsigned> retOrder = {1, 0};
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
getContext(), origShape, retSizePerThread, retOrder, numWarps);
RankedTensorType retType =
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
// a & b must be of smem layout
auto aType = adaptor.a().getType().cast<RankedTensorType>();
auto bType = adaptor.b().getType().cast<RankedTensorType>();
@@ -232,6 +253,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
Value c = adaptor.c();
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
@@ -246,7 +268,9 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, c,
adaptor.allowTF32());
return success();
}

View File

@@ -77,9 +77,9 @@ SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
blockedLayout.getThreadsPerWarp().end());
}
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 1)
if (mmaLayout.isVolta())
return {4, 8};
if (mmaLayout.getVersion() == 2)
if (mmaLayout.isAmpere())
return {8, 4};
}
assert(0 && "getThreadsPerWarp not implemented");
@@ -106,9 +106,9 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
return getSizePerThread(sliceLayout.getParent());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
if (mmaLayout.isAmpere()) {
return {2, 2};
} else if (mmaLayout.getVersion() == 1) {
} else if (mmaLayout.isVolta()) {
// Note: here the definition of sizePerThread is obscure, which doesn't
// mean vecSize=4 can be supported in the last dimension.
return {2, 4};
@@ -119,7 +119,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.getVersion() == 2 &&
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
@@ -144,7 +144,7 @@ SmallVector<unsigned> getSizePerThread(const Attribute &layout) {
SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2);
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
return {1, 2};
} else {
return getSizePerThread(layout);
@@ -179,14 +179,13 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
for (unsigned d = 0, n = getOrder(parent).size(); d < n; ++d) {
if (d == dim)
continue;
shape.push_back(getSizePerThread(parent)[d] *
getThreadsPerWarp(parent)[d] * getWarpsPerCTA(parent)[d]);
shape.push_back(getShapePerCTA(parent)[d]);
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2)
if (mmaLayout.isAmpere())
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
if (mmaLayout.getVersion() == 1)
if (mmaLayout.isVolta())
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
assert(0 && "Unexpected MMA layout version found");
@@ -194,7 +193,7 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
assert(parentMmaLayout.getVersion() == 2 &&
assert(parentMmaLayout.isAmpere() &&
"mmaLayout version = 1 is not implemented yet");
auto parentShapePerCTA = getShapePerCTA(parentLayout);
auto opIdx = dotLayout.getOpIdx();
@@ -210,10 +209,10 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
"supported yet");
}
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersion() == 2) {
if (mmaLayout.isAmpere()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
8 * mmaLayout.getWarpsPerCTA()[1]};
} else if (mmaLayout.getVersion() == 1) {
} else if (mmaLayout.isVolta()) {
return {16 * mmaLayout.getWarpsPerCTA()[0],
16 * mmaLayout.getWarpsPerCTA()[1]};
} else {
@@ -255,6 +254,11 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
}
};
bool isaDistributedLayout(const Attribute &layout) {
return layout.isa<BlockedEncodingAttr>() || layout.isa<MmaEncodingAttr>() ||
layout.isa<SliceEncodingAttr>();
}
} // namespace gpu
} // namespace triton
} // namespace mlir
@@ -369,17 +373,16 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
size_t rank = shape.size();
assert(rank == 2 && "Unexpected rank of mma layout");
assert((getVersion() == 1 || getVersion() == 2) &&
"Only version 1 and 2 is supported");
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
int res = 0;
if (getVersion() == 1) {
if (isVolta()) {
unsigned mmasRow = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]);
unsigned mmasCol = ceil<unsigned>(shape[1], 16 * getWarpsPerCTA()[1]);
// Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16
// matrix as result.
res = mmasRow * mmasCol * (16 * 16 / 32);
} else if (getVersion() == 2) {
} else if (isAmpere()) {
unsigned elemsCol = ceil<unsigned>(shape[0], 16 * getWarpsPerCTA()[0]) * 2;
unsigned elemsRow = ceil<unsigned>(shape[1], 8 * getWarpsPerCTA()[1]) * 2;
res = elemsCol * elemsRow;
@@ -477,12 +480,17 @@ Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
unsigned version = 0;
unsigned versionMajor = 0;
unsigned versionMinor = 0;
SmallVector<unsigned, 2> warpsPerCTA;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "version") {
if (parseUInt(parser, attr, version, "version").failed())
if (attr.getName() == "versionMajor") {
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
return {};
}
if (attr.getName() == "versionMinor") {
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
@@ -491,13 +499,14 @@ Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
}
}
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
warpsPerCTA);
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), versionMajor,
versionMinor, warpsPerCTA);
}
void MmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "version = " << getVersion() << ", "
<< "versionMajor = " << getVersionMajor() << ", "
<< "versionMinor = " << getVersionMinor() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
}
@@ -576,6 +585,25 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
<< "}>";
}
//===----------------------------------------------------------------------===//
// Mma encoding
//===----------------------------------------------------------------------===//
bool MmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; }
bool MmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; }
// Get [isARow, isBRow, isAVec4, isBVec4] from versionMinor
std::tuple<bool, bool, bool, bool>
MmaEncodingAttr::decodeVoltaLayoutStates() const {
unsigned versionMinor = getVersionMinor();
bool isARow = versionMinor & (1 << 0);
bool isBRow = versionMinor & (1 << 1);
bool isAVec4 = versionMinor & (1 << 2);
bool isBVec4 = versionMinor & (1 << 3);
return std::make_tuple(isARow, isBRow, isAVec4, isBVec4);
}
//===----------------------------------------------------------------------===//
// DotOperand Encoding
//===----------------------------------------------------------------------===//
@@ -590,10 +618,10 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
Attribute isMMAv1Row;
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
if (parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().isVolta()) {
isMMAv1Row = attrs.get("isMMAv1Row");
if(!isMMAv1Row)
if (!isMMAv1Row)
llvm::report_fatal_error("isMMAv1Row attribute is missing");
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
@@ -604,8 +632,8 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
if(getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
if (getIsMMAv1Row())
printer << ", isMMAv1Row = " << getIsMMAv1Row();
printer << "}>";
}

View File

@@ -22,6 +22,10 @@
using namespace mlir;
namespace {
#include "TritonGPUCombine.inc"
using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
// -----------------------------------------------------------------------------
//
@@ -57,8 +61,7 @@ public:
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if (dstParentMma.getVersion() == 1 ||
dstParentMma.getWarpsPerCTA()[1] > 1)
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
SetVector<Operation *> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
@@ -80,6 +83,45 @@ public:
}
};
class SimplifyReduceCvt : public mlir::RewritePattern {
public:
explicit SimplifyReduceCvt(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::ReduceOp::getOperationName(), 2, context) {
}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto reduce = cast<triton::ReduceOp>(*op);
auto reduceArg = dyn_cast<triton::gpu::ConvertLayoutOp>(
reduce.getOperand().getDefiningOp());
if (!reduceArg)
return mlir::failure();
// this may generate unsupported conversions in the LLVM codegen
if (reduceArg.getOperand()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
auto newReduce = rewriter.create<triton::ReduceOp>(
op->getLoc(), reduce.redOp(), reduceArg.getOperand(), reduce.axis());
if (isa<triton::gpu::ConvertLayoutOp>(
*reduceArg.getOperand().getDefiningOp()))
return mlir::failure();
Value newRet = newReduce.getResult();
// it's still beneficial to move the conversion
// to after the reduce if necessary since it will be
// done on a rank-reduced tensor hence cheaper
if (newRet.getType() != reduce.getResult().getType())
newRet = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), reduce.getResult().getType(), newRet);
rewriter.replaceOp(op, newRet);
return success();
}
};
// Layout conversions can't deduce their return type automatically.
// IIUC they are therefore not handled by DRR right now
class SimplifyConversion : public mlir::RewritePattern {
@@ -219,6 +261,7 @@ public:
//
// -----------------------------------------------------------------------------
// TODO: Interface
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
Attribute &ret) {
ret = targetEncoding;
@@ -236,6 +279,20 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return success();
}
// TODO: Interface
LogicalResult getForwardEncoding(Attribute sourceEncoding, Operation *op,
Attribute &ret) {
if (op->hasTrait<mlir::OpTrait::Elementwise>()) {
ret = sourceEncoding;
return success();
}
if (isa<triton::ReduceOp>(op)) {
ret = Attribute();
return success();
}
return failure();
}
inline bool expensive_to_remat(Operation *op) {
if (!op)
return true;
@@ -248,6 +305,64 @@ inline bool expensive_to_remat(Operation *op) {
return false;
}
LogicalResult simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding) {
// DFS
std::vector<std::pair<Operation *, Attribute>> queue;
queue.emplace_back(initOp, targetEncoding);
// We want to see the effect of converting `initOp` to a new layout
// so we initialize `numCvts = 1`.
int numCvts = 1;
while (!queue.empty()) {
Operation *currOp;
Attribute currLayout;
std::tie(currOp, currLayout) = queue.back();
queue.pop_back();
// If the current operation is expensive to rematerialize,
// we stop everything
if (expensive_to_remat(currOp))
return mlir::failure();
// we would propagate the conversion here
numCvts -= 1;
// check if the conversion could be folded at this operation
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*currOp))
continue;
// done processing
processed.insert(currOp);
layout.insert(currLayout);
// add all operands to the queue
for (Value argI : currOp->getOperands()) {
Attribute newEncoding;
// cannot invert the current encoding for this operand
// we stop everything
if (failed(invertEncoding(currLayout, currOp, newEncoding))) {
return mlir::failure();
}
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
return mlir::failure();
//
Operation *opArgI = argI.getDefiningOp();
toConvert.insert({argI, newEncoding});
if (!opArgI || processed.contains(opArgI) ||
(opArgI->getBlock() != initOp->getBlock()))
continue;
// we add one expensive conversion for the current operand
numCvts += 1;
queue.emplace_back(opArgI, newEncoding);
}
}
// if rematerialization would add more conversions than it removes
// then we don't do it
if (numCvts > 0)
return mlir::failure();
return mlir::success();
}
//
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
@@ -268,6 +383,167 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
return newOp;
}
//
class MoveConvertOutOfIf : public mlir::RewritePattern {
public:
explicit MoveConvertOutOfIf(mlir::MLIRContext *context)
: mlir::RewritePattern(scf::IfOp::getOperationName(), 2, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto ifOp = cast<scf::IfOp>(*op);
auto thenYield = ifOp.thenYield();
auto elseYield = ifOp.elseYield();
int numOps = thenYield.getNumOperands();
SmallVector<Value> newThenYieldOps = thenYield.getOperands();
SmallVector<Value> newElseYieldOps = elseYield.getOperands();
SetVector<Operation *> thenCvts;
SetVector<Operation *> elseCvts;
SmallVector<Type> newRetTypes;
BlockAndValueMapping mapping;
for (size_t i = 0; i < numOps; i++) {
auto thenCvt = dyn_cast<triton::gpu::ConvertLayoutOp>(
thenYield.getOperand(i).getDefiningOp());
auto elseCvt = dyn_cast<triton::gpu::ConvertLayoutOp>(
elseYield.getOperand(i).getDefiningOp());
if (thenCvt && elseCvt &&
std::distance(thenCvt->user_begin(), thenCvt->user_end()) == 1 &&
std::distance(elseCvt->user_begin(), elseCvt->user_end()) == 1 &&
thenCvt.getOperand().getType() == elseCvt.getOperand().getType()) {
mapping.map(thenCvt.getResult(), thenCvt.getOperand());
mapping.map(elseCvt.getResult(), elseCvt.getOperand());
newRetTypes.push_back(thenCvt.getOperand().getType());
thenCvts.insert((Operation *)thenCvt);
elseCvts.insert((Operation *)elseCvt);
} else
newRetTypes.push_back(thenYield.getOperand(i).getType());
}
if (mapping.getValueMap().empty())
return mlir::failure();
rewriter.setInsertionPoint(op);
auto newIfOp = rewriter.create<scf::IfOp>(ifOp.getLoc(), newRetTypes,
ifOp.getCondition(), true);
// rematerialize `then` block
rewriter.setInsertionPointToEnd(newIfOp.thenBlock());
for (Operation &op : ifOp.thenBlock()->getOperations()) {
if (thenCvts.contains(&op)) {
mapping.map(op.getResult(0), mapping.lookup(op.getOperand(0)));
continue;
}
rewriter.clone(op, mapping);
}
// rematerialize `else` block
rewriter.setInsertionPointToEnd(newIfOp.elseBlock());
for (Operation &op : ifOp.elseBlock()->getOperations()) {
if (elseCvts.contains(&op)) {
mapping.map(op.getResult(0), mapping.lookup(op.getOperand(0)));
continue;
}
rewriter.clone(op, mapping);
}
rewriter.setInsertionPointAfter(newIfOp);
SmallVector<Value> newRetValues = newIfOp.getResults();
for (size_t i = 0; i < numOps; i++) {
if (newIfOp.getResult(i).getType() != ifOp.getResult(i).getType()) {
newRetValues[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
newIfOp.getLoc(), ifOp.getResult(i).getType(),
newIfOp.getResult(i));
}
}
rewriter.replaceOp(op, newRetValues);
return mlir::success();
}
};
//
class FoldConvertAndReduce : public mlir::RewritePattern {
public:
explicit FoldConvertAndReduce(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *cvtOp,
mlir::PatternRewriter &rewriter) const override {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(*cvtOp);
auto srcEncoding =
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
return op->getBlock() == cvt->getBlock() &&
!(isa<triton::ReduceOp>(op) &&
!op->getResult(0).getType().isa<RankedTensorType>()) &&
!isa<scf::YieldOp>(op);
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())
return failure();
llvm::MapVector<Value, Attribute> toConvert;
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensive_to_remat(op))
return failure();
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::Elementwise>())
return failure();
Attribute dstEncoding =
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
// don't rematerialize if it adds an extra conversion that can't
// be removed
for (Value arg : op->getOperands()) {
Operation *argOp = arg.getDefiningOp();
SetVector<Operation *> processed;
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
failed(simulateBackwardRematerialization(argOp, processed, layout,
toConvert, dstEncoding))) {
return failure();
}
}
}
BlockAndValueMapping mapping;
auto op = cvtSlices.front();
for (Value arg : op->getOperands()) {
if (arg.getDefiningOp() == cvt)
mapping.map(arg, cvt.getOperand());
else {
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
arg.getLoc(), cvt.getOperand().getType(), arg);
if (Operation *argOp = arg.getDefiningOp())
cvtI->moveAfter(argOp);
mapping.map(arg, cvtI);
}
}
rewriter.setInsertionPoint(op);
Operation *newOp = rewriter.clone(*op, mapping);
auto oldType = op->getResult(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(),
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding());
newOp->getResult(0).setType(newType);
auto newCvtType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(),
cvt.getResult().getType().cast<RankedTensorType>().getEncoding());
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
newOp->getLoc(), newCvtType, newOp->getResult(0));
rewriter.replaceOp(op, newCvt->getResults());
return success();
}
};
// Layout conversions are expensive. They require going through
// shared memory, which is orders of magnitude slower than
// other non-i/o operations in the dialect.
@@ -495,7 +771,6 @@ public:
continue;
}
// check
// llvm::outs() << "replacing " << iterArg.index() << "\n";
for (auto op : iterArg.value().getUsers()) {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
@@ -585,13 +860,15 @@ public:
// -----------------------------------------------------------------------------
namespace {
int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 80) {
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else {
assert(false && "computeCapability > 90 not supported");
return 0;
return 3;
}
}
@@ -746,6 +1023,7 @@ public:
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
return failure();
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
op->getContext(), dstDotOperandLayout.getOpIdx(),
@@ -787,25 +1065,55 @@ public:
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
int version = computeCapabilityToMMAVersion(computeCapability);
auto AType = dotOp.getOperand(0).getType().cast<RankedTensorType>();
auto BType = dotOp.getOperand(1).getType().cast<RankedTensorType>();
// for FMA, should retain the blocked layout.
if (!supportMMA(dotOp, version))
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
if (!supportMMA(dotOp, versionMajor))
return failure();
auto AOrder = AType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto BOrder = BType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent()
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),
triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), version,
getWarpsPerTile(dotOp, retShape, version, numWarps)));
auto warpsPerTile =
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
auto shapeA = AType.getShape();
auto shapeB = BType.getShape();
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, warpsPerTile, shapeA, shapeB,
isARow, isBRow);
} else if (versionMajor == 2) {
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
assert(false && "Mma layout only support versionMajor of 1 or 2");
}
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -826,7 +1134,7 @@ public:
.getOrder();
Attribute isMMAv1RowA;
Attribute isMMAv1RowB;
if (version == 1) {
if (versionMajor == 1) {
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
}
@@ -868,7 +1176,8 @@ public:
for (size_t i = 0; i < newInitArgs.size(); i++) {
auto initArg = newInitArgs[i];
auto regionArg = forOp.getRegionIterArgs()[i];
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) {
if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() ||
newInitArgs[i].getType() != forOp.getResultTypes()[i]) {
shouldRematerialize = true;
break;
}
@@ -884,15 +1193,207 @@ public:
BlockAndValueMapping mapping;
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (Operation &op : forOp.getBody()->getOperations()) {
Operation *newOp = rewriter.clone(op, mapping);
rewriter.clone(op, mapping);
}
rewriter.replaceOp(forOp, newForOp.getResults());
return success();
}
};
// This pattern collects the wrong Mma those need to update and create the right
// ones for each.
class CollectMmaToUpdateForVolta : public mlir::RewritePattern {
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
public:
CollectMmaToUpdateForVolta(
mlir::MLIRContext *ctx,
DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx),
mmaToUpdate(mmaToUpdate) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<triton::DotOp>(op);
auto *ctx = dotOp->getContext();
auto AT = dotOp.a().getType().cast<RankedTensorType>();
auto BT = dotOp.b().getType().cast<RankedTensorType>();
auto DT = dotOp.d().getType().cast<RankedTensorType>();
if (!DT.getEncoding())
return failure();
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
if (!(mmaLayout && mmaLayout.isVolta()))
return failure();
// Has processed.
if (mmaToUpdate.count(mmaLayout))
return failure();
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto [isARow_, isBRow_, isAVec4, isBVec4] =
mmaLayout.decodeVoltaLayoutStates();
if (isARow_ == isARow && isBRow_ == isBRow) {
return failure(); // No need to update
}
auto newMmaLayout = MmaEncodingAttr::get(
ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(),
AT.getShape(), BT.getShape(), isARow, isBRow);
// Collect the wrong MMA Layouts, and mark need to update.
mmaToUpdate.try_emplace(mmaLayout, newMmaLayout);
return failure();
}
};
// Correct the versionMinor field in MmaEncodingAttr for Volta.
class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern {
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate;
enum class Kind {
kUnk,
kCvtToMma,
kCvtToDotOp,
kDot,
kConstant,
};
mutable Kind rewriteKind{Kind::kUnk};
public:
UpdateMMAVersionMinorForVolta(
mlir::MLIRContext *ctx, llvm::StringRef opName,
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &mmaToUpdate)
: RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {}
LogicalResult match(Operation *op) const override {
MmaEncodingAttr mma;
if (mmaToUpdate.empty())
return failure();
if (op->getNumResults() != 1)
return failure();
auto tensorTy = op->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return failure();
// ConvertLayoutOp
if (auto cvt = llvm::dyn_cast<ConvertLayoutOp>(op)) {
// cvt X -> dot_operand
if (auto dotOperand =
tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>()) {
mma = dotOperand.getParent().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kCvtToDotOp;
if (mma && mmaToUpdate.count(mma))
return success();
}
if ((mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>())) {
// cvt X -> mma
rewriteKind = Kind::kCvtToMma;
if (mma && mmaToUpdate.count(mma))
return success();
}
} else if (auto dot = llvm::dyn_cast<DotOp>(op)) {
// DotOp
mma = dot.d()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kDot;
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
// ConstantOp
mma = tensorTy.getEncoding().dyn_cast<MmaEncodingAttr>();
rewriteKind = Kind::kConstant;
}
return success(mma && mmaToUpdate.count(mma));
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
switch (rewriteKind) {
case Kind::kDot:
rewriteDot(op, rewriter);
break;
case Kind::kConstant:
rewriteConstant(op, rewriter);
break;
case Kind::kCvtToDotOp:
rewriteCvtDotOp(op, rewriter);
break;
case Kind::kCvtToMma:
rewriteCvtToMma(op, rewriter);
break;
default:
llvm::report_fatal_error("Not supported rewrite kind");
}
}
private:
void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto dotOperand = tensorTy.getEncoding().cast<DotOperandEncodingAttr>();
MmaEncodingAttr newMma =
mmaToUpdate.lookup(dotOperand.getParent().cast<MmaEncodingAttr>());
auto newDotOperand = DotOperandEncodingAttr::get(
ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row());
auto newTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), newDotOperand);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteDot(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto dot = llvm::cast<DotOp>(op);
auto tensorTy = dot.d().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<DotOp>(op, newTensorTy, dot.a(), dot.b(),
dot.c(), dot.allowTF32());
}
void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto cvt = llvm::cast<ConvertLayoutOp>(op);
auto tensorTy = cvt.result().getType().cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, newTensorTy,
cvt.getOperand());
}
void rewriteConstant(Operation *op, PatternRewriter &rewriter) const {
auto *ctx = op->getContext();
auto constant = llvm::cast<arith::ConstantOp>(op);
auto tensorTy = constant.getResult().getType().dyn_cast<RankedTensorType>();
auto mma = tensorTy.getEncoding().cast<MmaEncodingAttr>();
auto newMma = mmaToUpdate.lookup(mma);
auto newTensorTy = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newMma);
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet =
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newTensorTy, newRet);
return;
}
assert(false && "Not supported ConstantOp value type");
}
};
} // namespace
#define GEN_PASS_CLASSES
@@ -914,17 +1415,41 @@ public:
patterns.add<OptimizeBlockedToShared>(context);
patterns.add<OptimizeConvertToDotOperand>(context);
patterns.add<SimplifyConversion>(context);
patterns.add<SimplifyReduceCvt>(context);
patterns.add<FoldConvertAndReduce>(context);
patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);
patterns.add<MoveConvertOutOfIf>(context);
patterns.add<BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
// llvm::outs() << m << "\n";
llvm::DenseMap<MmaEncodingAttr, MmaEncodingAttr> mmaToUpdate;
{
mlir::RewritePatternSet patterns(context);
patterns.add<CollectMmaToUpdateForVolta>(context, mmaToUpdate);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
}
{
mlir::RewritePatternSet patterns(context);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, DotOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, ConvertLayoutOp::getOperationName(), mmaToUpdate);
patterns.add<UpdateMMAVersionMinorForVolta>(
context, arith::ConstantOp::getOperationName(), mmaToUpdate);
mlir::GreedyRewriteConfig config;
config.useTopDownTraversal = true;
if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed())
signalPassFailure();
}
mlir::RewritePatternSet loopFixup(context);
loopFixup.add<FixupLoop>(context);
if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) {

View File

@@ -1,4 +1,5 @@
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
@@ -11,12 +12,13 @@
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/tools/sys/getenv.hpp"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/Constants.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
#include <filesystem>
namespace mlir {
namespace triton {
@@ -25,19 +27,18 @@ namespace triton {
// information from mlir module.
struct NVVMMetadata {
int maxntidx{-1};
bool is_kernel{};
bool isKernel{};
// Free to extend with other information.
};
// Add the nvvm related metadata to LLVM IR.
void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
static void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
auto *module = func->getParent();
auto &ctx = func->getContext();
if (metadata.maxntidx > 0) {
auto i32_ty = llvm::IntegerType::get(ctx, 32);
auto warps =
llvm::ConstantInt::get(i32_ty, llvm::APInt(32, metadata.maxntidx));
auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32),
llvm::APInt(32, metadata.maxntidx));
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
llvm::MDString::get(ctx, "maxntidx"),
@@ -47,33 +48,34 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
->addOperand(llvm::MDNode::get(ctx, md_args));
}
if (metadata.is_kernel) {
llvm::Metadata *md_args[] = {
if (metadata.isKernel) {
llvm::Metadata *mdArgs[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
->addOperand(llvm::MDNode::get(ctx, mdArgs));
}
}
void extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
static void
extractNVVMMetadata(mlir::ModuleOp module,
llvm::DenseMap<llvm::StringRef, NVVMMetadata> *dic) {
for (auto op : module.getOps<LLVM::LLVMFuncOp>()) {
NVVMMetadata meta;
bool hasMetadata{};
// maxntid
if (op->hasAttr(NVVMMetadataField::MaxNTid)) {
auto attr = op->getAttr(NVVMMetadataField::MaxNTid);
if (op->hasAttr("nvvm.maxntid")) {
auto attr = op->getAttr("nvvm.maxntid");
meta.maxntidx = attr.dyn_cast<IntegerAttr>().getInt();
hasMetadata = true;
}
// kernel
if (op->hasAttr(NVVMMetadataField::Kernel)) {
meta.is_kernel = true;
if (op->hasAttr("nvvm.kernel")) {
meta.isKernel = true;
hasMetadata = true;
}
@@ -82,13 +84,109 @@ void extractNVVMMetadata(mlir::ModuleOp module,
}
}
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string libName = name.str();
externLibs[libName] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
if (!funcs.empty()) {
// When using the Math Dialect, it is possible that some ops (e.g., log) are
// lowered to a function call. In this case, we need to link libdevice
// using its default path:
// [triton root dir]/python/triton/language/libdevice.10.bc
// TODO(Keren): handle external linkage other than libdevice?
namespace fs = std::filesystem;
static const std::string libdevice = "libdevice";
static const std::filesystem::path path = std::filesystem::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
externLibs.try_emplace(libdevice, path.string());
}
return externLibs;
}
static void linkLibdevice(llvm::Module &module) {
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *i32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(i32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
if (name == "libdevice") {
linkLibdevice(module);
} else {
assert(false && "unknown extern lib: ");
}
return false;
}
std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
context->appendDialectRegistry(registry);
module->getContext()->appendDialectRegistry(registry);
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
extractNVVMMetadata(module, &nvvmMetadata);
@@ -99,6 +197,20 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
return nullptr;
}
// Link external libraries before perform optimizations
// Note from libdevice users guide:
// https://docs.nvidia.com/cuda/libdevice-users-guide/basic-usage.html
// The standard process for linking with libdevice is to first link it with
// the target module, then run the standard LLVM optimization and code
// generation passes. This allows the optimizers to inline and perform
// analyses on the used library functions, and eliminate any used functions as
// dead code.
auto externLibs = getExternLibs(module);
for (auto &lib : externLibs) {
if (linkExternLib(*llvmModule, lib.first, lib.second))
return nullptr;
}
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
@@ -146,49 +258,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
module.walk([&](LLVM::LLVMFuncOp func) {
if (func.isExternal())
funcs.push_back(func);
});
for (auto &func : funcs) {
if (func.getOperation()->hasAttr("libname")) {
auto name =
func.getOperation()->getAttr("libname").dyn_cast<StringAttr>();
auto path =
func.getOperation()->getAttr("libpath").dyn_cast<StringAttr>();
if (name) {
std::string lib_name = name.str();
externLibs[lib_name] = path.str();
}
}
}
if (module.getOperation()->hasAttr("triton_gpu.externs")) {
auto dict = module.getOperation()
->getAttr("triton_gpu.externs")
.dyn_cast<DictionaryAttr>();
for (auto &attr : dict) {
externLibs[attr.getName().strref().trim().str()] =
attr.getValue().dyn_cast<StringAttr>().strref().trim().str();
}
}
auto llvmir = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmir) {
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;
}
llvm::SMDiagnostic err;
for (auto &lib : externLibs) {
if (linkExternLib(*llvmir, lib.second))
return nullptr;
}
return llvmir;
return llvmIR;
}
void addExternalLibs(mlir::ModuleOp &module,
@@ -208,29 +283,6 @@ void addExternalLibs(mlir::ModuleOp &module,
DictionaryAttr dict = DictionaryAttr::get(module->getContext(), attrs);
module.getOperation()->setAttr("triton_gpu.externs", dict);
return;
}
bool linkExternLib(llvm::Module &module, llvm::StringRef path) {
llvm::SMDiagnostic err;
auto &ctx = module.getContext();
auto extMod = llvm::parseIRFile(path, err, ctx);
if (!extMod) {
llvm::errs() << "Failed to load " << path;
return true;
}
extMod->setTargetTriple(module.getTargetTriple());
extMod->setDataLayout(module.getDataLayout());
if (llvm::Linker::linkModules(module, std::move(extMod),
llvm::Linker::Flags::LinkOnlyNeeded)) {
llvm::errs() << "Failed to link " << path;
return true;
}
return false;
}
} // namespace triton

View File

@@ -8,7 +8,6 @@
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include <filesystem>
namespace triton {
@@ -31,68 +30,29 @@ static bool findAndReplace(std::string &str, const std::string &begin,
return true;
}
static void linkExternal(llvm::Module &module) {
bool hasExternal = false;
for (auto &func : module) {
if (func.hasExternalLinkage()) {
hasExternal = true;
break;
}
}
if (hasExternal) {
namespace fs = std::filesystem;
// [triton root dir]/python/triton/language/libdevice.10.bc
static const fs::path libdevice = fs::path(__FILE__)
.parent_path()
.parent_path()
.parent_path()
.parent_path() /
"python" / "triton" / "language" /
"libdevice.10.bc";
if (mlir::triton::linkExternLib(module, libdevice.string()))
llvm::errs() << "link failed for: " << libdevice.string();
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
// this will enable fast math path in libdevice
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
// sqrt.approx.ftz.f32
auto &ctx = module.getContext();
llvm::Type *I32 = llvm::Type::getInt32Ty(ctx);
llvm::Metadata *mdFour =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
llvm::Metadata *mdName = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
llvm::Metadata *mdOne =
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
llvm::MDNode *reflect = llvm::MDNode::get(ctx, {mdFour, mdName, mdOne});
module.addModuleFlag(reflect);
}
}
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
linkExternal(module);
// LLVM version in use may not officially support target hardware
int maxNNVMCC = 75;
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
int maxPTX = std::min(75, version);
int maxCC = std::min(86, cc);
// options
auto options = llvm::cl::getRegisteredOptions();
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
// compute capability
std::string sm = "sm_" + std::to_string(cc);
std::string sm = "sm_" + std::to_string(maxCC);
// max PTX version
int ptxMajor = version / 10;
int ptxMinor = version % 10;
int ptxMajor = maxPTX / 10;
int ptxMinor = maxPTX % 10;
// create
llvm::SmallVector<char, 0> buffer;
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(std::min(cc, maxNNVMCC));
std::string proc = "sm_" + std::to_string(maxCC);
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(std::min(ptx,
// max_nvvm_ptx));
// std::string features = "+ptx" + std::to_string(maxPTX);
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;

View File

@@ -15,5 +15,5 @@ def kernel(X, stride_xm,
tl.store(Zs, tl.load(Xs))
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}, output="ttgir")
print(ret)

View File

@@ -24,10 +24,11 @@ def get_build_type():
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
return "TritonRelBuildWithAsserts"
else:
return "RelWithDebInfo"
# TODO: change to release when stable enough
#return "Release"
return "TritonRelBuildWithAsserts"
# --- third party packages -----
@@ -140,10 +141,10 @@ class CMakeBuild(build_ext):
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON",
# '-DPYTHON_EXECUTABLE=' + sys.executable,
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DLLVM_EXTERNAL_LIT=" + lit_dir
"-DLLVM_EXTERNAL_LIT=" + lit_dir,
] + thirdparty_cmake_args
# configuration
@@ -172,7 +173,7 @@ setup(
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description="",
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
install_requires=[
"cmake",
"filelock",

View File

@@ -13,15 +13,15 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/tools/sys/getenv.hpp"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
@@ -1267,12 +1267,11 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<::mlir::LLVM::UndefOp>(loc, type);
})
// Force GPU barrier
.def("create_barrier",
[](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::gpu::BarrierOp>(loc);
});
// Force GPU barrier
.def("create_barrier", [](mlir::OpBuilder &self) {
auto loc = self.getUnknownLoc();
self.create<mlir::gpu::BarrierOp>(loc);
});
py::class_<mlir::PassManager>(m, "pass_manager")
.def(py::init<mlir::MLIRContext *>())

View File

@@ -6,7 +6,7 @@ import torch
import triton
import triton.language as tl
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, set_gpu_clock
from triton.testing import get_dram_gbps, get_max_tensorcore_tflops
DEVICE_NAME = 'v100'
@@ -87,8 +87,8 @@ def test_matmul(M, N, K, dtype_str):
dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str]
torch.manual_seed(0)
ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str]
ref_sm_clock = sm_clocks[DEVICE_NAME]
cur_sm_clock = nvsmi(['clocks.current.sm'])[0]
ref_sm_clock = sm_clocks[DEVICE_NAME]
max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3)
assert abs(cur_sm_clock - ref_sm_clock) < 10, f'GPU SMs must run at {ref_sm_clock} MHz'
if dtype == torch.int8:

View File

@@ -1,5 +1,6 @@
# flake8: noqa: F821,F841
import itertools
import os
import re
from typing import Optional, Union
@@ -12,15 +13,13 @@ import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
from tests.libdevice_testutil import system_libdevice_path
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
# TODO: handle bfloat16
dtypes_with_bfloat16 = dtypes # + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes # + ['bfloat16']
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
def _bitwidth(dtype: str) -> int:
@@ -250,7 +249,7 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool:
@pytest.mark.parametrize("dtype_x, dtype_y, op", [
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/'] # , '%'] #TODO: handle remainder
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes_with_bfloat16
for dtype_y in dtypes_with_bfloat16
])
@@ -448,9 +447,9 @@ def test_where_broadcast():
z = np.where(0, x, 0)
assert (z == to_numpy(z_tri)).all()
# # ---------------
# # test unary ops
# # ---------------
# ---------------
# test unary ops
# ---------------
@pytest.mark.parametrize("dtype_x, expr", [
@@ -461,9 +460,9 @@ def test_where_broadcast():
def test_unary_op(dtype_x, expr, device='cuda'):
_test_unary(dtype_x, expr, device=device)
# # ----------------
# # test math ops
# # ----------------
# ----------------
# test math ops
# ----------------
@pytest.mark.parametrize("expr", [
@@ -473,9 +472,9 @@ def test_math_op(expr, device='cuda'):
_test_unary('float32', f'tl.{expr}(x)', f'np.{expr}(x) ', device=device)
# # ----------------
# # test indexing
# # ----------------
# ----------------
# test indexing
# ----------------
def make_ptr_str(name, shape):
@@ -493,10 +492,8 @@ def make_ptr_str(name, shape):
@pytest.mark.parametrize("expr, dtype_str", [
(f'x[{s}]', d)
for s in ['None, :', ':, None',
# TODO: 3D
# 'None, :, :',
# ':, :, None'
]
'None, :, :',
':, :, None']
for d in ['int32', 'uint32', 'uint16']
])
def test_index1d(expr, dtype_str, device='cuda'):
@@ -551,9 +548,9 @@ def test_index1d(expr, dtype_str, device='cuda'):
catch_compilation_error(kernel_rank_mismatch)
# # ---------------
# # test tuples
# # ---------------
# ---------------
# test tuples
# ---------------
@triton.jit
@@ -609,6 +606,10 @@ def test_tuples():
]
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']]))
def test_atomic_rmw(op, dtype_x_str, mode, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
if dtype_x_str == 'float16':
pytest.skip("Only test atomic float16 ops on devices with sm >= 70")
n_programs = 5
# triton kernel
@@ -709,9 +710,9 @@ def test_atomic_cas():
triton.testing.assert_almost_equal(data, ref)
# # ---------------
# # test cast
# # ---------------
# ---------------
# test cast
# ---------------
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
@@ -719,11 +720,9 @@ def test_atomic_cas():
for dtype_x in dtypes
for dtype_z in dtypes
] + [
# TODO:
# ('float32', 'bfloat16', False),
# ('bfloat16', 'float32', False),
('float32', 'bfloat16', False),
('bfloat16', 'float32', False),
('float32', 'int32', True),
# TODO:
('float32', 'int1', False),
] + [
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
@@ -731,6 +730,10 @@ def test_atomic_cas():
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
# bfloat16 on cc < 80 will not be tested
check_type_supported(dtype_x)
check_type_supported(dtype_z)
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
x0 = 43 if dtype_x in int_dtypes else 43.5
if dtype_x in float_dtypes and dtype_z == 'int1':
@@ -873,9 +876,9 @@ def test_f16_to_f8_rounding():
), f"f16_input[mismatch]={f16_input[mismatch]} f16_output[mismatch]={f16_output[mismatch]} abs_error[mismatch]={abs_error[mismatch]} min_error[mismatch]={min_error[mismatch]}"
# # ---------------
# # test reduce
# # ---------------
# ---------------
# test reduce
# ---------------
def get_reduced_dtype(dtype_str, op):
@@ -888,7 +891,6 @@ def get_reduced_dtype(dtype_str, op):
return dtype_str
# TODO: [Qingyi] Fix argmin / argmax
@pytest.mark.parametrize("op, dtype_str, shape",
[(op, dtype, shape)
for op in ['min', 'max', 'sum']
@@ -951,7 +953,7 @@ reduce_configs1 = [
# exceeds the limit of 99KB
reduce2d_shapes = [(2, 32), (4, 32), (4, 128)]
# TODO: fix and uncomment
#, (32, 64), (64, 128)]
# , (32, 64), (64, 128)]
if 'V100' in torch.cuda.get_device_name(0):
reduce2d_shapes += [(128, 256) and (32, 1024)]
@@ -966,6 +968,8 @@ reduce_configs2 = [
@pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2)
def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
@@ -1017,9 +1021,9 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
else:
np.testing.assert_equal(z_ref, z_tri)
# # ---------------
# # test permute
# # ---------------
# ---------------
# test permute
# ---------------
@pytest.mark.parametrize("dtype_str, shape, perm",
@@ -1066,31 +1070,43 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# # ---------------
# # test dot
# # ---------------
# ---------------
# test dot
# ---------------
@pytest.mark.parametrize("M, N, K, epilogue, allow_tf32, dtype",
[(*shape, epilogue, allow_tf32, dtype)
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 4, False, False, epilogue, allow_tf32, dtype)
for shape in [(64, 64, 64)]
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
for allow_tf32 in [True, False]
for dtype in ['float16', 'float32']
if not (allow_tf32 and (dtype in ['float16']))])
def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
if not (allow_tf32 and (dtype in ['float16']))] +
[(*shape_nw, col_a, col_b, 'none', allow_tf32, dtype)
for shape_nw in [[128, 256, 32, 8],
[128, 16, 32, 4],
[32, 128, 64, 4],
[128, 128, 64, 4],
[64, 128, 128, 4],
[32, 128, 64, 2],
[128, 128, 64, 2],
[64, 128, 128, 4]]
for allow_tf32 in [True]
for col_a in [True, False]
for col_b in [True, False]
for dtype in ['int8', 'float16', 'float32']])
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, device='cuda'):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
num_warps = 4
trans_a, trans_b = False, False
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
@@ -1101,7 +1117,7 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
TRANS_A: tl.constexpr, TRANS_B: tl.constexpr):
COL_A: tl.constexpr, COL_B: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_l = tl.arange(0, BLOCK_N)
@@ -1112,8 +1128,6 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
x = tl.load(Xs)
y = tl.load(Ys)
x = tl.trans(x) if TRANS_A else x
y = tl.trans(y) if TRANS_B else y
z = tl.dot(x, y, allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
@@ -1130,17 +1144,24 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
den = tl.sum(num, 1)
z = num / den[:, None]
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
w = tl.load(Ws)
z = tl.dot(z.to(w.dtype), w)
tl.store(Zs, z)
# input
rs = RandomState(17)
x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1
y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1
w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1
if allow_tf32:
if col_a:
x = numpy_random((K, M), dtype_str=dtype, rs=rs).T
else:
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
if col_b:
y = numpy_random((N, K), dtype_str=dtype, rs=rs).T
else:
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
w = numpy_random((N, N), dtype_str=dtype, rs=rs)
if 'int' not in dtype:
x *= .1
y *= .1
if dtype == 'float32' and allow_tf32:
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32')
@@ -1148,7 +1169,11 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
y_tri = to_triton(y, device=device)
w_tri = to_triton(w, device=device)
# triton result
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
if dtype == 'int8':
z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs)
else:
z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
@@ -1156,7 +1181,7 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1),
TRANS_A=trans_a, TRANS_B=trans_b,
COL_A=col_a, COL_B=col_b,
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
@@ -1166,9 +1191,12 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
# torch result
x_ref = x.T if trans_a else x
y_ref = y.T if trans_b else y
z_ref = np.matmul(x_ref, y_ref)
if dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
y.astype(np.float32())).astype(np.int32)
else:
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
z_ref += z
if epilogue == 'add-rows':
@@ -1192,7 +1220,7 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
if dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
elif dtype == 'float32' and allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
@@ -1200,23 +1228,23 @@ def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'):
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
# def test_dot_without_load():
# @triton.jit
# def kernel(out):
# pid = tl.program_id(axis=0)
# a = tl.zeros((32, 32), tl.float32)
# b = tl.zeros((32, 32), tl.float32)
# c = tl.zeros((32, 32), tl.float32)
# c = tl.dot(a, b)
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
# tl.store(pout, c)
def test_dot_without_load():
@triton.jit
def kernel(out):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
# kernel[(1,)](out)
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# # ---------------
# # test arange
# # ---------------
# ---------------
# test arange
# ---------------
@pytest.mark.parametrize("start", [0, 1, 7, 16])
@@ -1239,7 +1267,7 @@ def test_arange(start, device='cuda'):
# ---------------
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]])
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@@ -1258,68 +1286,68 @@ def test_masked_load(dtype_str, size, size_diff, device='cuda'):
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
x = GENERATE_TEST_HERE
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
_kernel[(1,)](input, output, input_size, output_size)
mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None"
kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"})
kernel[(1,)](input, output, input_size, output_size)
reference_out = input
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# # 'bfloat16': torch.bfloat16,
# # Testing masked loads with an intermate copy to shared memory run.
# Testing masked loads with an intermate copy to shared memory run.
# @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
# def test_masked_load_shared_memory(dtype, device='cuda'):
# check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
# M = 32
# N = 32
# K = 16
M = 32
N = 32
K = 16
# in1 = torch.rand((M, K), dtype=dtype, device=device)
# in2 = torch.rand((K, N), dtype=dtype, device=device)
# out = torch.zeros((M, N), dtype=dtype, device=device)
in1 = torch.rand((M, K), dtype=dtype, device=device)
in2 = torch.rand((K, N), dtype=dtype, device=device)
out = torch.zeros((M, N), dtype=dtype, device=device)
# @triton.jit
# def _kernel(in1_ptr, in2_ptr, output_ptr,
# in_stride, in2_stride, out_stride,
# in_numel, in2_numel, out_numel,
# M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
@triton.jit
def _kernel(in1_ptr, in2_ptr, output_ptr,
in_stride, in2_stride, out_stride,
in_numel, in2_numel, out_numel,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
# M_offsets = tl.arange(0, M)
# N_offsets = tl.arange(0, N)
# K_offsets = tl.arange(0, K)
M_offsets = tl.arange(0, M)
N_offsets = tl.arange(0, N)
K_offsets = tl.arange(0, K)
# in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
# in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :]
in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :]
# # Load inputs.
# x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
# w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
# Load inputs.
x = tl.load(in1_ptr + in_offsets, mask=in_offsets < in_numel)
w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < in2_numel)
# # Without a dot product the memory doesn't get promoted to shared.
# o = tl.dot(x, w)
# Without a dot product the memory doesn't get promoted to shared.
o = tl.dot(x, w)
# # Store output
# output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
# tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
# Store output
output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :]
tl.store(output_ptr + output_offsets, o, mask=output_offsets < in2_numel)
# pgm = _kernel[(1,)](in1, in2, out,
# in1.stride()[0],
# in2.stride()[0],
# out.stride()[0],
# in1.numel(),
# in2.numel(),
# out.numel(),
# M=M, N=N, K=K)
pgm = _kernel[(1,)](in1, in2, out,
in1.stride()[0],
in2.stride()[0],
out.stride()[0],
in1.numel(),
in2.numel(),
out.numel(),
M=M, N=N, K=K)
# reference_out = torch.matmul(in1, in2)
# triton.testing.allclose(out, reference_out)
reference_out = torch.matmul(in1, in2)
triton.testing.allclose(out, reference_out)
@pytest.mark.parametrize("cache", ["", ".ca", ".cg"])
@@ -1363,26 +1391,27 @@ def test_vectorization(N):
else:
assert "ld.global.b32" in ptx
# triton.testing.assert_almost_equal(dst, src[:N])
# # ---------------
# # test store
# # ---------------
# # ---------------
# # test if
# # ---------------
# ---------------
# test store
# ---------------
# # ---------------
# # test for
# # ---------------
# ---------------
# test if
# ---------------
# # ---------------
# # test while
# # ---------------
# ---------------
# test for
# ---------------
# # ---------------
# # test default
# # ---------------
# # TODO: can't be local to test_default
# ---------------
# test while
# ---------------
# ---------------
# test default
# ---------------
# TODO: can't be local to test_default
@triton.jit
@@ -1404,9 +1433,9 @@ def test_default():
assert ret0.item() == 10
assert ret1.item() == value
# # ---------------
# # test noop
# # ----------------
# ---------------
# test noop
# ----------------
def test_noop(device='cuda'):
@@ -1440,9 +1469,9 @@ def test_value_specialization(value: int, value_type: str, device='cuda') -> Non
JITFunction.cache_hook = None
assert spec_type == value_type
# # --------------------
# # value specialization
# # --------------------
# --------------------
# value specialization
# --------------------
@pytest.mark.parametrize(
@@ -1464,9 +1493,9 @@ def test_value_specialization_overflow(value: int, overflow: bool, device='cuda'
kernel[(1, )](value, x)
# # ----------------
# # test constexpr
# # ----------------
# ----------------
# test constexpr
# ----------------
@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>'])
@pytest.mark.parametrize("is_lhs_constexpr", [False, True])
@@ -1517,9 +1546,9 @@ def test_constexpr_scalar_shape():
kernel[(1,)](x_tri, 32)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
# # -------------
# # test call
# # -------------
# -------------
# test call
# -------------
@triton.jit
@@ -1553,9 +1582,9 @@ def test_call():
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
# # -------------
# # test if
# # -------------
# -------------
# test if
# -------------
def test_if():
@@ -1589,9 +1618,23 @@ def test_num_warps_pow2():
_kernel[(1,)](dst=dst, num_warps=2)
_kernel[(1,)](dst=dst, num_warps=4)
# # -------------
# # test extern
# # -------------
# -------------
# test extern
# -------------
def system_libdevice_path() -> str:
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
]
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
if os.path.exists(_p):
SYSTEM_LIBDEVICE_PATH = _p
assert SYSTEM_LIBDEVICE_PATH is not None, \
"Could not find libdevice.10.bc path"
return SYSTEM_LIBDEVICE_PATH
@pytest.mark.parametrize("dtype_str, expr, lib_path",
@@ -1663,3 +1706,95 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path})
# compare
np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01)
# -----------------------
# test layout conversions
# -----------------------
# TODO: backend hsould be tested separately
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}
}
"""
x = to_triton(numpy_random(shape, dtype_str=dtype))
z = torch.empty_like(x)
# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
assert torch.equal(z, x)

View File

@@ -1,12 +1,13 @@
import os
import subprocess
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
printf_path = os.path.join(dir_path, "printf_helper.py")
def test_printf():
proc = subprocess.Popen(["python", printf_path], stdout=subprocess.PIPE, shell=False)
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
(outs, err) = proc.communicate()
outs = outs.split()
new_lines = set()

View File

@@ -3,14 +3,14 @@ import torch
import triton
# TODO: float32 fails
@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
# TODO: float32 fails
@pytest.mark.parametrize("DTYPE", [torch.float16])
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=384):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=256):
seed = 0
torch.manual_seed(seed)
is_sdd = MODE == "sdd"
@@ -39,8 +39,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=
dc_ref = do_mask(dc_ref) if is_sdd else dc_ref
a_ref = do_mask(a_ref) if is_dsd else a_ref
b_ref = do_mask(b_ref) if is_dds else b_ref
a_ref.requires_grad_().retain_grad()
b_ref.requires_grad_().retain_grad()
a_ref.retain_grad()
b_ref.retain_grad()
c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref,
b_ref.transpose(2, 3) if TRANS_B else b_ref)
c_ref.backward(dc_ref)
@@ -51,8 +51,8 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=256, K=
dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri
a_tri = do_sparsify(a_tri) if is_dsd else a_tri
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.requires_grad_().retain_grad()
b_tri.requires_grad_().retain_grad()
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
c_tri = triton.testing.catch_oor(lambda: op(a_tri, b_tri), pytest)
triton.testing.catch_oor(lambda: c_tri.backward(dc_tri), pytest)
@@ -116,7 +116,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
@pytest.mark.parametrize("block", [16, 32, 64])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_attention_fwd_bwd(
block,
dtype,
@@ -126,6 +126,10 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
# inputs
qkv_shape = (batch_size, n_heads, n_ctx, 64)
qkvs = [

View File

@@ -0,0 +1,38 @@
import pytest
import torch
import triton
@pytest.mark.parametrize("M, N, dtype, mode",
[
(M, N, dtype, mode) for M in [1024, 821]
for N in [512, 857, 1871, 2089, 8573, 31000]
for dtype in ['float16', 'float32']
for mode in ['forward', 'backward']
]
)
def test_op(M, N, dtype, mode):
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype]
# create inputs
x = torch.randn(M, N, dtype=dtype, device='cuda', requires_grad=True)
idx = 4 + torch.ones(M, dtype=torch.int64, device='cuda')
# forward pass
tt_y = triton.ops.cross_entropy(x, idx)
th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx)
if mode == 'forward':
triton.testing.assert_almost_equal(th_y, tt_y)
# backward pass
elif mode == 'backward':
dy = torch.randn_like(tt_y)
# triton backward
tt_y.backward(dy)
tt_dx = x.grad.clone()
# torch backward
x.grad.zero_()
th_y.backward(dy)
th_dx = x.grad.clone()
triton.testing.assert_almost_equal(th_dx, tt_dx)

View File

@@ -4,7 +4,6 @@ import pytest
import torch
import triton
import triton._C.libtriton.triton as _triton
@pytest.mark.parametrize(
@@ -72,10 +71,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
pytest.skip("Only test tl.dot() on devices with sm >= 70")
if capability[0] < 8 and DTYPE == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
#if DTYPE == "bfloat16" and SPLIT_K != 1:
# pytest.skip("bfloat16 matmuls don't allow split_k for now")
if DTYPE == "bfloat16":
pytest.skip("bfloat16 matmuls doesn't support for now")
if DTYPE == "bfloat16" and SPLIT_K != 1:
pytest.skip("bfloat16 matmuls don't allow split_k for now")
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}

View File

@@ -0,0 +1,206 @@
import multiprocessing
import os
import re
import shutil
from collections import namedtuple
import pytest
import torch
import triton
import triton.language as tl
from triton.runtime.jit import JITFunction
tmpdir = ".tmp"
@triton.jit
def function_1(i):
i = i + 1
i = function_2(i)
return i
@triton.jit
def function_2(i):
i = i + 1
return i
@triton.jit
def kernel(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
@triton.jit(do_not_specialize=["i"])
def kernel_nospec(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)
def apply_src_change(target, old, new):
kernel.hash = None
function_1.hash = None
function_2.hash = None
function_1.src = function_1.src.replace(old, new)
target.src = target.src.replace(old, new)
ret = target.cache_key
target.src = target.src.replace(new, old)
return ret
def test_nochange():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 1')
assert baseline == updated
def test_toplevel_change():
baseline = kernel.cache_key
updated = apply_src_change(kernel, 'i + 1', 'i + 2')
assert baseline != updated
def test_nested1_change():
baseline = kernel.cache_key
updated = apply_src_change(function_1, 'i + 1', 'i + 2')
assert baseline != updated
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
def test_reuse():
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
for i in range(10):
kernel[(1,)](x, 1, BLOCK=1024)
assert counter == 1
@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
counter = 0
def inc_counter(*args, **kwargs):
nonlocal counter
counter += 1
JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 3, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1,)](x, i, BLOCK=512)
assert counter == target
@pytest.mark.parametrize("value, value_type", [
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
])
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
@triton.jit
def kernel(VALUE, X):
pass
cache_str = None
def get_cache_str(*args, **kwargs):
nonlocal cache_str
cache_str = kwargs["repr"]
triton.JITFunction.cache_hook = get_cache_str
reset_tmp_dir()
x = torch.tensor([3.14159], device='cuda')
kernel[(1, )](value, x)
triton.JITFunction.cache_hook = None
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
spec_type = None if cache_str_match is None else cache_str_match.group(1)
assert spec_type == value_type
def test_constexpr_not_callable() -> None:
@triton.jit
def kernel(X, c: tl.constexpr):
tl.store(X, 2)
x = torch.empty(1, dtype=torch.int32, device='cuda')
error = False
try:
kernel[(1, )](x, c="str")
except BaseException:
error = True
assert error is False
# try and catch
try:
kernel[(1, )](x, c=tl.abs)
except BaseException:
error = True
assert error is True
def test_jit_warmup_cache() -> None:
@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) + tl.load(b + idx))
args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
32,
]
assert len(kernel_add.cache) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
def test_compile_in_subproc() -> None:
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) - tl.load(b + idx) * 777)
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = namedtuple("instance_descriptor", [
"divisible_by_16", "equal_to_1"])(
tuple(range(4)),
())
proc = multiprocessing.Process(
target=triton.compile,
kwargs=dict(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
))
proc.start()
proc.join()
assert proc.exitcode == 0

View File

@@ -1,18 +0,0 @@
import os
from typing import Optional
_SYSTEM_LIBDEVICE_SEARCH_PATHS = [
'/usr/lib/cuda/nvvm/libdevice/libdevice.10.bc',
'/usr/local/cuda/nvvm/libdevice/libdevice.10.bc',
]
SYSTEM_LIBDEVICE_PATH: Optional[str] = None
for _p in _SYSTEM_LIBDEVICE_SEARCH_PATHS:
if os.path.exists(_p):
SYSTEM_LIBDEVICE_PATH = _p
def system_libdevice_path() -> str:
assert SYSTEM_LIBDEVICE_PATH is not None, \
"Could not find libdevice.10.bc path"
return SYSTEM_LIBDEVICE_PATH

View File

@@ -1,156 +0,0 @@
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [4, 2]}>
#shared0 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 8 : i32} {
func public @_kernel_0d1d2d3d4d5d6d7d8d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : index
%cst = arith.constant dense<32> : tensor<256x32xi32, #blocked0>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
%c8_i32 = arith.constant 8 : i32
%c255_i32 = arith.constant 255 : i32
%c127_i32 = arith.constant 127 : i32
%c32_i32 = arith.constant 32 : i32
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c256_i32 = arith.constant 256 : i32
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.get_program_id {axis = 1 : i32} : i32
%2 = arith.addi %arg3, %c255_i32 : i32
%3 = arith.divsi %2, %c256_i32 : i32
%4 = arith.addi %arg4, %c127_i32 : i32
%5 = arith.divsi %4, %c128_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %0, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.cmpi slt, %9, %c8_i32 : i32
%11 = select %10, %9, %c8_i32 : i32
%12 = arith.remsi %0, %11 : i32
%13 = arith.addi %8, %12 : i32
%14 = arith.remsi %0, %6 : i32
%15 = arith.divsi %14, %11 : i32
%16 = arith.muli %13, %c256_i32 : i32
%17 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%18 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%19 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%20 = tt.splat %16 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%21 = arith.muli %15, %c128_i32 : i32
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%23 = tt.splat %21 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%24 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%25 = tt.splat %arg3 : (i32) -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%26 = tt.splat %arg4 : (i32) -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%27 = arith.muli %1, %c32_i32 : i32
%28 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%30 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%31 = tt.splat %27 : (i32) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%32 = arith.addi %19, %17 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%33 = arith.remsi %32, %24 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>
%34 = tt.splat %arg6 : (i32) -> tensor<256x1xi32, #blocked0>
%35 = arith.addi %30, %28 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
%36 = tt.expand_dims %35 {axis = 0 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>) -> tensor<1x32xi32, #blocked0>
%37 = tt.broadcast %36 : (tensor<1x32xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
%38 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<256x32x!tt.ptr<f16>, #blocked0>
%39 = arith.addi %31, %29 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%40 = tt.expand_dims %39 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<32x1xi32, #blocked1>
%41 = tt.splat %arg7 : (i32) -> tensor<32x1xi32, #blocked1>
%42 = arith.addi %23, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%43 = arith.remsi %42, %26 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%44 = tt.expand_dims %43 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
%45 = tt.broadcast %44 : (tensor<1x128xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
%46 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #blocked1>
%47 = arith.index_cast %arg5 : i32 to index
%48 = arith.muli %arg7, %c32_i32 : i32
%49 = tt.splat %48 : (i32) -> tensor<32x128xi32, #blocked1>
%50 = tt.expand_dims %33 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked0}>>) -> tensor<256x1xi32, #blocked0>
%51 = arith.muli %50, %34 : tensor<256x1xi32, #blocked0>
%52 = tt.broadcast %51 : (tensor<256x1xi32, #blocked0>) -> tensor<256x32xi32, #blocked0>
%53 = arith.addi %52, %37 : tensor<256x32xi32, #blocked0>
%54 = tt.addptr %38, %53 : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%55 = arith.muli %40, %41 : tensor<32x1xi32, #blocked1>
%56 = tt.broadcast %55 : (tensor<32x1xi32, #blocked1>) -> tensor<32x128xi32, #blocked1>
%57 = arith.addi %56, %45 : tensor<32x128xi32, #blocked1>
%58 = tt.addptr %46, %57 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
%59 = arith.cmpi slt, %c0, %47 : index
%60 = triton_gpu.alloc_tensor : tensor<2x256x32xf16, #shared0>
%64 = triton_gpu.alloc_tensor : tensor<2x32x128xf16, #shared1>
%61 = tt.splat %59 : (i1) -> tensor<256x32xi1, #blocked0>
%65 = tt.splat %59 : (i1) -> tensor<32x128xi1, #blocked1>
%62 = tt.load %54, %61 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
%66 = tt.load %58, %65 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
%63 = tensor.insert_slice %62 into %60[%c0_i32, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
%67 = tensor.insert_slice %66 into %64[%c0_i32, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
%68 = tt.addptr %54, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%69 = tt.addptr %58, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
%70 = tensor.extract_slice %63[0, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
%71 = tensor.extract_slice %67[0, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
%72 = tensor.extract_slice %70[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
gpu.barrier
%73 = triton_gpu.convert_layout %72 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
%74 = tensor.extract_slice %71[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
%75 = triton_gpu.convert_layout %74 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
%76:14 = scf.for %arg9 = %c0 to %47 step %c32 iter_args(%arg10 = %cst_0, %arg11 = %54, %arg12 = %58, %arg13 = %63, %arg14 = %67, %arg15 = %70, %arg16 = %71, %arg17 = %68, %arg18 = %69, %arg19 = %c0, %arg20 = %c1_i32, %arg21 = %c1_i32, %arg22 = %73, %arg23 = %75) -> (tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>) {
%104 = arith.addi %arg19, %c32 : index
%105 = arith.cmpi slt, %104, %47 : index
%106 = arith.remsi %arg20, %c2_i32 : i32
%107 = arith.remsi %arg21, %c2_i32 : i32
%108 = arith.index_cast %107 : i32 to index
%200 = arith.index_cast %106 : i32 to index
%109 = tt.splat %105 : (i1) -> tensor<256x32xi1, #blocked0>
%112 = tt.splat %105 : (i1) -> tensor<32x128xi1, #blocked1>
%110 = tt.load %arg17, %109 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x32xf16, #blocked0>
%113 = tt.load %arg18, %112 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked1>
%96 = tt.dot %arg22, %arg23, %arg10 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
%97 = tensor.extract_slice %arg15[0, 16] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
%98 = triton_gpu.convert_layout %97 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
%99 = tensor.extract_slice %arg16[16, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
%100 = triton_gpu.convert_layout %99 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
%101 = tt.dot %98, %100, %96 {allowTF32 = true} : tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>> * tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>> -> tensor<256x128xf32, #mma>
%102 = tt.addptr %arg11, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%103 = tt.addptr %arg12, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
gpu.barrier
%111 = tensor.insert_slice %110 into %arg13[%200, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<256x32xf16, #blocked0> into tensor<2x256x32xf16, #shared0>
%114 = tensor.insert_slice %113 into %arg14[%200, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<32x128xf16, #blocked1> into tensor<2x32x128xf16, #shared1>
gpu.barrier
%115 = tt.addptr %arg17, %cst : tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<256x32xi32, #blocked0>
%116 = tt.addptr %arg18, %49 : tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<32x128xi32, #blocked1>
%117 = tensor.extract_slice %111[%108, 0, 0] [1, 256, 32] [1, 1, 1] : tensor<2x256x32xf16, #shared0> to tensor<256x32xf16, #shared0>
%118 = tensor.extract_slice %114[%108, 0, 0] [1, 32, 128] [1, 1, 1] : tensor<2x32x128xf16, #shared1> to tensor<32x128xf16, #shared1>
%119 = arith.addi %arg20, %c1_i32 : i32
%120 = arith.addi %arg21, %c1_i32 : i32
%121 = tensor.extract_slice %117[0, 0] [256, 16] [1, 1] : tensor<256x32xf16, #shared0> to tensor<256x16xf16, #shared0>
%122 = triton_gpu.convert_layout %121 : (tensor<256x16xf16, #shared0>) -> tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>
%123 = tensor.extract_slice %118[0, 0] [16, 128] [1, 1] : tensor<32x128xf16, #shared1> to tensor<16x128xf16, #shared1>
%124 = triton_gpu.convert_layout %123 : (tensor<16x128xf16, #shared1>) -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
scf.yield %101, %102, %103, %111, %114, %117, %118, %115, %116, %104, %119, %120, %122, %124 : tensor<256x128xf32, #mma>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, tensor<2x256x32xf16, #shared0>, tensor<2x32x128xf16, #shared1>, tensor<256x32xf16, #shared0>, tensor<32x128xf16, #shared1>, tensor<256x32x!tt.ptr<f16>, #blocked0>, tensor<32x128x!tt.ptr<f16>, #blocked1>, index, i32, i32, tensor<256x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, isMMAv1Row = true}>>, tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, isMMAv1Row = true}>>
}
gpu.barrier
%77 = triton_gpu.convert_layout %76#0 : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked1>
%78 = arith.addi %20, %18 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%79 = tt.splat %arg8 : (i32) -> tensor<256x1xi32, #blocked1>
%80 = tt.expand_dims %42 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi32, #blocked1>
%81 = tt.broadcast %80 : (tensor<1x128xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
%82 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%83 = "triton_gpu.cmpi"(%78, %25) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%84 = "triton_gpu.cmpi"(%42, %26) {predicate = 2 : i64} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%85 = tt.expand_dims %84 {axis = 0 : i32} : (tensor<128xi1, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x128xi1, #blocked1>
%86 = tt.broadcast %85 : (tensor<1x128xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
%87 = tt.expand_dims %78 {axis = 1 : i32} : (tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi32, #blocked1>
%88 = arith.muli %87, %79 : tensor<256x1xi32, #blocked1>
%89 = tt.broadcast %88 : (tensor<256x1xi32, #blocked1>) -> tensor<256x128xi32, #blocked1>
%90 = arith.addi %89, %81 : tensor<256x128xi32, #blocked1>
%91 = tt.addptr %82, %90 : tensor<256x128x!tt.ptr<f16>, #blocked1>, tensor<256x128xi32, #blocked1>
%92 = arith.truncf %77 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%93 = tt.expand_dims %83 {axis = 1 : i32} : (tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<256x1xi1, #blocked1>
%94 = tt.broadcast %93 : (tensor<256x1xi1, #blocked1>) -> tensor<256x128xi1, #blocked1>
%95 = arith.andi %94, %86 : tensor<256x128xi1, #blocked1>
tt.store %91, %92, %95 : tensor<256x128xf16, #blocked1>
return
}
}

View File

@@ -1,91 +0,0 @@
import triton
import triton.language as tl
import torch
import pytest
from .test_core import numpy_random, to_triton
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
def __str__(self):
return f"#triton_gpu.mma<{{version={self.version}, warpsPerCTA={self.warps_per_cta}}}>"
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=2, warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=2, warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
if str(src_layout) == str(dst_layout):
pytest.skip()
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
pytest.skip()
ir = f"""
#src = {src_layout}
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<128x1xi32, #src>
%6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}
}
"""
x = to_triton(numpy_random(shape, dtype_str=dtype))
z = torch.empty_like(x)
# write the IR to a temporary file using mkstemp
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
kernel[(1,1,1)](x.data_ptr(), z.data_ptr())
assert torch.equal(z, x)

View File

@@ -1,32 +0,0 @@
import torch
import triton
import triton.language as tl
# trigger the torch.device implicitly to ensure cuda context initialization
torch.zeros([10], device=torch.device('cuda'))
@triton.jit
def empty_kernel(X, stride_xm, BLOCK: tl.constexpr):
pass
def test_empty_kernel_cubin_compile():
device = torch.cuda.current_device()
kernel = triton.compile(empty_kernel,
signature="*fp32,i32,i32",
device=device,
constants={"BLOCK": 256})
assert len(kernel.asm["cubin"]) > 0
def test_empty_kernel_launch():
grid = lambda META: (
triton.cdiv(1024, META['BLOCK']) * triton.cdiv(1024, META['BLOCK']),
)
A = torch.zeros([1024], device="cuda")
empty_kernel[grid](X=A, stride_xm=256, BLOCK=256)

View File

@@ -1,201 +0,0 @@
import tempfile
from inspect import Parameter, Signature
import _testcapi
import pytest
import torch
from torch.testing import assert_close
from tests.libdevice_testutil import system_libdevice_path
import triton
import triton.language as tl
torch_type = {
"bool": torch.bool,
"int32": torch.int32,
"float32": torch.float32,
"float64": torch.float64
}
torch_ops = {
"log": "log",
"cos": "cos",
"sin": "sin",
"sqrt": "sqrt",
"abs": "abs",
"exp": "exp",
"sigmoid": "sigmoid",
"umulhi": None,
"cdiv": None,
"fdiv": "div",
"minimum": "minimum",
"maximum": "maximum",
"where": "where",
}
def get_tensor(shape, data_type, b_positive=False):
x = None
if data_type.startswith('int'):
x = torch.randint(2**31 - 1, shape, dtype=torch_type[data_type], device='cuda')
elif data_type.startswith('bool'):
x = torch.randint(1, shape, dtype=torch_type[data_type], device='cuda')
else:
x = torch.randn(shape, dtype=torch_type[data_type], device='cuda')
if b_positive:
x = torch.abs(x)
return x
@pytest.mark.parametrize('expr, output_type, input0_type',
[('log', 'float32', 'float32'),
('log', 'float64', 'float64'),
('cos', 'float32', 'float32'),
('cos', 'float64', 'float64'),
('sin', 'float32', 'float32'),
('sin', 'float64', 'float64'),
('sqrt', 'float32', 'float32'),
('sqrt', 'float64', 'float64'),
('abs', 'float32', 'float32'),
('exp', 'float32', 'float32'),
('exp', 'float64', 'float64'),
('sigmoid', 'float32', 'float32'),
])
def test_single_input(expr, output_type, input0_type):
src = f"""
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.{expr}(x)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, input0_type, expr == 'log' or expr == 'sqrt')
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](
x, y,
BLOCK=shape[0],
extern_libs={"libdevice": system_libdevice_path()},
)
# reference result
y_ref = getattr(torch, torch_ops[expr])(x)
# compare
assert_close(y, y_ref)
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type',
[('umulhi', 'int32', 'int32', 'int32'),
('cdiv', 'int32', 'int32', 'int32'),
('fdiv', 'float32', 'float32', 'float32'),
('minimum', 'float32', 'float32', 'float32'),
('maximum', 'float32', 'float32', 'float32'),
])
def test_two_input(expr, output_type, input0_type, input1_type):
src = f"""
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
x0 = tl.load(X0 + tl.arange(0, BLOCK))
x1 = tl.load(X1 + tl.arange(0, BLOCK))
y = tl.{expr}(x0, x1)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X0, X1, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X0", 1))
parameters.append(Parameter("X1", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x0 = get_tensor(shape, input0_type)
x1 = get_tensor(shape, input1_type)
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](
x0, x1, y,
BLOCK=shape[0],
extern_libs={"libdevice": system_libdevice_path()},
)
# reference result
if expr == "cdiv":
y_ref = torch.div(x0 + x1 - 1, x1, rounding_mode='trunc')
elif expr == "umulhi":
y_ref = ((x0.to(torch.int64) * x1) >> 32).to(torch.int32)
else:
y_ref = getattr(torch, torch_ops[expr])(x0, x1)
# compare
assert_close(y, y_ref)
@pytest.mark.parametrize('expr, output_type, input0_type, input1_type, input2_type',
[('where', "int32", "bool", "int32", "int32"), ])
def test_three_input(expr, output_type, input0_type, input1_type, input2_type):
src = f"""
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
x0 = tl.load(X0 + tl.arange(0, BLOCK))
x1 = tl.load(X1 + tl.arange(0, BLOCK))
x2 = tl.load(X2 + tl.arange(0, BLOCK))
y = tl.{expr}(x0, x1, x2)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X0, X1, X2, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X0", 1))
parameters.append(Parameter("X1", 1))
parameters.append(Parameter("X2", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x0 = get_tensor(shape, input0_type)
x1 = get_tensor(shape, input1_type)
x2 = get_tensor(shape, input1_type)
# triton result
y = torch.zeros(shape, dtype=torch_type[output_type], device="cuda")
kernel[(1,)](
x0, x1, x2, y,
BLOCK=shape[0],
extern_libs={"libdevice": system_libdevice_path()},
)
# reference result
y_ref = getattr(torch, torch_ops[expr])(x0, x1, x2)
# compare
assert_close(y, y_ref)

View File

@@ -1,179 +0,0 @@
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
from tests.libdevice_testutil import system_libdevice_path
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_sin_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
x = tl.load(x_ptrs)
y = tl.libdevice.sin(x)
y_ptrs = y_ptr + offset
tl.store(y_ptrs, y)
x_ptr += iter_size
y_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_y = torch.sin(x)
assert_close(y, golden_y, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_fmin_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = tl.libdevice.min(x, y)
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_z = torch.minimum(x, y)
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_fmad_rn_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
w_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
z_ptrs = z_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = tl.load(z_ptrs)
w = tl.libdevice.fma_rn(x, y, z)
w_ptrs = w_ptr + offset
tl.store(w_ptrs, w)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
w_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float64)
y = torch.randn((block_size,), device='cuda', dtype=torch.float64)
z = torch.randn((block_size,), device='cuda', dtype=torch.float64)
w = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, w_ptr=w,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_w = x * y + z
assert_close(w, golden_w, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize("dtype_str, expr, lib_path",
[('int32', 'libdevice.ffs', system_libdevice_path()),
('int32', 'libdevice.ffs', '')])
def test_libdevice(dtype_str, expr, lib_path):
src = f"""
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
y = tl.{expr}(x)
tl.store(Y + tl.arange(0, BLOCK), y)
"""
import tempfile
from inspect import Parameter, Signature
import _testcapi
fp = tempfile.NamedTemporaryFile(mode='w', suffix=".py")
fp.write(src)
fp.flush()
def kernel(X, Y, BLOCK: tl.constexpr):
pass
kernel.__code__ = _testcapi.code_newempty(fp.name, "kernel", 1)
parameters = []
parameters.append(Parameter("X", 1))
parameters.append(Parameter("Y", 1))
parameters.append(Parameter("BLOCK", 1))
kernel.__signature__ = Signature(parameters=parameters)
kernel = triton.jit(kernel)
torch_type = {
"int32": torch.int32,
"float32": torch.float32,
"float64": torch.float64
}
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = None
if dtype_str == "int32":
x = torch.randint(2**31 - 1, shape, dtype=torch_type[dtype_str], device="cuda")
else:
x = torch.randn(shape, dtype=torch_type[dtype_str], device="cuda")
if expr == 'libdevice.ffs':
y_ref = torch.zeros(shape, dtype=x.dtype, device="cuda")
for i in range(shape[0]):
y_ref[i] = (int(x[i]) & int(-x[i])).bit_length()
# triton result
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0], extern_libs={"libdevice": lib_path})
# compare
assert_close(y, y_ref)

View File

@@ -1,314 +0,0 @@
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def matmul_no_scf_kernel(
a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr
):
offs_m = tl.arange(0, M)
offs_n = tl.arange(0, N)
offs_k = tl.arange(0, K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
c = tl.dot(a, b)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c)
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[128, 256, 32],
# [256, 128, 16],
[128, 16, 32],
[32, 128, 64],
[128, 128, 64],
[64, 128, 128],
]
for num_warps in [2, 4]
for trans_a in [False, True]
for trans_b in [False, True]
])
def test_gemm_no_scf(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
grid = lambda META: (1, )
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=SIZE_M, N=SIZE_N, K=SIZE_K,
num_warps=NUM_WARPS)
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
@pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [
(shape, num_warps, trans_a, trans_b)
for shape in [
[64, 128, 128],
[128, 128, 128],
[16, 16, 32],
[32, 16, 64],
[32, 16, 64],
]
for num_warps in [1, 2, 4]
for trans_a in [False, True]
for trans_b in [False, True]
])
def test_gemm_no_scf_int8(SHAPE, NUM_WARPS, TRANS_A, TRANS_B):
guard_for_volta(is_int8=True)
SIZE_M, SIZE_N, SIZE_K = SHAPE
if (TRANS_A):
a = torch.randint(-5, 5, (SIZE_K, SIZE_M), device='cuda', dtype=torch.int8).T
else:
a = torch.randint(-5, 5, (SIZE_M, SIZE_K), device='cuda', dtype=torch.int8)
if (TRANS_B):
b = torch.randint(-5, 5, (SIZE_N, SIZE_K), device='cuda', dtype=torch.int8).T
else:
b = torch.randint(-5, 5, (SIZE_K, SIZE_N), device='cuda', dtype=torch.int8)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.int32)
grid = lambda META: (1, )
matmul_no_scf_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=SIZE_M, N=SIZE_N, K=SIZE_K,
num_warps=NUM_WARPS)
aa = a.cpu()
bb = b.cpu()
golden = torch.matmul(aa.float(), bb.float()).int()
torch.set_printoptions(profile="full")
torch.testing.assert_close(c.cpu(), golden, check_dtype=False)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
offs_m = tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, accumulator)
def get_variant_golden(a, b):
SIZE_M = a.shape[0]
SIZE_K = a.shape[1]
SIZE_N = b.shape[1]
assert a.shape[1] == b.shape[0]
zero_M_K = torch.zeros((SIZE_M, SIZE_K)).cuda()
zero_3M_K = torch.zeros((3 * SIZE_M, SIZE_K)).cuda()
zero_K_N = torch.zeros((SIZE_K, SIZE_N)).cuda()
zero_3K_N = torch.zeros((3 * SIZE_K, SIZE_N)).cuda()
a_padded = torch.cat((a, zero_M_K, zero_M_K), 0)
a_padded = torch.cat((a_padded, zero_3M_K, zero_3M_K), 1)
b_padded = torch.cat((b, zero_K_N, zero_K_N), 0)
b_padded = torch.cat((b_padded, zero_3K_N, zero_3K_N), 1)
c_padded = torch.matmul(a_padded, b_padded)
return c_padded[:SIZE_M, :SIZE_N]
# It's not easy to get a proper error threshold in different size
# Here the gemm calculation is padded to a different size in order to get
# a variant version of the golden result. And the error between golden and
# golden_variant provide reference on selecting the proper rtol / atol.
def get_proper_err(a, b, golden):
golden_variant = get_variant_golden(a, b)
golden_diff = golden - golden_variant
golden_abs_err = torch.max(torch.abs(golden_diff)).item()
golden_rel_err = torch.max(torch.abs(golden_diff / golden)).item()
return (golden_abs_err, golden_rel_err)
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,TRANS_A,TRANS_B', [
# Non-forloop
[64, 32, 64, 4, 64, 32, 64, False, False],
[128, 64, 128, 4, 128, 64, 128, False, False],
[16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue
# K-Forloop
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
[16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k
[64, 32, 128, 4, 64, 32, 64, False, False],
[128, 16, 128, 4, 128, 16, 32, False, False],
[32, 16, 128, 4, 32, 16, 32, False, False],
[32, 64, 128, 4, 32, 64, 32, False, False],
[32, 128, 256, 4, 32, 128, 64, False, False],
[64, 128, 64, 4, 64, 128, 32, False, False],
[64, 64, 128, 4, 64, 64, 32, False, False],
[128, 128, 64, 4, 128, 128, 32, False, False],
[128, 128, 128, 4, 128, 128, 32, False, False],
[128, 128, 256, 4, 128, 128, 64, False, False],
[128, 256, 128, 4, 128, 256, 32, False, False],
[256, 128, 64, 4, 256, 128, 16, False, False],
[128, 64, 128, 4, 128, 64, 32, False, False],
[16, 16, 64, 4, 16, 16, 16, False, False],
[32, 32, 64, 4, 32, 32, 32, False, False],
# trans
[128, 64, 128, 4, 128, 64, 32, True, False],
[128, 64, 128, 4, 128, 64, 32, False, True],
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
if (TRANS_A):
a = torch.randn((SIZE_K, SIZE_M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((SIZE_N, SIZE_K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
grid = lambda META: (1, )
matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
M=a.shape[0], N=b.shape[1], K=a.shape[1],
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=NUM_WARPS)
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
torch.set_printoptions(profile="full")
assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False)
@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K,allow_tf32', [
[32, 32, 16, 4, 32, 32, 16, False],
[32, 32, 16, 4, 32, 32, 16, True],
[32, 16, 16, 4, 32, 32, 16, False],
[32, 16, 16, 4, 32, 32, 16, True],
[127, 41, 43, 4, 32, 32, 16, False],
[127, 41, 43, 4, 32, 32, 16, True],
[128, 8, 8, 4, 32, 32, 16, False],
[128, 8, 8, 4, 32, 32, 16, True]
])
def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
ALLOW_TF32: tl.constexpr
):
pid = tl.program_id(axis=0)
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, a_mask, other=0.0)
b = tl.load(b_ptrs, b_mask, other=0.0)
accumulator += tl.dot(a, b, allow_tf32=ALLOW_TF32)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_k += BLOCK_SIZE_K
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, c_mask)
guard_for_volta(is_tf32=allow_tf32)
# Configure the pytorch counterpart
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
matmul_kernel[grid](a, b, c,
M, N, K,
stride_am=a.stride(0), stride_ak=a.stride(1),
stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1),
BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K, ALLOW_TF32=allow_tf32)
golden = torch.matmul(a, b)
golden_abs_err, golden_rel_err = get_proper_err(a, b, golden)
if allow_tf32:
# TF32 is not accurate enough
torch.testing.assert_close(c, golden, rtol=max(1e-2, 1.5 * golden_rel_err), atol=max(1e-2, 1.5 * golden_abs_err))
else:
torch.testing.assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err))
def guard_for_volta(is_int8=False, is_tf32=False):
'''
Tell whether the test case is valid on Volta GPU.
Some features are WIP, so the corresponding support are missing.
'''
capability = torch.cuda.get_device_capability()
is_on_Volta = capability[0] < 8
# TODO[Superjomn]: Remove the constraints below when features are ready
is_feature_supported = not (is_int8 or is_tf32)
if is_on_Volta:
if (not is_feature_supported):
pytest.skip("Not valid on Volta")

View File

@@ -1,164 +0,0 @@
import pytest
import numpy as np
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
int_dtypes = ['int8', 'int16', 'int32', 'int64']
uint_dtypes = ['uint8'] # PyTorch does not support uint16/uint32/uint64
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = int_dtypes + uint_dtypes + float_dtypes
dtype_mapping = {dtype_str: torch.__dict__[dtype_str] for dtype_str in dtypes}
def get_reduced_dtype(op, dtype):
if op in ['argmin', 'argmax']:
return torch.int32
if dtype in [torch.int8, torch.int16, torch.uint8]:
return torch.int32
if dtype in [torch.bfloat16]:
return torch.float32
return dtype
def patch_kernel(template, to_replace):
kernel = triton.JITFunction(template.fn)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel
@triton.jit
def reduce1d_kernel(x_ptr, z_ptr, block: tl.constexpr):
x = tl.load(x_ptr + tl.arange(0, block))
tl.store(z_ptr, tl.OP(x, axis=0))
@triton.jit
def reduce2d_kernel(x_ptr, z_ptr, axis: tl.constexpr, block_m: tl.constexpr, block_n: tl.constexpr):
range_m = tl.arange(0, block_m)
range_n = tl.arange(0, block_n)
x = tl.load(x_ptr + range_m[:, None] * block_n + range_n[None, :])
z = tl.OP(x, axis=axis)
if axis == 0:
tl.store(z_ptr + range_n, z)
else:
tl.store(z_ptr + range_m, z)
reduce1d_configs = [
(op, dtype, shape)
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
for dtype in dtypes
for shape in [4, 8, 16, 32, 64, 128, 512, 1024]
]
@pytest.mark.parametrize('op, dtype, shape', reduce1d_configs)
def test_reduce1d(op, dtype, shape):
if op == 'xor_sum' and dtype in float_dtypes:
return
dtype = dtype_mapping[dtype]
reduced_dtype = get_reduced_dtype(op, dtype)
if dtype.is_floating_point:
x = torch.randn((shape,), device='cuda', dtype=dtype)
elif dtype is torch.uint8:
x = torch.randint(0, 20, (shape,), device='cuda', dtype=dtype)
else:
x = torch.randint(-20, 20, (shape,), device='cuda', dtype=dtype)
z = torch.empty(
tuple(),
device=x.device,
dtype=reduced_dtype,
)
kernel = patch_kernel(reduce1d_kernel, {'OP': op})
grid = (1,)
kernel[grid](x_ptr=x, z_ptr=z, block=shape)
if op == 'sum':
golden_z = torch.sum(x, dtype=reduced_dtype)
elif op == 'min':
golden_z = torch.min(x).to(reduced_dtype)
elif op == 'max':
golden_z = torch.max(x).to(reduced_dtype)
elif op == 'argmin':
golden_z = torch.argmin(x).to(reduced_dtype)
elif op == 'argmax':
golden_z = torch.argmax(x).to(reduced_dtype)
elif op == 'xor_sum':
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy())
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
else:
raise RuntimeError(f'Unknwon reduce op {op}')
if dtype.is_floating_point and op == 'sum':
if shape >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1)
elif shape >= 32:
assert_close(z, golden_z, rtol=0.05, atol=0.02)
else:
assert_close(z, golden_z, rtol=0.01, atol=0.01)
else:
assert_close(z, golden_z, rtol=0.001, atol=0.001)
reduce2d_configs = [
(op, dtype, shape, axis)
for op in ['sum', 'min', 'max', 'argmin', 'argmax', 'xor_sum']
for dtype in dtypes
for shape in [(1, 4), (1, 8), (1, 16), (1, 32), (2, 32), (4, 32), (4, 128), (32, 64)]
for axis in [0, 1]
]
@pytest.mark.parametrize('op, dtype, shape, axis', reduce2d_configs)
def test_reduce2d(op, dtype, shape, axis):
if op == 'xor_sum' and dtype in float_dtypes:
return
dtype = dtype_mapping[dtype]
reduced_dtype = get_reduced_dtype(op, dtype)
reduced_shape = (shape[1 - axis],)
if dtype.is_floating_point:
x = torch.randn(shape, device='cuda', dtype=dtype)
elif dtype is torch.uint8:
x = torch.randint(0, 20, shape, device='cuda', dtype=dtype)
else:
x = torch.randint(-20, 20, shape, device='cuda', dtype=dtype)
z = torch.empty(reduced_shape, device=x.device, dtype=reduced_dtype)
kernel = patch_kernel(reduce2d_kernel, {'OP': op})
kernel[(1,)](x_ptr=x, z_ptr=z, axis=axis, block_m=shape[0], block_n=shape[1])
if op == 'sum':
golden_z = torch.sum(x, dim=axis, keepdim=False, dtype=reduced_dtype)
elif op == 'min':
golden_z = torch.min(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
elif op == 'max':
golden_z = torch.max(x, dim=axis, keepdim=False)[0].to(reduced_dtype)
elif op == 'argmin':
golden_z = torch.argmin(x, dim=axis, keepdim=False).to(reduced_dtype)
elif op == 'argmax':
golden_z = torch.argmax(x, dim=axis, keepdim=False).to(reduced_dtype)
elif op == 'xor_sum':
sum_npy = np.bitwise_xor.reduce(x.cpu().numpy(), axis=axis, keepdims=False)
golden_z = torch.tensor(sum_npy, dtype=reduced_dtype).cuda()
else:
raise RuntimeError(f'Unknwon reduce op {op}')
if dtype.is_floating_point and op == 'sum':
if shape[axis] >= 256:
assert_close(z, golden_z, rtol=0.05, atol=0.1)
elif shape[axis] >= 32:
assert_close(z, golden_z, rtol=0.05, atol=0.02)
else:
assert_close(z, golden_z, rtol=0.01, atol=0.01)
else:
assert_close(z, golden_z, rtol=0.001, atol=0.001)

View File

@@ -1,47 +0,0 @@
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@triton.jit
def kernel(x_ptr, stride_xm,
z_ptr, stride_zn,
SIZE_M: tl.constexpr, SIZE_N: tl.constexpr):
off_m = tl.arange(0, SIZE_M)
off_n = tl.arange(0, SIZE_N)
Xs = x_ptr + off_m[:, None] * stride_xm + off_n[None, :] * 1
Zs = z_ptr + off_m[:, None] * 1 + off_n[None, :] * stride_zn
tl.store(Zs, tl.load(Xs))
# These sizes cover the case of:
# - blocked layout and sliced layout with block parent
# -- blocked layout in which sizePerThread/threadsPerWarp/warpsPerCTA
# need/need not to be wrapped
# -- sliced layout incase sizePerThread need to be wrapped
# -- different orders
# - LayoutConversion from blocked -> blocked
# - tt.Broadcast which requires for broadcast in either/both of
# CTA/perThread level
# What is not covered and requires for TODO:
# - vectorization load/store of shared memory
# - multiple replication of layout conversion
@pytest.mark.parametrize('NUM_WARPS,SIZE_M,SIZE_N', [
[1, 16, 16],
[1, 32, 32],
[1, 32, 64],
[2, 64, 128],
[2, 128, 64]
])
def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
grid = lambda META: (1, )
x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32)
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
golden_z = torch.t(x)
assert_close(z, golden_z, rtol=1e-7, atol=1e-7, check_dtype=False)

View File

@@ -1,215 +0,0 @@
import math
import random
import pytest
import torch
from torch.testing import assert_close
import triton
import triton.language as tl
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_vecadd_scf_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_z = x + y
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('shape, num_warps, block_size, iter_size', [
[(127, 3), 2, 128, 1],
[(127, 3), 2, 128, 32],
])
def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
num_elements,
block_size: tl.constexpr,
iter_size: tl.constexpr
):
'''
@block_size: size of a block
@iter_size: size of the iteration, a block has multiple iterations
@num_elements: number of elements
'''
pid = tl.program_id(axis=0)
for i in range(tl.cdiv(block_size, iter_size)):
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs, mask=offset < num_elements)
y = tl.load(y_ptrs, mask=offset < num_elements)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z, mask=offset < num_elements)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
x = torch.randn(shape, device='cuda', dtype=torch.float32)
y = torch.randn(shape, device='cuda', dtype=torch.float32)
z = torch.empty(shape, device=x.device, dtype=x.dtype)
grid = lambda EA: (math.ceil(x.numel() / block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps,
num_elements=x.numel())
golden_z = x + y
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
def vecadd_no_scf_tester(num_warps, block_size, shape):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
n_elements,
block_size_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * block_size_N + tl.arange(0, block_size_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
mask = offset < n_elements
x = tl.load(x_ptrs, mask=mask)
y = tl.load(y_ptrs, mask=mask)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z, mask=mask)
x = torch.randn(shape, device='cuda', dtype=torch.float32)
y = torch.randn(shape, device='cuda', dtype=torch.float32)
z = torch.empty(shape, device=x.device, dtype=x.dtype)
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
golden_z = x + y
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
'''
vecadd tester with float comparison as load/store mask.
'''
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
n_elements,
block_size_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * block_size_N + tl.arange(0, block_size_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
io_mask = offset < n_elements
x = tl.load(x_ptrs, mask=io_mask)
y = tl.load(y_ptrs, mask=io_mask)
z = x + y
val_mask = offset < n_elements and (z < 0. or z > 1.)
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z, mask=val_mask)
x = torch.randn(shape, device='cuda', dtype=torch.float32)
y = torch.randn(shape, device='cuda', dtype=torch.float32)
z = torch.zeros(shape, device=x.device, dtype=x.dtype)
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
golden_z: torch.Tensor = x + y
gz_data = torch.flatten(golden_z)
for i in range(golden_z.numel()):
gz_data[i] = gz_data[i] if gz_data[i] < 0. or gz_data[i] > 1. else 0.
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('num_warps, block_size, shape', [
[4, 256, (256,)],
[2, 256, (256,)],
[1, 256, (256,)],
[4, 16, (256,)],
[2, 64, (256,)],
[1, 128, (256,)],
])
def test_vecadd_no_scf(num_warps, block_size, shape):
vecadd_no_scf_tester(num_warps, block_size, shape)
@pytest.mark.parametrize('num_warps, block_size, shape', [
[1, 128, (256 + 1,)],
[1, 256, (256 + 1,)],
[2, 256, (3, 256 + 7)],
[4, 256, (3, 256 + 7)],
])
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
vecadd_no_scf_tester(num_warps, block_size, shape)
def test_vecadd_no_scf_masked_randomly():
random.seed(0) # fix seed to make random test reproducible
for i in range(10):
num_elements = random.randint(128, 2048)
shape = (num_elements,)
max_warps = num_elements // 32 # floor div
for num_warps in range(1, max_warps):
is_power2 = num_warps & (num_warps - 1) == 0 and num_warps != 0
if not is_power2: continue
block_size = min(32, num_warps * 32)
vecadd_no_scf_tester(num_warps, block_size, shape)
@pytest.mark.parametrize('num_warps, block_size, shape', [
[1, 128, (256 + 1,)],
[1, 256, (256 + 1,)],
[2, 256, (3, 256 + 7)],
[4, 256, (3, 256 + 7)],
])
def test_vecadd_fcmp_no_scf_masked(num_warps, block_size, shape):
vecadd_fcmp_no_scf_tester(num_warps, block_size, shape)

View File

@@ -25,7 +25,6 @@ from filelock import FileLock
import triton
import triton._C.libtriton.triton as _triton
from . import impl
from .tools.disasm import extract
@@ -392,7 +391,7 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs[then_name].type == else_defs[else_name].type:
names.append(then_name)
ret_types.append(then_defs[then_name].type)
# defined in else block but not in then block
# to find in parent scope and yield them
for else_name in else_defs:
@@ -642,7 +641,7 @@ class CodeGenerator(ast.NodeVisitor):
ub_si = self.builder.create_index_to_si(ub)
iv = self.builder.create_sub(ub_si, iv)
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(name, triton.language.core.tensor(iv, triton.language.core.int32))
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
# create YieldOp
self.builder.set_insertion_point_to_end(for_op.get_body(0))
@@ -735,10 +734,6 @@ class CodeGenerator(ast.NodeVisitor):
assert len(node.values) == 2
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.And: 'logical_and',
@@ -766,8 +761,8 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Attribute(self, node):
lhs = self.visit(node.value)
if isinstance(lhs, triton.language.tensor):
if node.attr == "T":
return triton.language.semantic.trans(lhs, builder=self.builder)
if node.attr == "T":
return triton.language.semantic.trans(lhs, builder=self.builder)
return getattr(lhs, node.attr)
def visit_Expr(self, node):
@@ -810,6 +805,7 @@ class OutOfResources(Exception):
self.message = f'out of resource: {name}, '\
f'Required: {required}, '\
f'Hardware limit: {limit}'
self.message += '. Reducing block sizes or `num_stages` may help.'
self.required = required
self.limit = limit
self.name = name
@@ -849,7 +845,7 @@ def build_triton_ir(fn, signature, specialization, constants):
gscope = fn.__globals__.copy()
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
tys = list(signature.values())
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
all_constants = constants.copy()
all_constants.update(new_constants)
@@ -967,23 +963,12 @@ def ptx_get_version(cuda_version) -> int:
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
version = major * 1000 + minor * 10
if version >= 11040:
return 74
if version >= 11030:
return 73
if version >= 11020:
return 72
if version >= 11010:
return 71
if version >= 11000:
return 70
if version >= 10020:
return 65
if version >= 10010:
return 64
if version >= 10000:
return 63
if major == 12:
return 80 + minor
if major == 11:
return 70 + minor
if major == 10:
return 63 + minor
raise RuntimeError("Triton only support CUDA 10.0 or higher")
@@ -1024,8 +1009,11 @@ def ty_to_cpp(ty):
"i64": "int64_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]
@@ -1055,6 +1043,8 @@ def generate_launcher(constants, signature):
'i64': 'int64_t',
'u32': 'uint32_t',
'u64': 'uint64_t',
'fp16': 'float',
'bf16': 'float',
'fp32': 'float',
'f32': 'float',
'fp64': 'double',
@@ -1072,7 +1062,7 @@ def generate_launcher(constants, signature):
"int64_t": "L",
}[ty]
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
# generate glue code
src = f"""
@@ -1130,11 +1120,37 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
PyObject *hook_ret = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject *new_args = PyTuple_Pack(1, compiled_kernel);
hook_ret = PyObject_CallObject(launch_enter_hook, new_args);
Py_DECREF(new_args);
}}
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject *new_args = NULL;
if (hook_ret) {{
new_args = PyTuple_Pack(2, compiled_kernel, hook_ret);
}} else {{
new_args = PyTuple_Pack(1, compiled_kernel);
}}
hook_ret = PyObject_CallObject(launch_exit_hook, new_args);
Py_DECREF(new_args);
}}
if (hook_ret) {{
Py_DECREF(hook_ret);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
@@ -1174,7 +1190,8 @@ def default_cache_dir():
def default_cuda_dir():
return os.path.join("/usr", "local", "cuda")
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)
class CacheManager:
@@ -1217,9 +1234,9 @@ class CacheManager:
@functools.lru_cache()
def libcuda_dir():
loc = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[-1]
return os.path.dirname(loc)
def libcuda_dirs():
locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:]
return [os.path.dirname(loc) for loc in locs]
@contextlib.contextmanager
@@ -1233,7 +1250,7 @@ def quiet():
def _build(name, src, srcdir):
cuda_lib_dir = libcuda_dir()
cuda_lib_dirs = libcuda_dirs()
cuda_path = os.environ.get('CUDA_PATH', default_cuda_dir())
cu_include_dir = os.path.join(cuda_path, "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX')
@@ -1246,12 +1263,16 @@ def _build(name, src, srcdir):
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
py_include_dir = get_paths()["include"]
ret = subprocess.check_call([cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{cuda_lib_dir}", "-lcuda", "-o", so])
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = [cuda_lib_dir]
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
@@ -1282,10 +1303,10 @@ def _build(name, src, srcdir):
return so
def make_so_cache_key(signature, constants):
def make_so_cache_key(version_hash, signature, constants):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{''.join(signature.values())}{constants}"
key = f"{version_hash}-{''.join(signature.values())}{constants}"
key = hashlib.md5(key.encode("utf-8")).hexdigest()
return key
@@ -1320,7 +1341,7 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
def make_stub(name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(signature, constants)
so_cache_key = make_so_cache_key(triton.runtime.jit.version_key(), signature, constants)
so_cache_manager = CacheManager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
@@ -1385,8 +1406,11 @@ arg_type_pattern = {
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, **kwargs):
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
capability = kwargs.get("cc", None)
if capability is None:
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
# we get the kernel, i.e. the first function generated in the module
# if fn is not a JITFunction, then it
# has to be a path to a file
@@ -1396,7 +1420,6 @@ def compile(fn, **kwargs):
num_warps = kwargs.get("num_warps", 4)
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
extern_libs = kwargs.get("extern_libs", dict())
device = kwargs.get("device", torch.cuda.current_device())
# build compilation stages
stages = {
"ast": (lambda path: fn, None),
@@ -1431,7 +1454,9 @@ def compile(fn, **kwargs):
import re
match = re.search(prototype_pattern[ir], src, re.MULTILINE)
name, signature = match.group(1), match.group(2)
print(name, signature)
types = re.findall(arg_type_pattern[ir], signature)
print(types)
param_tys = [convert_type_repr(ty) for ty in types]
signature = {k: v for k, v in enumerate(param_tys)}
first_stage = list(stages.keys()).index(ir)
@@ -1444,7 +1469,7 @@ def compile(fn, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
else:
name, ext = os.path.basename(fn).split(".")
name, ext = os.path.basename(fn).split(".")
# load metadata if any
metadata = None
@@ -1452,34 +1477,34 @@ def compile(fn, **kwargs):
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
metadata = json.load(f)
else:
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
if ext == "ptx":
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
metadata["shared"] = kwargs["shared"]
first_stage = list(stages.keys()).index(ext)
asm = dict()
module = fn
# run compilation pipeline and populate metadata
for ir, (parse, compile) in list(stages.items())[first_stage:]:
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
path = fn_cache_manager._make_path(f"{name}.{ir}")
if ir == ext:
next_module = parse(fn)
elif os.path.exists(path) and\
ir in metadata["ctime"] and\
os.path.getctime(path) == metadata["ctime"][ir]:
next_module = parse(path)
else:
next_module = compile(module)
fn_cache_manager.put(next_module, f"{name}.{ir}")
if os.path.exists(path):
metadata["ctime"][ir] = os.path.getctime(path)
asm[ir] = next_module if ir == "cubin" else str(next_module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = ptx_get_kernel_name(next_module)
module = next_module
# write-back metadata
fn_cache_manager.put(json.dumps(metadata), f"{name}.json", binary=False)
# return handle to compiled kernel
@@ -1488,6 +1513,10 @@ def compile(fn, **kwargs):
class CompiledKernel:
# Hooks for external tools to monitor the execution of triton kernels
launch_enter_hook = None
launch_exit_hook = None
def __init__(self, so_path, metadata, asm):
# initialize launcher
import importlib.util
@@ -1501,18 +1530,39 @@ class CompiledKernel:
self.num_stages = metadata["num_stages"]
# initialize asm dict
self.asm = asm
# binaries are lazily initialized
# because it involves doing runtime things
# (e.g., checking amount of shared memory on current device)
self.metadata = metadata
self.cu_module = None
self.cu_function = None
def _init_handles(self):
if self.cu_module is not None:
return
device = torch.cuda.current_device()
global cuda_utils
init_cuda_utils()
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
max_shared = cuda_utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
def __getattribute__(self, name):
if name == 'c_wrapper':
self._init_handles()
return super().__getattribute__(name)
def __getitem__(self, grid):
self._init_handles()
def runner(*args, stream=None):
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args)
return runner
def get_sass(self, fun=None):
@@ -1574,7 +1624,7 @@ class CudaUtils(object):
int sm_clock_rate;
int mem_clock_rate;
int mem_bus_width;
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, device));
CUDA_CHECK(cuDeviceGetAttribute(&max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
CUDA_CHECK(cuDeviceGetAttribute(&multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
CUDA_CHECK(cuDeviceGetAttribute(&mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));

View File

@@ -7,12 +7,8 @@ APIs defined in the `triton.impl` module which are public will be re-exported
in other relevant `triton` module namespaces.
"""
from .base import builtin, extern, is_builtin
from triton._C.libtriton.triton import ir
from .base import (
builtin,
extern,
is_builtin,
)
__all__ = [
"builtin",

View File

@@ -5,7 +5,7 @@ from ..impl import (
ir,
builtin,
)
from . import core, extern, libdevice, random
from . import libdevice
from .core import (
abs,
arange,
@@ -21,7 +21,8 @@ from .core import (
atomic_xor,
bfloat16,
block_type,
builtin,
broadcast,
broadcast_to,
cat,
cdiv,
constexpr,
@@ -107,6 +108,8 @@ __all__ = [
"atomic_xor",
"bfloat16",
"block_type",
"broadcast",
"broadcast_to",
"builtin",
"cat",
"cdiv",
@@ -128,6 +131,7 @@ __all__ = [
"int64",
"int8",
"ir",
"libdevice",
"load",
"log",
"max",

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from enum import Enum
from typing import List, Callable, TypeVar
from typing import Callable, List, TypeVar
import triton
from . import builtin, semantic
@@ -403,6 +403,18 @@ class constexpr:
def __neg__(self):
return constexpr(-self.value)
def __and__(self, other):
return constexpr(self.value & other.value)
def logical_and(self, other):
return constexpr(self.value and other.value)
def __or__(self, other):
return constexpr(self.value | other.value)
def logical_or(self, other):
return constexpr(self.value or other.value)
def __pos__(self):
return constexpr(+self.value)
@@ -596,11 +608,9 @@ class tensor:
if isinstance(slices, slice):
slices = [slices]
ret = self
n_inserted = 0
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
n_inserted += 1
ret = semantic.expand_dims(ret, dim, _builder)
elif sl == slice(None, None, None):
pass
else:
@@ -820,9 +830,9 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",
'type cache_modifier: str, optional
"""
# mask, other can be constexpr
if mask is not None:
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
if other is not None:
if _constexpr_to_value(other) is not None:
other = _to_tensor(other, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
@@ -846,7 +856,7 @@ def store(pointer, value, mask=None, _builder=None):
"""
# value can be constexpr
value = _to_tensor(value, _builder)
if mask is not None:
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)

View File

@@ -237,7 +237,11 @@ def mod(input: tl.tensor,
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
other, builder),
builder)
return ret
# % int
elif scalar_ty.is_int():
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
@@ -481,7 +485,7 @@ def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
# assert len(input.shape) == len(dst_shape)
numel = 1
for s in dst_shape:
@@ -505,12 +509,14 @@ def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder)
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
if len(input.shape) != 2:
raise ValueError("Only 2D tensors can be transposed")
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
return tl.tensor(builder.create_trans(input.handle), ret_type)
def broadcast_impl_shape(input: tl.tensor,
shape: List[int],
builder: ir.builder) -> tl.tensor:
@@ -631,9 +637,9 @@ def cast(input: tl.tensor,
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# Casting types of the same bit width: fp16 <=> bf16
if (src_sca_ty.is_fp16() and dst_sca_ty.is_bf16()) or \
(src_sca_ty.is_bf16() and dst_sca_ty.is_fp16()):
# bf16 <=> (not fp32)
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
# Standard floating types' casting: truncation
@@ -986,7 +992,7 @@ def dot(lhs: tl.tensor,
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[1].value == rhs.shape[0].value
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
_0 = builder.get_int32(0)
@@ -1051,6 +1057,13 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# If we are doing an argmin or argmax we want to use an int32 output type
out_scalar_ty = scalar_ty
if FLOAT_OP is ir.REDUCE_OP.ARGFMAX or INT_OP is ir.REDUCE_OP.ARGMAX:
out_scalar_ty = tl.int32
elif FLOAT_OP is ir.REDUCE_OP.ARGFMIN or INT_OP is ir.REDUCE_OP.ARGMIN:
out_scalar_ty = tl.int32
# get result type
shape = input.type.shape
ret_shape = []
@@ -1058,10 +1071,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
if i != axis:
ret_shape.append(s)
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
res_ty = tl.block_type(out_scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
res_ty = out_scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
@@ -1103,10 +1116,17 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
# FIXME(Keren): not portable, should be fixed
from . import libdevice
return libdevice.floor(x, _builder=builder)
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_exp(x.handle), x.type)
@@ -1130,12 +1150,12 @@ def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
##
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values")

View File

@@ -27,29 +27,28 @@ def get_configs_io_bound():
@triton.autotune(
#configs=[
# # basic configs for compute-bound matmuls
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# # good for int8
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
#] + get_configs_io_bound(),
configs=[triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=2, num_warps=8)],
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
@@ -114,7 +113,7 @@ def _kernel(A, B, C, M, N, K,
class _matmul(torch.autograd.Function):
kernel = None
kernel = _kernel
_locks = dict()
@@ -135,17 +134,12 @@ class _matmul(torch.autograd.Function):
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
#grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
if _matmul.kernel is None:
_matmul.kernel = triton.compile("/root/code/triton-mlir/python/tests/matmul.ttgir", num_stages=2, num_warps=8)
#_matmul.kernel = _kernel
_matmul.kernel[(8192//256 * 8192//128, 1, 1,)](a.data_ptr(), b.data_ptr(), c.data_ptr(),
M, N, K,
a.stride(0), b.stride(0), c.stride(0))
#_matmul.kernel[grid](a, b, c,
# M, N, K,
# a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),
# GROUP_M=8, ACC_TYPE=ACC_TYPE)
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
@staticmethod

View File

@@ -99,7 +99,6 @@ def estimate_matmul_time(
def early_config_prune(configs, named_args):
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages

View File

@@ -4,6 +4,7 @@ import builtins
import time
from typing import Dict
from ..compiler import OutOfResources
from ..testing import do_bench
from .jit import KernelInterface
@@ -60,7 +61,10 @@ class Autotuner(KernelInterface):
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return do_bench(kernel_call)
try:
return do_bench(kernel_call)
except OutOfResources:
return float('inf')
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
@@ -74,8 +78,6 @@ class Autotuner(KernelInterface):
for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
for config, ttime in timings.items():
print(f"config: {config}, time: {ttime}")
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings

View File

@@ -7,8 +7,8 @@ import inspect
import os
import subprocess
import textwrap
from collections import namedtuple
from typing import TypeVar, Generic, cast, Callable, overload, Optional, Iterable, Union
from collections import defaultdict, namedtuple
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload
import torch
@@ -110,9 +110,9 @@ class KernelInterface(Generic[T]):
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
divisibility = 16
@@ -258,31 +258,30 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if stream is None and not warmup:
stream = get_cuda_stream(device)
try:
bin = cache[key]
bin = cache[device][key]
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args})
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin
# kernel not cached -- compile
except KeyError:
# build dict of constant values
args = [{args}]
configs = self._get_config(*args),
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
device = 0
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
self.cache[key] = bin
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
"""
@@ -307,7 +306,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels
self.cache = dict()
self.cache = defaultdict(dict)
self.hash = None
# JITFunction can be instantiated as kernel
# when called with a grid using __getitem__

View File

@@ -105,7 +105,6 @@ def allclose(x, y, tol=1e-2):
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
tol = 1e-2
err = torch.max(diff) / torch.max(x_max, y_max)
return err <= tol
@@ -119,7 +118,9 @@ def nvsmi(attrs):
return ret
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0.8], record_clocks=False):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None,
percentiles=(0.5, 0.2, 0.8),
record_clocks=False, fast_flush=False):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
@@ -134,6 +135,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
:type grad_to_none: torch.tensor, optional
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
# Estimate the runtime of the function
@@ -155,7 +158,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, percentiles=[0.5, 0.2, 0
# doesn't contain any input data before the run
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Warm-up
for _ in range(n_warmup):
fn()
@@ -333,8 +339,8 @@ def get_dram_gbps(backend=None, device=None):
backend = _triton.runtime.backend.CUDA
if not device:
device = torch.cuda.current_device()
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
mem_clock_khz = triton.compiler.cuda_utils.get_device_properties(device)["mem_clock_rate"] # in kHz
bus_width = triton.compiler.cuda_utils.get_device_properties(device)["mem_bus_width"]
bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
return bw_gbps
@@ -348,7 +354,7 @@ def get_max_tensorcore_tflops(dtype: torch.dtype, backend=None, device=None, clo
triton.compiler.init_cuda_utils()
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
if not clock_rate:
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8:
assert dtype == torch.float16

View File

@@ -1,10 +1,24 @@
import argparse
import subprocess
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
class Symbol:
def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None:
_name: str
_op_name: str
_ret_type: str
_arg_names: List[str]
_arg_types: List[str]
def __init__(
self,
name: str,
op_name: str,
ret_type: str,
arg_names: List[str],
arg_types: List[str],
) -> None:
'''
A symbol is a function declaration.
:param name: name of the symbol
@@ -16,31 +30,31 @@ class Symbol:
self._name = name
self._op_name = op_name
self._ret_type = ret_type
self._arg_names = arg_names
self._arg_types = arg_types
self._arg_names = list(arg_names)
self._arg_types = list(arg_types)
@property
def name(self):
def name(self) -> str:
return self._name
@property
def op_name(self):
def op_name(self) -> str:
return self._op_name
@property
def ret_type(self):
def ret_type(self) -> str:
return self._ret_type
@property
def arg_names(self):
def arg_names(self) -> List[str]:
return self._arg_names
@property
def arg_types(self):
def arg_types(self) -> List[str]:
return self._arg_types
def convert_type(type_str):
def convert_type(type_str) -> Optional[str]:
if type_str == "i32":
return "int32"
elif type_str == "u32":
@@ -58,7 +72,7 @@ def convert_type(type_str):
return None
def to_unsigned(type_str):
def to_unsigned(type_str) -> str:
if type_str == "int32":
return "uint32"
elif type_str == "int64":
@@ -68,7 +82,19 @@ def to_unsigned(type_str):
class ExternLibrary(ABC):
def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None:
_name: str
_path: str
_symbols: Dict[str, Symbol]
_format: bool
_grouping: bool
def __init__(
self,
name: str,
path: str,
format: bool = True,
grouping: bool = True,
) -> None:
'''
Abstract class for extern library.
:param name: name of the library
@@ -78,34 +104,34 @@ class ExternLibrary(ABC):
self._name = name
self._path = path
self._symbols = {}
self._format = True
self._format = format
self._grouping = grouping
@property
def name(self):
def name(self) -> str:
return self._name
@property
def path(self):
def path(self) -> str:
return self._path
@property
def symbols(self):
def symbols(self) -> Dict[str, Symbol]:
return self._symbols
@property
def grouping(self):
def grouping(self) -> bool:
return self._grouping
@abstractmethod
def parse_symbols(self, input_file):
def parse_symbols(self, input_file) -> None:
pass
@abstractmethod
def _output_stubs(self) -> str:
pass
def generate_stub_file(self, output_dir):
def generate_stub_file(self, output_dir) -> None:
file_str = self._output_stubs()
if file_str is None or len(file_str) == 0:
raise Exception("file_str is empty")
@@ -121,6 +147,8 @@ class ExternLibrary(ABC):
class Libdevice(ExternLibrary):
_symbol_groups: Dict[str, List[Symbol]]
def __init__(self, path) -> None:
'''
Constructor for Libdevice.
@@ -129,7 +157,7 @@ class Libdevice(ExternLibrary):
super().__init__("libdevice", path)
self._symbol_groups = {}
def _extract_symbol(self, line):
def _extract_symbol(self, line) -> Optional[Symbol]:
# Extract symbols from line in the following format:
# "define [internal] <ret_type> @<name>(<arg_types>,)"
entries = line.split("@")
@@ -170,7 +198,7 @@ class Libdevice(ExternLibrary):
arg_types[i] = to_unsigned(arg_type)
return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
def _group_symbols(self):
def _group_symbols(self) -> None:
symbol_set = {}
for symbol in self._symbols.values():
op_name = symbol.op_name
@@ -240,7 +268,7 @@ class Libdevice(ExternLibrary):
else:
self._symbol_groups[op_name] = [symbol]
def parse_symbols(self, input_file):
def parse_symbols(self, input_file) -> None:
if len(self.symbols) > 0:
return
output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
@@ -252,7 +280,7 @@ class Libdevice(ExternLibrary):
self._group_symbols()
def _output_stubs(self):
def _output_stubs(self) -> str:
# Generate python functions in the following format:
# @extern.extern
# def <op_name>(<args>, _builder=None):
@@ -293,7 +321,10 @@ class Libdevice(ExternLibrary):
class LLVMDisassembler:
def __init__(self, path):
_path: str
_ll_file: str
def __init__(self, path) -> None:
'''
Invoke llvm-dis to disassemble the given file.
:param path: path to llvm-dis
@@ -301,23 +332,28 @@ class LLVMDisassembler:
self._path = path
self._ll_file = "/tmp/extern_lib.ll"
def disasm(self, lib_path):
def disasm(self, lib_path: str) -> None:
subprocess.Popen([self._path, lib_path, "-o", self.ll_file],
stdout=subprocess.PIPE).communicate()
@property
def ll_file(self):
def ll_file(self) -> str:
return self._ll_file
@property
def path(self):
def path(self) -> str:
return self._path
extern_libs = ["libdevice"]
def build(llvm_dis_path, lib_path, lib_name, output_dir):
def build(
llvm_dis_path: str,
lib_path: str,
lib_name: str,
output_dir: str,
) -> None:
'''
Interface function to build the library file.
:param llvm_dis_path: path to the llvm-dis binary

View File

@@ -3,9 +3,9 @@ Vector Addition
=================
In this tutorial, you will write a simple vector addition using Triton and learn about:
- The basic programming model of Triton
- The basic programming model of Triton.
- The `triton.jit` decorator, which is used to define Triton kernels.
- The best practices for validating and benchmarking your custom ops against native reference implementations
- The best practices for validating and benchmarking your custom ops against native reference implementations.
"""
# %%
@@ -20,51 +20,51 @@ import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)
# %%
# Let's also declare a helper function to (1) allocate the `z` tensor
# and (2) enqueue the above kernel with appropriate grid/block sizes.
# and (2) enqueue the above kernel with appropriate grid/block sizes:
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
# In this case, we use a 1D grid where the size is the number of blocks
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
@@ -94,24 +94,24 @@ print(
# Benchmark
# -----------
# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch.
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops
# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of your custom ops.
# for different problem sizes.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # argument names to use as an x-axis for the plot
x_names=['size'], # Argument names to use as an x-axis for the plot.
x_vals=[
2 ** i for i in range(12, 28, 1)
], # different possible values for `x_name`
x_log=True, # x axis is logarithmic
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg`
line_names=['Triton', 'Torch'], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel='GB/s', # label name for the y-axis
plot_name='vector-add-performance', # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
], # Different possible values for `x_name`.
x_log=True, # x axis is logarithmic.
line_arg='provider', # Argument name whose value corresponds to a different line in the plot.
line_vals=['triton', 'torch'], # Possible values for `line_arg`.
line_names=['Triton', 'Torch'], # Label name for the lines.
styles=[('blue', '-'), ('green', '-')], # Line styles.
ylabel='GB/s', # Label name for the y-axis.
plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot.
args={}, # Values for function arguments not in `x_names` and `y_name`.
)
)
def benchmark(size, provider):
@@ -127,5 +127,5 @@ def benchmark(size, provider):
# %%
# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or
# `save_path='/path/to/results/' to save them to disk along with raw CSV data
# benchmark.run(print_data=True, show_plots=True)
# `save_path='/path/to/results/' to save them to disk along with raw CSV data:
benchmark.run(print_data=True, show_plots=True)

View File

@@ -156,7 +156,7 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
],
key=['M', 'N', 'K'],
)

View File

@@ -17,35 +17,51 @@ except ModuleNotFoundError:
HAS_APEX = False
# Forward Pass
@triton.jit
def _layer_norm_fwd_fused(X, Y, W, B, M, V, stride, N, eps,
BLOCK_SIZE: tl.constexpr):
def _layer_norm_fwd_fused(
A,
Out,
Weight,
Bias,
Mean, Rstd,
stride, N, eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# offset data pointers to start at the row of interest
X += row * stride
Y += row * stride
# load data and cast to float32
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
Out += row * stride
A += row * stride
# compute mean
mean = tl.sum(x, axis=0) / N
# compute std
xmean = tl.where(mask, x - mean, 0.)
var = tl.sum(xmean * xmean, axis=0) / N
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
xhat = xmean * rstd
# write-back mean/rstd
tl.store(M + row, mean)
tl.store(V + row, rstd)
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
y = xhat * w + b
# write-back
tl.store(Y + cols, y, mask=mask)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
# Backward pass (DX + partial DW + partial DB)
@@ -257,5 +273,6 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c
grad_to_none=[x], rep=500)
return gbps(ms), gbps(max_ms), gbps(min_ms)
# test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)
test_layer_norm(1151, 8192, torch.float16)
# bench_layer_norm.run(save_path='.', print_data=True)

View File

@@ -196,6 +196,8 @@ def _bwd_kernel(
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
@@ -248,7 +250,8 @@ class _attention(torch.autograd.Function):
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
num_warps = 4 if ctx.BLOCK_DMODEL <= 64 else 8
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
@@ -305,14 +308,21 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)],
line_arg='provider',
line_vals=['triton'],
line_names=['Triton'],
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
@@ -350,4 +360,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
# bench_flash_attention.run(save_path='.', print_data=True)
# bench_flash_attention.run(save_path='.', print_data=True)

View File

@@ -4,7 +4,7 @@
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
@@ -52,6 +52,15 @@ func @convert(%A : !tt.ptr<f16>) {
return
}
// CHECK-LABEL: trans
func @trans(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK: %0 -> %cst
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>

View File

@@ -5,7 +5,7 @@
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
@@ -174,6 +174,14 @@ func @scratch() {
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: trans
func @trans(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
@@ -285,6 +293,25 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
// CHECK-NEXT: size = 24576
}
// c0 cannot be released in the loop
// CHECK-LABEL: for_use_ancestor
func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 16384, size = 8192
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
%c0 = tt.trans %c_shared_init : (tensor<128x32xf16, #A_SHARED>) -> tensor<32x128xf16, #A_SHARED>
// CHECK-NEXT: offset = 24576, size = 8192
%c1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
scf.yield %b_shared, %a_shared: tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
}
return
// CHECK-NEXT: size = 32768
}
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_if_for

View File

@@ -5,7 +5,7 @@
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
@@ -111,6 +111,13 @@ func @extract_slice() {
return
}
// CHECK-LABEL: trans
func @trans() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
return
}
// CHECK-LABEL: insert_slice_async
func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>

View File

@@ -710,7 +710,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}>
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
@@ -748,7 +748,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav2_block
@@ -768,7 +768,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [2, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
@@ -853,7 +853,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}>
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
@@ -878,7 +878,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}>
#mma = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
@@ -923,7 +923,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// -----
#mma = #triton_gpu.mma<{version=2, warpsPerCTA=[2, 2]}>
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
@@ -997,20 +997,61 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
%v0 = arith.addi %blockdimx, %blockdimy : i32
%v1 = arith.addi %v0, %blockdimz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z
%blockdimx = tt.get_num_programs {axis=0:i32} : i32
%blockdimy = tt.get_num_programs {axis=1:i32} : i32
%blockdimz = tt.get_num_programs {axis=2:i32} : i32
%v0 = arith.addi %blockdimx, %blockdimy : i32
%v1 = arith.addi %v0, %blockdimz : i32
%0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0>
tt.store %a, %0 : tensor<32xi32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: test_index_cache
func @test_index_cache() {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_base_index_cache
func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_index_cache_different_block
func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
scf.if %arg1 {
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
%1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
}
return
}
}

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -7,7 +7,6 @@
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
@@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
// CHECK-LABEL: transpose
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]>
// CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]>
// CHECK: return
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
@@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
tt.store %21, %22 : tensor<256xf32, #layout1>
return
}
// -----
// check the UpdateMMAVersionMinorForVolta pattern
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}>
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
// and the pattern should update the versionMinor.
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
// It creates a new MMA layout to fit with $a and $b's dot_operand
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: dot_mmav1
func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b>
%CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0>
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]>
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0>
%res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0>
return %res : tensor<16x16xf32, #blocked0>
}
}

View File

@@ -4,7 +4,7 @@
// matmul: 128x32 @ 32x128 -> 128x128
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>

View File

@@ -1,5 +1,5 @@
add_triton_ut(
NAME TestPtxAsmFormat
SRCS PtxAsmFormatTest.cpp
SRCS PTXAsmFormatTest.cpp
LIBS TritonGPUToLLVM
)

View File

@@ -1,16 +1,17 @@
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/Builders.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include <gtest/gtest.h>
namespace mlir {
namespace triton {
class PtxAsmFormatTest : public ::testing::Test {
class PTXAsmFormatTest : public ::testing::Test {
protected:
static constexpr int numValues = 4;
PtxAsmFormatTest() {
PTXAsmFormatTest() {
ctx.loadDialect<arith::ArithmeticDialect>();
createValues();
@@ -34,7 +35,7 @@ protected:
Value v[numValues + 1];
};
TEST_F(PtxAsmFormatTest, basic) {
TEST_F(PTXAsmFormatTest, basic) {
PTXBuilder builder;
// Create the operands needed by the instructions in the PTX code.
@@ -55,7 +56,7 @@ TEST_F(PtxAsmFormatTest, basic) {
ASSERT_EQ(constraints, "=r,b"); // $0 -> =r, $1 -> b
}
TEST_F(PtxAsmFormatTest, complexInstruction) {
TEST_F(PTXAsmFormatTest, complexInstruction) {
using triton::CacheModifier;
using triton::EvictionPolicy;
@@ -99,7 +100,7 @@ TEST_F(PtxAsmFormatTest, complexInstruction) {
EXPECT_EQ(builder.getConstraints(), "l,b");
}
TEST_F(PtxAsmFormatTest, MultiLinePTX) {
TEST_F(PTXAsmFormatTest, MultiLinePTX) {
PTXBuilder builder;
auto *constVal = builder.newConstantOperand(1);
@@ -121,7 +122,7 @@ TEST_F(PtxAsmFormatTest, MultiLinePTX) {
EXPECT_EQ(values[1], v[2]); // $1 -> v[2]
}
TEST_F(PtxAsmFormatTest, onlyAttachMLIRArgs) {
TEST_F(PTXAsmFormatTest, onlyAttachMLIRArgs) {
PTXBuilder builder;
const char *ptxCode =
".param .b64 param0;\n" // prepare param0 (format string)

View File

@@ -28,7 +28,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
MLIRContext ctx;
ctx.loadDialect<triton::gpu::TritonGPUDialect>();
// create encoding
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, {1, 1});
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1});
auto encoding =
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
@@ -50,4 +50,4 @@ INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture,
ParamT{{32, 32}, 0, 16, {8, 2, 4}},
ParamT{{32, 32}, 1, 16, {8, 2, 4}},
ParamT{{16, 16}, 0, 16, {8, 4, 2}},
ParamT{{16, 16}, 1, 16, {8, 4, 2}}));
ParamT{{16, 16}, 1, 16, {8, 4, 2}}));