[Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875)
This commit is contained in:
@@ -64,6 +64,12 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f64Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF64FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create an index type constant.
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
@@ -132,6 +138,7 @@ void llPrintf(StringRef msg, ValueRange args,
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
|
||||
@@ -2592,6 +2599,8 @@ public:
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||
operands[i], loc);
|
||||
if (!bool(resultVals[i]))
|
||||
return failure();
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
@@ -5834,6 +5843,32 @@ struct FDivOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpOpConversionApprox
|
||||
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||
ExpOpConversionApprox> {
|
||||
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||
ExpOpConversionApprox>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||
return {};
|
||||
const double log2e = 1.4426950408889634;
|
||||
Value prod =
|
||||
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||
auto output = ptxBuilder.newOperand("=f");
|
||||
auto input = ptxBuilder.newOperand(prod, "f");
|
||||
exp2(output, input);
|
||||
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== atomic_rmw codegen begin ==========================
|
||||
struct AtomicRMWOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
|
||||
@@ -5994,6 +6029,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
|
||||
// For FP64 input type, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
|
||||
|
Reference in New Issue
Block a user