Files
triton/bin/triton-opt.cpp
Shintaro Iwasaki 3c635449e5 [Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) 
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
  - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
  - No constant folding etc for `libdevice` ops.

```py
import triton
import triton.language as tl
import sys

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    x = tl.sin(x)
    output = tl.libdevice.sin(x)
    output = tl.libdevice.fdiv_rn(output, output)
    output = tl.libdevice.fmaf_rd(output, output, output)
    tl.store(y_ptr + offsets, output)


if __name__ == "__main__" and len(sys.argv) >= 2:
    signature = "*fp32,*fp32"
    constants = {'BLOCK_SIZE': 1024}
    output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
    print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
    %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
    %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
    %4 = math.sin %3 : tensor<1024xf32, #blocked>
    %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.store %9, %7 : tensor<1024xf32, #blocked>
    return
  }
}
```
2022-09-01 16:34:27 -07:00

43 lines
1.4 KiB
C++

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Conversion/Passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
namespace mlir {
namespace test {
void registerTestAliasPass();
void registerTestAlignmentPass();
void registerTestAllocationPass();
void registerTestMembarPass();
} // namespace test
} // namespace mlir
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAliasPass();
mlir::test::registerTestAlignmentPass();
mlir::test::registerTestAllocationPass();
mlir::test::registerTestMembarPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
mlir::triton::registerConvertTritonGPUToLLVMPass();
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
return mlir::asMainReturnCode(mlir::MlirOptMain(
argc, argv, "Triton (GPU) optimizer driver\n", registry));
}