[Triton] Support math and libdevice ops (#91)

This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) 
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
  - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
  - No constant folding etc for `libdevice` ops.

```py
import triton
import triton.language as tl
import sys

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offsets)
    x = tl.sin(x)
    output = tl.libdevice.sin(x)
    output = tl.libdevice.fdiv_rn(output, output)
    output = tl.libdevice.fmaf_rd(output, output, output)
    tl.store(y_ptr + offsets, output)


if __name__ == "__main__" and len(sys.argv) >= 2:
    signature = "*fp32,*fp32"
    constants = {'BLOCK_SIZE': 1024}
    output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
    print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
  func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
    %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
    %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
    %4 = math.sin %3 : tensor<1024xf32, #blocked>
    %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
    %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
    %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.store %9, %7 : tensor<1024xf32, #blocked>
    return
  }
}
```
This commit is contained in:
Shintaro Iwasaki
2022-09-01 16:34:27 -07:00
committed by GitHub
parent 328b87aec6
commit 3c635449e5
18 changed files with 1938 additions and 51 deletions

View File

@@ -6,11 +6,12 @@ include "mlir/Pass/PassBase.td"
def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> {
let summary = "Convert Triton to TritonGPU";
let description = [{
}];
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::StandardOpsDialect",
// TODO: Does this pass depend on SCF?
"mlir::scf::SCFDialect",
@@ -33,6 +34,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
"mlir::gpu::GPUDialect",
"mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",

View File

@@ -2,6 +2,7 @@
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"

View File

@@ -16,12 +16,15 @@ def Triton_Dialect : Dialect {
Dependent Dialects:
* Arithmetic:
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
* Math:
* exp, sin, cos, log, ...
* StructuredControlFlow:
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
}];
let dependentDialects = [
"arith::ArithmeticDialect",
"math::MathDialect",
"StandardOpsDialect",
"scf::SCFDialect",
"gpu::GPUDialect",

View File

@@ -283,6 +283,26 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
let results = (outs TT_Type:$result);
}
//
// External Function Ops
//
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
SameVariadicOperandSize]> {
let summary = "ext_elemwise";
let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];
let arguments = (ins Variadic<TT_Tensor>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
}
//
// Intrinsics
//