2022-07-26 17:25:03 -07:00
|
|
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
2022-10-21 16:52:15 -07:00
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2022-07-26 17:25:03 -07:00
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2022-04-30 14:31:18 -07:00
|
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
2022-05-02 21:51:00 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
2022-05-01 22:06:54 +08:00
|
|
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
2022-08-09 00:40:20 +08:00
|
|
|
#include "llvm/ADT/APSInt.h"
|
2022-08-04 18:41:06 -07:00
|
|
|
#include <numeric>
|
2022-04-30 14:31:18 -07:00
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::triton;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
2022-05-02 22:31:29 +08:00
|
|
|
public:
|
|
|
|
using OpConversionPattern<Op>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-02 22:31:29 +08:00
|
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
2022-10-28 12:36:09 -07:00
|
|
|
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
2022-05-04 12:50:02 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
template <class SrcOp, class DstOp>
|
2022-05-04 14:54:31 +08:00
|
|
|
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
2022-05-04 12:50:02 +08:00
|
|
|
public:
|
2022-05-04 14:54:31 +08:00
|
|
|
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
2022-05-04 12:50:02 +08:00
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-04 12:50:02 +08:00
|
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
2022-10-28 12:36:09 -07:00
|
|
|
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
|
|
|
adaptor.getLhs(), adaptor.getRhs());
|
2022-05-02 22:31:29 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-05-04 15:35:43 +08:00
|
|
|
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-04 15:35:43 +08:00
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
2022-05-04 21:50:32 +08:00
|
|
|
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
|
|
|
assert(value);
|
2022-05-04 15:35:43 +08:00
|
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
2022-07-26 17:25:03 -07:00
|
|
|
op, retType,
|
|
|
|
value.reshape(retType) // This is a hack. We just want to add encoding
|
2022-05-04 15:35:43 +08:00
|
|
|
);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
class ConvertArithmeticOp : public ConversionPattern {
|
2022-04-30 20:42:25 -07:00
|
|
|
public:
|
2022-07-26 17:25:03 -07:00
|
|
|
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
|
|
|
|
MLIRContext *context)
|
|
|
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
|
|
|
context) {}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Dialect *dialect = op->getDialect();
|
|
|
|
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
|
|
|
return failure();
|
|
|
|
return success();
|
|
|
|
}
|
2022-04-30 20:42:25 -07:00
|
|
|
};
|
|
|
|
|
|
|
|
void populateArithmeticPatternsAndLegality(
|
2022-07-26 17:25:03 -07:00
|
|
|
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
|
|
|
|
TritonGPUConversionTarget &target) {
|
2022-04-30 20:42:25 -07:00
|
|
|
// --------------
|
|
|
|
// Add legality and rewrite pattern rules for operations
|
|
|
|
// from the Arithmetic dialect. The basic premise is that
|
|
|
|
// arithmetic operations require both inputs to have the same
|
|
|
|
// non-null encoding
|
|
|
|
// --------------
|
|
|
|
MLIRContext *context = patterns.getContext();
|
2022-10-21 16:52:15 -07:00
|
|
|
// TODO: there's probably a better way to avoid adding all ops one-by-one
|
2022-07-26 17:25:03 -07:00
|
|
|
patterns.add<
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
|
|
|
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
|
|
|
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
|
|
|
|
GenericOpPattern<arith::CeilDivUIOp>,
|
|
|
|
GenericOpPattern<arith::CeilDivSIOp>,
|
|
|
|
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
|
|
|
|
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
|
|
|
|
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
|
|
|
|
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
|
|
|
|
GenericOpPattern<arith::ShRSIOp>, // NegFOp
|
2022-07-26 17:25:03 -07:00
|
|
|
// Floating point
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
2022-07-26 17:25:03 -07:00
|
|
|
// MaxMin
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
|
|
|
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
|
|
|
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
2022-07-26 17:25:03 -07:00
|
|
|
// Floating point
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
|
|
|
GenericOpPattern<arith::RemFOp>,
|
2022-07-26 17:25:03 -07:00
|
|
|
// Cmp
|
|
|
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
|
|
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
|
|
|
// Cast Ops
|
2022-10-21 16:52:15 -07:00
|
|
|
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
2022-10-24 19:47:01 -07:00
|
|
|
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
|
|
|
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
2022-12-08 09:07:01 -08:00
|
|
|
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
|
2022-10-24 19:47:01 -07:00
|
|
|
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
2022-10-21 16:52:15 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
|
|
|
class StdSelectPattern : public OpConversionPattern<SelectOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<SelectOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
2022-10-28 12:36:09 -07:00
|
|
|
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
|
|
|
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
|
|
|
adaptor.getFalseValue());
|
2022-10-21 16:52:15 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
|
|
RewritePatternSet &patterns,
|
|
|
|
TritonGPUConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
// Rewrite rule
|
|
|
|
patterns.add<StdSelectPattern>(typeConverter, context);
|
|
|
|
target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
|
|
|
|
// by the frontend
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
|
|
RewritePatternSet &patterns,
|
|
|
|
TritonGPUConversionTarget &target) {
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
// Rewrite rule
|
|
|
|
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
|
|
|
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
|
|
|
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
2022-04-30 20:42:25 -07:00
|
|
|
}
|
|
|
|
|
2022-05-02 21:51:00 +08:00
|
|
|
//
|
|
|
|
// Triton patterns
|
|
|
|
//
|
|
|
|
// TODO: Do we need to put them in anonymous namespace?
|
2022-07-26 17:25:03 -07:00
|
|
|
struct TritonMakeRangePattern
|
|
|
|
: public OpConversionPattern<triton::MakeRangeOp> {
|
2022-05-02 21:51:00 +08:00
|
|
|
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-02 21:51:00 +08:00
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
2022-07-26 17:25:03 -07:00
|
|
|
op, retType, adaptor.start(), adaptor.end());
|
2022-05-02 21:51:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-08-04 18:41:06 -07:00
|
|
|
struct TritonExpandDimsPattern
|
|
|
|
: public OpConversionPattern<triton::ExpandDimsOp> {
|
|
|
|
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// Type retType = op.getType());
|
|
|
|
RankedTensorType argType = adaptor.src().getType().cast<RankedTensorType>();
|
|
|
|
Attribute _argEncoding = argType.getEncoding();
|
|
|
|
if (!_argEncoding)
|
|
|
|
return failure();
|
2022-08-18 12:49:37 -07:00
|
|
|
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
2022-08-04 18:41:06 -07:00
|
|
|
// return shape
|
|
|
|
auto retShape = argType.getShape().vec();
|
|
|
|
retShape.insert(retShape.begin() + op.axis(), 1);
|
|
|
|
// return encoding
|
|
|
|
auto retSizePerThread = argEncoding.getSizePerThread().vec();
|
|
|
|
retSizePerThread.insert(retSizePerThread.begin() + op.axis(), 1);
|
|
|
|
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
|
|
|
|
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.axis(), 1);
|
|
|
|
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
|
|
|
|
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
|
|
|
|
SmallVector<unsigned, 4> retOrder(retShape.size());
|
|
|
|
std::iota(retOrder.begin(), retOrder.end(), 0);
|
2022-08-18 12:49:37 -07:00
|
|
|
triton::gpu::BlockedEncodingAttr retEncoding =
|
|
|
|
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
|
|
|
retThreadsPerWarp, retWarpsPerCTA,
|
|
|
|
retOrder);
|
2022-10-11 18:16:41 -07:00
|
|
|
// convert operand to slice of return type
|
|
|
|
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
|
|
|
getContext(), op.axis(), retEncoding);
|
|
|
|
RankedTensorType newArgType = RankedTensorType::get(
|
|
|
|
argType.getShape(), argType.getElementType(), newArgEncoding);
|
2022-08-04 18:41:06 -07:00
|
|
|
// construct new op
|
2022-10-11 18:16:41 -07:00
|
|
|
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
|
|
op.getLoc(), newArgType, adaptor.src());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
|
|
|
|
adaptor.axis());
|
2022-08-04 18:41:06 -07:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-05-04 15:56:24 +08:00
|
|
|
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|
|
|
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
2022-05-02 21:51:00 +08:00
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-02 21:51:00 +08:00
|
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
2022-11-10 13:57:27 +08:00
|
|
|
Attribute dEncoding = retType.cast<RankedTensorType>().getEncoding();
|
2022-05-09 21:19:53 +08:00
|
|
|
// a & b must be of smem layout
|
|
|
|
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
|
|
|
auto bType = adaptor.b().getType().cast<RankedTensorType>();
|
|
|
|
Attribute aEncoding = aType.getEncoding();
|
|
|
|
Attribute bEncoding = bType.getEncoding();
|
|
|
|
if (!aEncoding || !bEncoding)
|
|
|
|
return failure();
|
|
|
|
Value a = adaptor.a();
|
|
|
|
Value b = adaptor.b();
|
2022-11-10 13:57:27 +08:00
|
|
|
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute encoding =
|
2022-11-10 13:57:27 +08:00
|
|
|
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
|
2022-07-26 17:25:03 -07:00
|
|
|
auto dstType = RankedTensorType::get(aType.getShape(),
|
|
|
|
aType.getElementType(), encoding);
|
2022-06-07 19:34:59 +08:00
|
|
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
|
|
|
}
|
2022-11-10 13:57:27 +08:00
|
|
|
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
2022-08-18 12:49:37 -07:00
|
|
|
Attribute encoding =
|
2022-11-10 13:57:27 +08:00
|
|
|
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
|
2022-07-26 17:25:03 -07:00
|
|
|
auto dstType = RankedTensorType::get(bType.getShape(),
|
|
|
|
bType.getElementType(), encoding);
|
2022-06-07 19:34:59 +08:00
|
|
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
|
|
|
}
|
2022-12-06 09:32:13 -08:00
|
|
|
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, adaptor.c(),
|
|
|
|
adaptor.allowTF32());
|
2022-05-02 21:51:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-06 23:29:50 -08:00
|
|
|
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
|
|
|
|
|
|
|
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// For now, this behaves like generic, but this will evolve when
|
|
|
|
// we add support for `can_reorder=False`
|
|
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
2022-12-10 20:34:58 -08:00
|
|
|
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
|
|
|
adaptor.getOperands());
|
2022-12-06 23:29:50 -08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-03 09:58:24 -08:00
|
|
|
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
|
|
|
|
|
|
|
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value src = adaptor.src();
|
|
|
|
auto srcType = src.getType().cast<RankedTensorType>();
|
|
|
|
Attribute srcEncoding = srcType.getEncoding();
|
|
|
|
if (!srcEncoding)
|
|
|
|
return failure();
|
|
|
|
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
|
|
|
// TODO: end-to-end correctness is broken if
|
|
|
|
// the input is blocked and the output is shared
|
|
|
|
// with different order. Maybe a backend issue in BlockedToShared?
|
|
|
|
SmallVector<unsigned> order = {1, 0};
|
|
|
|
if (auto srcBlockedEncoding =
|
|
|
|
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
|
|
|
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
|
|
|
srcEncoding =
|
|
|
|
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
|
|
|
srcType = RankedTensorType::get(srcType.getShape(),
|
|
|
|
srcType.getElementType(), srcEncoding);
|
|
|
|
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
|
|
|
src);
|
|
|
|
}
|
|
|
|
auto srcSharedEncoding =
|
|
|
|
srcEncoding.cast<triton::gpu::SharedEncodingAttr>();
|
|
|
|
SmallVector<unsigned> retOrder(srcSharedEncoding.getOrder().begin(),
|
|
|
|
srcSharedEncoding.getOrder().end());
|
|
|
|
SmallVector<int64_t> retShapes(srcType.getShape().begin(),
|
|
|
|
srcType.getShape().end());
|
|
|
|
std::reverse(retOrder.begin(), retOrder.end());
|
|
|
|
std::reverse(retShapes.begin(), retShapes.end());
|
|
|
|
auto retEncoding =
|
|
|
|
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, retOrder);
|
|
|
|
auto retType =
|
|
|
|
RankedTensorType::get(retShapes, srcType.getElementType(), retEncoding);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<triton::TransOp>(op, retType, src);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-05-02 21:51:00 +08:00
|
|
|
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
|
|
|
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-10-13 18:53:00 -07:00
|
|
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
|
|
|
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
|
|
|
adaptor.mask(), adaptor.other(), adaptor.cache(), adaptor.evict(),
|
|
|
|
adaptor.isVolatile());
|
2022-05-04 12:50:02 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|
|
|
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-10-28 12:36:09 -07:00
|
|
|
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
2022-07-26 17:25:03 -07:00
|
|
|
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
2022-05-02 21:51:00 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-11-25 12:02:08 +08:00
|
|
|
struct TritonAtomicCASPattern
|
|
|
|
: public OpConversionPattern<triton::AtomicCASOp> {
|
|
|
|
using OpConversionPattern<triton::AtomicCASOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
2022-11-30 10:07:34 -08:00
|
|
|
op, typeConverter->convertType(op.getType()), adaptor.ptr(),
|
|
|
|
adaptor.cmp(), adaptor.val());
|
2022-11-25 12:02:08 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-11-06 20:52:11 -08:00
|
|
|
struct TritonAtomicRMWPattern
|
|
|
|
: public OpConversionPattern<triton::AtomicRMWOp> {
|
|
|
|
using OpConversionPattern<triton::AtomicRMWOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
|
|
|
op, typeConverter->convertType(op.getType()), adaptor.atomic_rmw_op(),
|
|
|
|
adaptor.ptr(), adaptor.val(), adaptor.mask());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
struct TritonExtElemwisePattern
|
|
|
|
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
|
|
|
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
|
|
|
op, typeConverter->convertType(op.getType()), adaptor.args(),
|
|
|
|
adaptor.libname(), adaptor.libpath(), adaptor.symbol());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-05-04 15:56:24 +08:00
|
|
|
template <class Op>
|
|
|
|
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
|
|
|
using OpConversionPattern<Op>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-04 15:56:24 +08:00
|
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
2022-07-26 17:25:03 -07:00
|
|
|
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
2022-05-04 15:56:24 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-07-31 13:59:44 -07:00
|
|
|
struct TritonBroadcastPattern
|
|
|
|
: public OpConversionPattern<triton::BroadcastOp> {
|
|
|
|
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
// This creates a tensor with the new shape but the argument's layout
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto srcType = adaptor.src().getType().cast<RankedTensorType>();
|
|
|
|
auto srcEncoding = srcType.getEncoding();
|
|
|
|
if (!srcEncoding)
|
|
|
|
return failure();
|
|
|
|
auto opType = op.getType().cast<RankedTensorType>();
|
|
|
|
Type retType = RankedTensorType::get(opType.getShape(),
|
|
|
|
opType.getElementType(), srcEncoding);
|
|
|
|
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
|
|
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
|
|
|
|
adaptor.getOperands());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-05-25 16:03:06 +08:00
|
|
|
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|
|
|
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-10-28 12:36:09 -07:00
|
|
|
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
2022-10-11 18:16:41 -07:00
|
|
|
op, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
2022-05-25 16:03:06 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-11-02 16:39:09 +08:00
|
|
|
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
|
|
|
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
|
|
|
adaptor.getOperands());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|
|
|
RewritePatternSet &patterns) {
|
2022-05-02 21:51:00 +08:00
|
|
|
MLIRContext *context = patterns.getContext();
|
2022-07-31 13:59:44 -07:00
|
|
|
patterns.add< // TODO: view should have custom pattern that views the layout
|
|
|
|
TritonGenericPattern<triton::ViewOp>,
|
2022-10-26 17:47:33 -07:00
|
|
|
TritonGenericPattern<triton::BitcastOp>,
|
2022-11-10 15:53:06 +08:00
|
|
|
TritonGenericPattern<triton::FpToFpOp>,
|
2022-10-26 17:47:33 -07:00
|
|
|
TritonGenericPattern<triton::IntToPtrOp>,
|
|
|
|
TritonGenericPattern<triton::PtrToIntOp>,
|
2022-07-31 13:59:44 -07:00
|
|
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
2022-12-10 20:34:58 -08:00
|
|
|
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
|
|
|
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
|
|
|
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
|
|
|
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
|
|
|
TritonAtomicRMWPattern>(typeConverter, context);
|
2022-05-02 21:51:00 +08:00
|
|
|
}
|
|
|
|
|
2022-05-04 21:50:32 +08:00
|
|
|
//
|
|
|
|
// SCF patterns
|
|
|
|
//
|
2022-05-09 21:19:53 +08:00
|
|
|
// This is borrowed from ConvertForOpTypes in
|
|
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
2022-05-04 21:50:32 +08:00
|
|
|
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
|
|
|
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
2022-05-05 20:56:55 +08:00
|
|
|
// Ref: ConvertForOpTypes
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
auto newOp =
|
|
|
|
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
2022-05-05 20:56:55 +08:00
|
|
|
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
|
|
|
newOp.getLoopBody().end());
|
|
|
|
|
|
|
|
// Now, update all the types.
|
|
|
|
|
2022-05-09 21:19:53 +08:00
|
|
|
// Convert the types of block arguments within the given region. This
|
2022-07-26 17:25:03 -07:00
|
|
|
// replaces each block with a new block containing the updated signature.
|
|
|
|
// The entry block may have a special conversion if `entryConversion` is
|
2022-05-09 21:19:53 +08:00
|
|
|
// provided. On success, the new entry block to the region is returned for
|
|
|
|
// convenience. Otherwise, failure is returned.
|
2022-05-05 20:56:55 +08:00
|
|
|
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
|
|
|
*getTypeConverter()))) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
|
|
|
}
|
|
|
|
// Change the clone to use the updated operands. We could have cloned with
|
|
|
|
// a BlockAndValueMapping, but this seems a bit more direct.
|
|
|
|
newOp->setOperands(adaptor.getOperands());
|
|
|
|
// Update the result types to the new converted types.
|
2022-05-09 21:19:53 +08:00
|
|
|
SmallVector<Type> newResultTypes;
|
|
|
|
for (Type type : op.getResultTypes()) {
|
|
|
|
Type newType = typeConverter->convertType(type);
|
|
|
|
if (!newType)
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
|
|
newResultTypes.push_back(newType);
|
|
|
|
}
|
2022-05-05 20:56:55 +08:00
|
|
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
|
|
|
std::get<0>(t).setType(std::get<1>(t));
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
|
|
|
2022-05-04 21:50:32 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
|
|
|
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-05-05 20:56:55 +08:00
|
|
|
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
|
|
|
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
|
|
|
// op.erase();
|
2022-07-26 17:25:03 -07:00
|
|
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
|
2022-05-04 21:50:32 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-02 17:45:29 +08:00
|
|
|
// This is borrowed from ConvertFIfOpTypes in
|
|
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
|
|
|
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
// TODO: Generalize this to any type conversion, not just 1:1.
|
|
|
|
//
|
|
|
|
// We need to implement something more sophisticated here that tracks which
|
|
|
|
// types convert to which other types and does the appropriate
|
|
|
|
// materialization logic.
|
|
|
|
// For example, it's possible that one result type converts to 0 types and
|
|
|
|
// another to 2 types, so newResultTypes would at least be the right size to
|
|
|
|
// not crash in the llvm::zip call below, but then we would set the the
|
|
|
|
// wrong type on the SSA values! These edge cases are also why we cannot
|
|
|
|
// safely use the TypeConverter::convertTypes helper here.
|
|
|
|
SmallVector<Type> newResultTypes;
|
|
|
|
for (auto type : op.getResultTypes()) {
|
|
|
|
Type newType = typeConverter->convertType(type);
|
|
|
|
if (!newType)
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
|
|
newResultTypes.push_back(newType);
|
|
|
|
}
|
|
|
|
|
|
|
|
// See comments in the ForOp pattern for why we clone without regions and
|
|
|
|
// then inline.
|
|
|
|
scf::IfOp newOp =
|
|
|
|
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
|
|
|
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
|
|
|
newOp.getThenRegion().end());
|
|
|
|
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
|
|
|
newOp.getElseRegion().end());
|
|
|
|
|
|
|
|
// Update the operands and types.
|
|
|
|
newOp->setOperands(adaptor.getOperands());
|
|
|
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
|
|
|
std::get<0>(t).setType(std::get<1>(t));
|
|
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
|
|
|
RewritePatternSet &patterns) {
|
2022-05-04 21:50:32 +08:00
|
|
|
MLIRContext *context = patterns.getContext();
|
2022-12-02 17:45:29 +08:00
|
|
|
patterns.add<SCFYieldPattern, SCFForPattern, SCFIfPattern>(typeConverter,
|
|
|
|
context);
|
2022-05-04 21:50:32 +08:00
|
|
|
}
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
class ConvertTritonToTritonGPU
|
|
|
|
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
2022-04-30 14:31:18 -07:00
|
|
|
public:
|
2022-08-09 00:40:20 +08:00
|
|
|
ConvertTritonToTritonGPU() = default;
|
|
|
|
// constructor with some parameters set explicitly.
|
2022-07-26 17:25:03 -07:00
|
|
|
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
2022-06-07 19:34:59 +08:00
|
|
|
|
2022-05-01 22:06:54 +08:00
|
|
|
void runOnOperation() override {
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
ModuleOp mod = getOperation();
|
2022-06-18 14:57:41 +08:00
|
|
|
// type converter
|
2022-07-31 13:59:44 -07:00
|
|
|
TritonGPUTypeConverter typeConverter(context, numWarps);
|
2022-06-18 14:57:41 +08:00
|
|
|
TritonGPUConversionTarget target(*context, typeConverter);
|
2022-05-01 22:06:54 +08:00
|
|
|
// rewrite patterns
|
|
|
|
RewritePatternSet patterns(context);
|
|
|
|
// add rules
|
2022-10-21 16:52:15 -07:00
|
|
|
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
2022-06-18 14:57:41 +08:00
|
|
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR)
- Currently till TritonGPU. It cannot be lowered to PTX now.
- No special optimizations (e.g., constant folding etc) are applied.
- 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td
- No constant folding etc for `libdevice` ops.
```py
import triton
import triton.language as tl
import sys
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
BLOCK_SIZE: tl.constexpr,
):
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
x = tl.sin(x)
output = tl.libdevice.sin(x)
output = tl.libdevice.fdiv_rn(output, output)
output = tl.libdevice.fmaf_rd(output, output, output)
tl.store(y_ptr + offsets, output)
if __name__ == "__main__" and len(sys.argv) >= 2:
signature = "*fp32,*fp32"
constants = {'BLOCK_SIZE': 1024}
output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir")
print(output)
```
->
```llvm
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
%4 = math.sin %3 : tensor<1024xf32, #blocked>
%5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked>
%8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.store %9, %7 : tensor<1024xf32, #blocked>
return
}
}
```
2022-09-01 16:34:27 -07:00
|
|
|
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
2022-06-18 14:57:41 +08:00
|
|
|
populateTritonPatterns(typeConverter, patterns);
|
2022-07-26 17:25:03 -07:00
|
|
|
// TODO: can we use
|
2022-05-05 20:56:55 +08:00
|
|
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
2022-06-18 14:57:41 +08:00
|
|
|
populateSCFPatterns(typeConverter, patterns);
|
2022-04-30 20:42:25 -07:00
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
|
|
|
return signalPassFailure();
|
2022-06-18 21:16:45 +08:00
|
|
|
|
2022-08-09 00:40:20 +08:00
|
|
|
auto inti = llvm::APSInt(32, false);
|
|
|
|
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
|
|
|
|
|
|
|
mod->setAttr(
|
|
|
|
AttrNumWarpsName,
|
|
|
|
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
|
|
|
|
2022-06-18 21:16:45 +08:00
|
|
|
// update layouts
|
|
|
|
// broadcast src => multicast, dst => broadcasted
|
2022-07-31 13:59:44 -07:00
|
|
|
// if (failed(target.refineLayouts(mod, numWarps)))
|
|
|
|
// return signalPassFailure();
|
2022-05-01 22:06:54 +08:00
|
|
|
}
|
2022-04-30 14:31:18 -07:00
|
|
|
};
|
|
|
|
|
2022-07-26 17:25:03 -07:00
|
|
|
} // namespace
|
2022-04-30 14:31:18 -07:00
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
2022-06-07 19:34:59 +08:00
|
|
|
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
|
|
|
return std::make_unique<::ConvertTritonToTritonGPU>(numWarps);
|
2022-05-01 13:06:51 +08:00
|
|
|
}
|
2022-08-09 00:40:20 +08:00
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::triton::createConvertTritonToTritonGPUPass() {
|
|
|
|
return std::make_unique<::ConvertTritonToTritonGPU>();
|
|
|
|
}
|