[Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875)

This commit is contained in:
Qingyi Liu
2022-11-15 09:27:32 +08:00
committed by GitHub
parent c28cfd821b
commit 4c4159c6fa
5 changed files with 58 additions and 0 deletions

View File

@@ -161,9 +161,16 @@ private:
// TODO(Keren): Reduce with index is not supported yet.
auto value = op->getOperand(0);
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
auto srcLayout = tensorType.getEncoding();
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
auto smemShape = getScratchConfigForReduce(reduceOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
if (fastReduce) {
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
elems = std::max<unsigned>(elems, numWarps * 32);
}
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}

View File

@@ -64,6 +64,12 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
rewriter.getF32FloatAttr(v));
}
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
auto type = type::f64Ty(rewriter.getContext());
return rewriter.create<LLVM::ConstantOp>(loc, type,
rewriter.getF64FloatAttr(v));
}
// Create an index type constant.
static Value createIndexConstant(OpBuilder &builder, Location loc,
TypeConverter *converter, int64_t value) {
@@ -132,6 +138,7 @@ void llPrintf(StringRef msg, ValueRange args,
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
@@ -2592,6 +2599,8 @@ public:
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
operands[i], loc);
if (!bool(resultVals[i]))
return failure();
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
@@ -5834,6 +5843,32 @@ struct FDivOpConversion
}
};
struct ExpOpConversionApprox
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
ExpOpConversionApprox> {
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
ExpOpConversionApprox>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Type elemTy,
ValueRange operands, Location loc) const {
// For FP64 input, call __nv_expf for higher-precision calculation
if (elemTy.getIntOrFloatBitWidth() == 64)
return {};
const double log2e = 1.4426950408889634;
Value prod =
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
PTXBuilder ptxBuilder;
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
auto output = ptxBuilder.newOperand("=f");
auto input = ptxBuilder.newOperand(prod, "f");
exp2(output, input);
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
}
};
/// ====================== atomic_rmw codegen begin ==========================
struct AtomicRMWOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
@@ -5994,6 +6029,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
// For FP64 input type, ExpOpConversionApprox will return failure and
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
// __nv_expf for higher-precision calculation
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);

View File

@@ -61,6 +61,7 @@ def get_tensor(shape, data_type, b_positive=False):
('sqrt', 'float64', 'float64'),
('abs', 'float32', 'float32'),
('exp', 'float32', 'float32'),
('exp', 'float64', 'float64'),
('sigmoid', 'float32', 'float32'),
])
def test_single_input(expr, output_type, input0_type):

View File

@@ -9,6 +9,8 @@
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -313,3 +315,5 @@ func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
return
// CHECK-NEXT: size = 40960
}
}

View File

@@ -9,6 +9,8 @@
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
// There shouldn't be any membar with the dot op encoding.
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
@@ -250,3 +252,5 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
return
}
}