[Triton] Support math and libdevice ops (#91)
This PR adds basic math ops by using `MathDialect` and `libdevice` ops by using `extern_elementwise`. This is needed to compile some tutorial code (e.g., `softmax`). This PR implements only interface till PTX (so from frontend to TritonGPU-MLIR) - Currently till TritonGPU. It cannot be lowered to PTX now. - No special optimizations (e.g., constant folding etc) are applied. - 14.x does not define folders for many operators for math ops, but 15.x seems to increase its coverage: https://github.com/llvm/llvm-project/blob/llvmorg-15.0.0-rc3/mlir/include/mlir/Dialect/Math/IR/MathOps.td - No constant folding etc for `libdevice` ops. ```py import triton import triton.language as tl import sys @triton.jit def add_kernel( x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr, ): offsets = tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offsets) x = tl.sin(x) output = tl.libdevice.sin(x) output = tl.libdevice.fdiv_rn(output, output) output = tl.libdevice.fmaf_rd(output, output, output) tl.store(y_ptr + offsets, output) if __name__ == "__main__" and len(sys.argv) >= 2: signature = "*fp32,*fp32" constants = {'BLOCK_SIZE': 1024} output = triton.compile(add_kernel, signature, device=0, constants=constants, output="ttgir") print(output) ``` -> ```llvm #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @add_kernel__Pfp32_Pfp32__2c1024(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> %1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %2 = tt.getelementptr %1, %0 : tensor<1024x!tt.ptr<f32>, #blocked> %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> %4 = math.sin %3 : tensor<1024xf32, #blocked> %5 = tt.ext_elemwise %4 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_sinf"} : tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %6 = tt.ext_elemwise %5, %5 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fdiv_rn"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %7 = tt.ext_elemwise %6, %6, %6 {libname = "libdevice", libpath = "/home/siwasaki/triton/python/triton/language/libdevice.10.bc", symbol = "__nv_fmaf_rd"} : tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xf32, #blocked> -> tensor<1024xf32, #blocked> %8 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>, #blocked> %9 = tt.getelementptr %8, %0 : tensor<1024x!tt.ptr<f32>, #blocked> tt.store %9, %7 : tensor<1024xf32, #blocked> return } } ```
This commit is contained in:
@@ -196,6 +196,7 @@ target_link_libraries(triton
|
|||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRTargetLLVMIRExport
|
MLIRTargetLLVMIRExport
|
||||||
MLIRExecutionEngine
|
MLIRExecutionEngine
|
||||||
|
MLIRMathToLLVM
|
||||||
MLIRNVVMToLLVMIRTranslation
|
MLIRNVVMToLLVMIRTranslation
|
||||||
MLIRIR
|
MLIRIR
|
||||||
)
|
)
|
||||||
|
@@ -52,6 +52,7 @@ target_link_libraries(triton-translate PRIVATE
|
|||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
MLIRExecutionEngine
|
MLIRExecutionEngine
|
||||||
|
MLIRMathToLLVM
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRLLVMToLLVMIRTranslation
|
MLIRLLVMToLLVMIRTranslation
|
||||||
MLIRNVVMToLLVMIRTranslation
|
MLIRNVVMToLLVMIRTranslation
|
||||||
|
@@ -32,8 +32,8 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
// TODO: register Triton & TritonGPU passes
|
// TODO: register Triton & TritonGPU passes
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
registry
|
registry.insert<mlir::triton::TritonDialect,
|
||||||
.insert<mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
|
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
|
||||||
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
|
||||||
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
|
||||||
|
|
||||||
|
@@ -36,9 +36,9 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::DialectRegistry registry;
|
mlir::DialectRegistry registry;
|
||||||
registry
|
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
||||||
.insert<TritonDialect, triton::gpu::TritonGPUDialect,
|
mlir::math::MathDialect, arith::ArithmeticDialect,
|
||||||
arith::ArithmeticDialect, StandardOpsDialect, scf::SCFDialect>();
|
StandardOpsDialect, scf::SCFDialect>();
|
||||||
|
|
||||||
context.appendDialectRegistry(registry);
|
context.appendDialectRegistry(registry);
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
|||||||
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
|
let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
|
"mlir::math::MathDialect",
|
||||||
"mlir::StandardOpsDialect",
|
"mlir::StandardOpsDialect",
|
||||||
// TODO: Does this pass depend on SCF?
|
// TODO: Does this pass depend on SCF?
|
||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
@@ -33,6 +34,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
|||||||
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
|
let constructor = "mlir::triton::createConvertTritonGPUToLLVMPass()";
|
||||||
|
|
||||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
|
"mlir::math::MathDialect",
|
||||||
"mlir::gpu::GPUDialect",
|
"mlir::gpu::GPUDialect",
|
||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
"mlir::LLVM::LLVMDialect",
|
"mlir::LLVM::LLVMDialect",
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
@@ -16,12 +16,15 @@ def Triton_Dialect : Dialect {
|
|||||||
Dependent Dialects:
|
Dependent Dialects:
|
||||||
* Arithmetic:
|
* Arithmetic:
|
||||||
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
|
* addf, addi, andi, cmpf, cmpi, divf, fptosi, ...
|
||||||
|
* Math:
|
||||||
|
* exp, sin, cos, log, ...
|
||||||
* StructuredControlFlow:
|
* StructuredControlFlow:
|
||||||
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
|
* ForOp, IfOp, WhileOp, YieldOp, ConditionOp
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let dependentDialects = [
|
let dependentDialects = [
|
||||||
"arith::ArithmeticDialect",
|
"arith::ArithmeticDialect",
|
||||||
|
"math::MathDialect",
|
||||||
"StandardOpsDialect",
|
"StandardOpsDialect",
|
||||||
"scf::SCFDialect",
|
"scf::SCFDialect",
|
||||||
"gpu::GPUDialect",
|
"gpu::GPUDialect",
|
||||||
|
@@ -283,6 +283,26 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
|||||||
let results = (outs TT_Type:$result);
|
let results = (outs TT_Type:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// External Function Ops
|
||||||
|
//
|
||||||
|
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [NoSideEffect, Elementwise, SameOperandsAndResultShape,
|
||||||
|
SameVariadicOperandSize]> {
|
||||||
|
let summary = "ext_elemwise";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
call an external function $symbol implemented in $libpath/$libname with $args
|
||||||
|
|
||||||
|
return $libpath/$libname:$symbol($args...)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Variadic<TT_Tensor>:$args, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Intrinsics
|
// Intrinsics
|
||||||
//
|
//
|
||||||
|
@@ -4,6 +4,7 @@
|
|||||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
@@ -1328,9 +1329,10 @@ public:
|
|||||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||||
*axisAnalysis, 10 /*benefit*/);
|
*axisAnalysis, 10 /*benefit*/);
|
||||||
|
|
||||||
// Add arith's patterns to help convert scalar expression to LLVM.
|
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
patterns);
|
patterns);
|
||||||
|
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
|
@@ -7,13 +7,12 @@
|
|||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
#include "llvm/ADT/APSInt.h"
|
#include "llvm/ADT/APSInt.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
|
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<Op>::OpConversionPattern;
|
using OpConversionPattern<Op>::OpConversionPattern;
|
||||||
|
|
||||||
@@ -23,6 +22,7 @@ public:
|
|||||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||||
Op res =
|
Op res =
|
||||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -98,32 +98,41 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
// Rewrite rule
|
// Rewrite rule
|
||||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||||
patterns.add<
|
patterns.add<
|
||||||
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
|
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
||||||
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
|
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
||||||
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
|
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
|
||||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
GenericOpPattern<arith::CeilDivUIOp>,
|
||||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
GenericOpPattern<arith::CeilDivSIOp>,
|
||||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
|
||||||
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
|
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
|
||||||
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
|
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
|
||||||
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
|
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
|
||||||
ArithGenericPattern<arith::ShRUIOp>,
|
GenericOpPattern<arith::ShRSIOp>, // NegFOp
|
||||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
|
||||||
// Floating point
|
// Floating point
|
||||||
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
|
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
||||||
// MaxMin
|
// MaxMin
|
||||||
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
|
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
|
||||||
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
|
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
|
||||||
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
|
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
||||||
// Floating point
|
// Floating point
|
||||||
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
|
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
||||||
ArithGenericPattern<arith::RemFOp>,
|
GenericOpPattern<arith::RemFOp>,
|
||||||
// Cmp
|
// Cmp
|
||||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||||
// Cast Ops
|
// Cast Ops
|
||||||
ArithGenericPattern<arith::TruncIOp>,
|
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>>(
|
||||||
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
|
typeConverter, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns,
|
||||||
|
TritonGPUConversionTarget &target) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
// Rewrite rule
|
||||||
|
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
||||||
|
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
||||||
|
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -246,6 +255,20 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonExtElemwisePattern
|
||||||
|
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
||||||
|
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
||||||
|
op, typeConverter->convertType(op.getType()), adaptor.args(),
|
||||||
|
adaptor.libname(), adaptor.libpath(), adaptor.symbol());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <class Op>
|
template <class Op>
|
||||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||||
using OpConversionPattern<Op>::OpConversionPattern;
|
using OpConversionPattern<Op>::OpConversionPattern;
|
||||||
@@ -302,7 +325,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||||
TritonLoadPattern, TritonStorePattern>(typeConverter, context);
|
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
|
||||||
|
typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -389,6 +413,7 @@ public:
|
|||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
// add rules
|
// add rules
|
||||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||||
|
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
||||||
populateTritonPatterns(typeConverter, patterns);
|
populateTritonPatterns(typeConverter, patterns);
|
||||||
// TODO: can we use
|
// TODO: can we use
|
||||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||||
|
@@ -81,9 +81,9 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
||||||
scf::ReduceReturnOp>();
|
scf::ReduceReturnOp>();
|
||||||
|
|
||||||
addDynamicallyLegalDialect<arith::ArithmeticDialect, triton::TritonDialect,
|
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
|
||||||
StandardOpsDialect, scf::SCFDialect>(
|
triton::TritonDialect, StandardOpsDialect,
|
||||||
[&](Operation *op) {
|
scf::SCFDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op))
|
if (typeConverter.isLegal(op))
|
||||||
return true;
|
return true;
|
||||||
return false;
|
return false;
|
||||||
|
@@ -1542,7 +1542,16 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||||
ptr, val, mask);
|
ptr, val, mask);
|
||||||
})
|
})
|
||||||
|
// External
|
||||||
|
.def("create_external_elementwise",
|
||||||
|
[](mlir::OpBuilder &self, const std::string &libName,
|
||||||
|
const std::string &libPath, const std::string &symbol,
|
||||||
|
std::vector<mlir::Value> &argList,
|
||||||
|
mlir::Type retType) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::triton::ExtElemwiseOp>(
|
||||||
|
loc, retType, argList, libName, libPath, symbol);
|
||||||
|
})
|
||||||
// Built-in instruction
|
// Built-in instruction
|
||||||
.def("create_get_program_id",
|
.def("create_get_program_id",
|
||||||
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
[](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
||||||
@@ -1563,12 +1572,32 @@ void init_triton_ir(py::module &&m) {
|
|||||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
|
||||||
allowTF32);
|
allowTF32);
|
||||||
})
|
})
|
||||||
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
.def("create_exp",
|
||||||
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||||
// .def("create_sin", &ir::builder::create_sin, ret::reference)
|
auto loc = self.getUnknownLoc();
|
||||||
// .def("create_log", &ir::builder::create_log, ret::reference)
|
return self.create<mlir::math::ExpOp>(loc, val);
|
||||||
|
})
|
||||||
|
.def("create_cos",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::math::CosOp>(loc, val);
|
||||||
|
})
|
||||||
|
.def("create_sin",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::math::SinOp>(loc, val);
|
||||||
|
})
|
||||||
|
.def("create_log",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::math::LogOp>(loc, val);
|
||||||
|
})
|
||||||
|
.def("create_sqrt",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
return self.create<mlir::math::SqrtOp>(loc, val);
|
||||||
|
})
|
||||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
|
||||||
.def("create_reduce",
|
.def("create_reduce",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &operand,
|
[](mlir::OpBuilder &self, mlir::Value &operand,
|
||||||
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
mlir::triton::RedOp redOp, int axis) -> mlir::Value {
|
||||||
|
33
python/tests/test_math_ops.py
Normal file
33
python/tests/test_math_ops.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def math_kernel(x1_ptr, x2_ptr, x3_ptr, x4_ptr, n, BLOCK_SIZE: tl.constexpr):
|
||||||
|
offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
x1 = tl.load(x1_ptr + offsets, mask=offsets < n)
|
||||||
|
x2 = tl.load(x2_ptr + offsets, mask=offsets < n)
|
||||||
|
x3 = tl.load(x3_ptr + offsets, mask=offsets < n)
|
||||||
|
x4 = tl.load(x4_ptr + offsets, mask=offsets < n)
|
||||||
|
|
||||||
|
y1 = tl.sin(x1)
|
||||||
|
y2 = tl.libdevice.sin(x2)
|
||||||
|
y3 = tl.libdevice.fdiv_rn(x3, x3)
|
||||||
|
y4 = tl.libdevice.fmaf_rd(x4, x4, x4)
|
||||||
|
|
||||||
|
tl.store(x1_ptr + offsets, y1, mask=offsets < n)
|
||||||
|
tl.store(x2_ptr + offsets, y2, mask=offsets < n)
|
||||||
|
tl.store(x3_ptr + offsets, y3, mask=offsets < n)
|
||||||
|
tl.store(x4_ptr + offsets, y4, mask=offsets < n)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_kernel_cubin_compile():
|
||||||
|
kernel = triton.compile(math_kernel,
|
||||||
|
"*fp32,*fp32,*fp32,*fp32,i32",
|
||||||
|
device=0,
|
||||||
|
constants={"BLOCK_SIZE": 256},
|
||||||
|
output="ttgir") # "cubin"
|
||||||
|
assert kernel
|
||||||
|
# TODO: Check if the values are correct.
|
||||||
|
# TODO: Cover all the math operators
|
@@ -671,7 +671,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
results.append(triton.language.tensor(call_op.get_result(i), callee_ret_type[i]))
|
||||||
return tuple(results)
|
return tuple(results)
|
||||||
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
if hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__) or \
|
||||||
sys.modules[fn.__module__] is triton.language.core:
|
sys.modules[fn.__module__] is triton.language.core or \
|
||||||
|
isinstance(fn, triton.language.extern.ExternalFunction):
|
||||||
return fn(*args, _builder=self.builder, **kws)
|
return fn(*args, _builder=self.builder, **kws)
|
||||||
if fn in self.builtins.values():
|
if fn in self.builtins.values():
|
||||||
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
args = [arg.value if isinstance(arg, triton.language.constexpr) else arg
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
# flake8: noqa: F401
|
# flake8: noqa: F401
|
||||||
from . import core, random
|
from . import core, extern, libdevice, random
|
||||||
from .core import *
|
from .core import *
|
||||||
from .random import *
|
from .random import *
|
||||||
|
@@ -234,11 +234,15 @@ class pointer_type(dtype):
|
|||||||
|
|
||||||
|
|
||||||
class block_type(dtype):
|
class block_type(dtype):
|
||||||
def __init__(self, element_ty: dtype, shape: List[int]):
|
def __init__(self, element_ty: dtype, shape: List):
|
||||||
self.element_ty = element_ty
|
self.element_ty = element_ty
|
||||||
# FIXME:
|
|
||||||
# block_type's shape is a list of int
|
# Note that block_type's shape is a list of int
|
||||||
# while tensor's shape is a list of constexpr
|
# while tensor's shape is a list of constexpr.
|
||||||
|
assert shape
|
||||||
|
if isinstance(shape[0], constexpr):
|
||||||
|
shape = [s.value for s in shape]
|
||||||
|
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.numel = 1
|
self.numel = 1
|
||||||
for s in self.shape:
|
for s in self.shape:
|
||||||
|
104
python/triton/language/extern.py
Normal file
104
python/triton/language/extern.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from __future__ import annotations # remove after python 3.11
|
||||||
|
|
||||||
|
from . import core, semantic
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
|
||||||
|
'''
|
||||||
|
Dispatch a function to a library
|
||||||
|
:param func: the function to dispatch
|
||||||
|
:param lib_name: the name of the library
|
||||||
|
:param lib_path: the path of the library
|
||||||
|
:param args: the arguments of the function
|
||||||
|
:param arg_type_symbol_dict: the type of the arguments
|
||||||
|
:param ret_shape: the shape of the return value
|
||||||
|
:param _builder: the builder
|
||||||
|
:return: the return value of the function
|
||||||
|
'''
|
||||||
|
if len(arg_type_symbol_dict) == 0:
|
||||||
|
raise ValueError("arg_type_symbol_dict is empty")
|
||||||
|
|
||||||
|
num_args = len(list(arg_type_symbol_dict.keys())[0])
|
||||||
|
if len(args) != num_args:
|
||||||
|
raise ValueError(f"length of input args does not match."
|
||||||
|
f"Expect {len(args)}, got {num_args}")
|
||||||
|
|
||||||
|
arg_types = []
|
||||||
|
arg_list = []
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, core.tensor):
|
||||||
|
arg_types.append(arg.dtype)
|
||||||
|
arg_list.append(arg.handle)
|
||||||
|
else:
|
||||||
|
arg_types.append(type(arg))
|
||||||
|
arg_list.append(arg)
|
||||||
|
arg_types = tuple(arg_types)
|
||||||
|
|
||||||
|
if arg_types not in arg_type_symbol_dict:
|
||||||
|
raise ValueError(f"input arg type does not match."
|
||||||
|
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
|
||||||
|
else:
|
||||||
|
symbol = arg_type_symbol_dict[arg_types][0]
|
||||||
|
ret_type = arg_type_symbol_dict[arg_types][1]
|
||||||
|
if ret_shape:
|
||||||
|
ret_type = core.block_type(ret_type, ret_shape)
|
||||||
|
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
|
||||||
|
|
||||||
|
|
||||||
|
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
|
||||||
|
'''
|
||||||
|
Dispatch an elementwise function to a library
|
||||||
|
:param lib_name: the name of the library
|
||||||
|
:param lib_path: the path of the library
|
||||||
|
:param args: the arguments of the function
|
||||||
|
:param arg_type_symbol_dict: the type of the arguments
|
||||||
|
:param _builder: the builder
|
||||||
|
:return: the return value of the function
|
||||||
|
'''
|
||||||
|
dispatch_args = args.copy()
|
||||||
|
if len(args) == 1:
|
||||||
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
|
ret_shape = dispatch_args[0].shape
|
||||||
|
elif len(args) == 2:
|
||||||
|
dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder)
|
||||||
|
dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder)
|
||||||
|
dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[0], dispatch_args[1], _builder)
|
||||||
|
ret_shape = dispatch_args[0].shape
|
||||||
|
else:
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
|
||||||
|
broadcast_arg = dispatch_args[0]
|
||||||
|
# Get the broadcast shape over all the arguments
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[i], broadcast_arg, _builder)
|
||||||
|
# Change the shape of each argument based on the broadcast shape
|
||||||
|
for i in range(len(dispatch_args)):
|
||||||
|
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||||
|
dispatch_args[i], broadcast_arg, _builder)
|
||||||
|
ret_shape = broadcast_arg.shape
|
||||||
|
func = getattr(_builder, "create_external_elementwise")
|
||||||
|
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalFunction:
|
||||||
|
'''
|
||||||
|
A wrapper for external functions
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, fn):
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if '_builder' not in kwargs or \
|
||||||
|
kwargs['_builder'] is None:
|
||||||
|
raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)")
|
||||||
|
return self.fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def extern(fn):
|
||||||
|
'''
|
||||||
|
A decorator for external functions
|
||||||
|
'''
|
||||||
|
return ExternalFunction(fn)
|
1661
python/triton/language/libdevice.py
Normal file
1661
python/triton/language/libdevice.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user