[Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875)
This commit is contained in:
@@ -161,9 +161,16 @@ private:
|
|||||||
// TODO(Keren): Reduce with index is not supported yet.
|
// TODO(Keren): Reduce with index is not supported yet.
|
||||||
auto value = op->getOperand(0);
|
auto value = op->getOperand(0);
|
||||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||||
|
auto srcLayout = tensorType.getEncoding();
|
||||||
|
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
|
||||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||||
std::multiplies{});
|
std::multiplies{});
|
||||||
|
if (fastReduce) {
|
||||||
|
auto mod = op->getParentOfType<ModuleOp>();
|
||||||
|
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
|
elems = std::max<unsigned>(elems, numWarps * 32);
|
||||||
|
}
|
||||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||||
}
|
}
|
||||||
|
@@ -64,6 +64,12 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) {
|
|||||||
rewriter.getF32FloatAttr(v));
|
rewriter.getF32FloatAttr(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) {
|
||||||
|
auto type = type::f64Ty(rewriter.getContext());
|
||||||
|
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
||||||
|
rewriter.getF64FloatAttr(v));
|
||||||
|
}
|
||||||
|
|
||||||
// Create an index type constant.
|
// Create an index type constant.
|
||||||
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
||||||
TypeConverter *converter, int64_t value) {
|
TypeConverter *converter, int64_t value) {
|
||||||
@@ -132,6 +138,7 @@ void llPrintf(StringRef msg, ValueRange args,
|
|||||||
#define f64_ty rewriter.getF64Type()
|
#define f64_ty rewriter.getF64Type()
|
||||||
#define vec_ty(type, num) VectorType::get(num, type)
|
#define vec_ty(type, num) VectorType::get(num, type)
|
||||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||||
|
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||||
|
|
||||||
@@ -2592,6 +2599,8 @@ public:
|
|||||||
for (unsigned i = 0; i < elems; ++i) {
|
for (unsigned i = 0; i < elems; ++i) {
|
||||||
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||||
operands[i], loc);
|
operands[i], loc);
|
||||||
|
if (!bool(resultVals[i]))
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, view);
|
rewriter.replaceOp(op, view);
|
||||||
@@ -5834,6 +5843,32 @@ struct FDivOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ExpOpConversionApprox
|
||||||
|
: ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||||
|
ExpOpConversionApprox> {
|
||||||
|
using Base = ElementwiseOpConversionBase<mlir::math::ExpOp, LLVM::InlineAsmOp,
|
||||||
|
ExpOpConversionApprox>;
|
||||||
|
using Base::Base;
|
||||||
|
using Adaptor = typename Base::OpAdaptor;
|
||||||
|
|
||||||
|
Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
|
ValueRange operands, Location loc) const {
|
||||||
|
// For FP64 input, call __nv_expf for higher-precision calculation
|
||||||
|
if (elemTy.getIntOrFloatBitWidth() == 64)
|
||||||
|
return {};
|
||||||
|
const double log2e = 1.4426950408889634;
|
||||||
|
Value prod =
|
||||||
|
rewriter.create<LLVM::FMulOp>(loc, f32_ty, operands[0], f32_val(log2e));
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
auto &exp2 = ptxBuilder.create<PTXInstr>("ex2")->o("approx").o("f32");
|
||||||
|
auto output = ptxBuilder.newOperand("=f");
|
||||||
|
auto input = ptxBuilder.newOperand(prod, "f");
|
||||||
|
exp2(output, input);
|
||||||
|
return ptxBuilder.launch(rewriter, loc, f32_ty, false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// ====================== atomic_rmw codegen begin ==========================
|
/// ====================== atomic_rmw codegen begin ==========================
|
||||||
struct AtomicRMWOpConversion
|
struct AtomicRMWOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
|
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
|
||||||
@@ -5994,6 +6029,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
|
|
||||||
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||||
|
|
||||||
|
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
|
||||||
|
// For FP64 input type, ExpOpConversionApprox will return failure and
|
||||||
|
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||||
|
// __nv_expf for higher-precision calculation
|
||||||
|
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||||
|
|
||||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
|
|
||||||
|
@@ -61,6 +61,7 @@ def get_tensor(shape, data_type, b_positive=False):
|
|||||||
('sqrt', 'float64', 'float64'),
|
('sqrt', 'float64', 'float64'),
|
||||||
('abs', 'float32', 'float32'),
|
('abs', 'float32', 'float32'),
|
||||||
('exp', 'float32', 'float32'),
|
('exp', 'float32', 'float32'),
|
||||||
|
('exp', 'float64', 'float64'),
|
||||||
('sigmoid', 'float32', 'float32'),
|
('sigmoid', 'float32', 'float32'),
|
||||||
])
|
])
|
||||||
def test_single_input(expr, output_type, input0_type):
|
def test_single_input(expr, output_type, input0_type):
|
||||||
|
@@ -9,6 +9,8 @@
|
|||||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
// CHECK-LABEL: matmul_loop
|
// CHECK-LABEL: matmul_loop
|
||||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||||
@@ -313,3 +315,5 @@ func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
|||||||
return
|
return
|
||||||
// CHECK-NEXT: size = 40960
|
// CHECK-NEXT: size = 40960
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@@ -9,6 +9,8 @@
|
|||||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||||
|
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
|
||||||
// CHECK-LABEL: matmul_loop
|
// CHECK-LABEL: matmul_loop
|
||||||
// There shouldn't be any membar with the dot op encoding.
|
// There shouldn't be any membar with the dot op encoding.
|
||||||
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||||
@@ -250,3 +252,5 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
|
|||||||
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
%cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user