[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) - Currently till TritonGPU. It cannot be lowered to PTX now. - No special optimizations (e.g., constant folding etc) are applied. - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td - No constant folding etc for `libdevice` ops. ```py import triton import triton.language as tl import sys @triton.jit def add_kernel( x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, ): offsets = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets) x = tl.sin(x) output = tl.libdevice.sin(x) output = tl.libdevice.fdiv_rn(output, output) output = tl.libdevice.fmaf_rd(output, output, output) tl.store(y_ptr + offsets, output) if __name__ == "__main__" and len(sys.argv) >= 2: signature = "*fp32,*fp32" constants = {'BLOCK_SIZE': 1024} output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir") print(output) ``` -> ```llvm #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked> %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> %4 = math.sin %3 : tensor<1024xf32, #blocked> %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked> tt.store %9, %7 : tensor<1024xf32, #blocked> return } } ```
This commit is contained in:
@@ -7,13 +7,12 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "llvm/ADT/APSInt.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
@@ -23,6 +22,7 @@ public:
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res =
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -98,32 +98,41 @@ void populateArithmeticPatternsAndLegality(
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<
|
||||
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
||||
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
||||
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
|
||||
GenericOpPattern<arith::CeilDivUIOp>,
|
||||
GenericOpPattern<arith::CeilDivSIOp>,
|
||||
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
|
||||
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
|
||||
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
|
||||
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
|
||||
GenericOpPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
|
||||
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
|
||||
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
||||
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
||||
GenericOpPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
|
||||
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
||||
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
||||
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -246,6 +255,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonExtElemwisePattern
|
||||
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
||||
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.args(),
|
||||
adaptor.libname(), adaptor.libpath(), adaptor.symbol());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
@@ -302,7 +325,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern>(typeConverter, context);
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -389,6 +413,7 @@ public:
|
||||
RewritePatternSet patterns(context);
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
|
Reference in New Issue
Block a user