[Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875)
This commit is contained in:
@@ -161,9 +161,16 @@ private:
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto srcLayout = tensorType.getEncoding();
|
||||
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
if (fastReduce) {
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
elems = std::max<unsigned>(elems, numWarps * 32);
|
||||
}
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
|
@@ -64,6 +64,12 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
||||
rewriter.getF32FloatAttr(v));
|
||||
}
|
||||
|
||||
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||
auto type = type::f64Ty(rewriter.getContext());
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||
rewriter.getF64FloatAttr(v));
|
||||
}
|
||||
|
||||
// Create an index type constant.
|
||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||
TypeConverter *converter, int64_t value) {
|
||||
@@ -132,6 +138,7 @@ void llPrintf(StringRef msg, ValueRange args,
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
|
||||
@@ -2592,6 +2599,8 @@ public:
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||
operands[i], loc);
|
||||
if (!bool(resultVals[i]))
|
||||
return failure();
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
@@ -5834,6 +5843,32 @@ struct FDivOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpOpConversionApprox
|
||||
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||
ExpOpConversionApprox> {
|
||||
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||
ExpOpConversionApprox>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||
return {};
|
||||
const double log2e = 1.4426950408889634;
|
||||
Value prod =
|
||||
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
||||
PTXBuilder ptxBuilder;
|
||||
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||
auto output = ptxBuilder.newOperand("=f");
|
||||
auto input = ptxBuilder.newOperand(prod, "f");
|
||||
exp2(output, input);
|
||||
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
/// ====================== atomic_rmw codegen begin ==========================
|
||||
struct AtomicRMWOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
|
||||
@@ -5994,6 +6029,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
|
||||
// For FP64 input type, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
|
||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
|
||||
|
@@ -61,6 +61,7 @@ def get_tensor(shape, data_type, b_positive=False):
|
||||
('sqrt', 'float64', 'float64'),
|
||||
('abs', 'float32', 'float32'),
|
||||
('exp', 'float32', 'float32'),
|
||||
('exp', 'float64', 'float64'),
|
||||
('sigmoid', 'float32', 'float32'),
|
||||
])
|
||||
def test_single_input(expr, output_type, input0_type):
|
||||
|
@@ -9,6 +9,8 @@
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
@@ -313,3 +315,5 @@ func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
return
|
||||
// CHECK-NEXT: size = 40960
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -9,6 +9,8 @@
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any membar with the dot op encoding.
|
||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
@@ -250,3 +252,5 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user