From 4c4159c6fa61dd48d4766a79abdc473f60a80b21 Mon Sep 17 00:00:00 2001 From: Qingyi Liu Date: Tue, 15 Nov 2022 09:27:32 +0800 Subject: [PATCH] [Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875) --- lib/Analysis/Allocation.cpp | 7 ++++ .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 42 +++++++++++++++++++ python/tests/test_elementwise.py | 1 + test/Analysis/test-allocation.mlir | 4 ++ test/Analysis/test-membar.mlir | 4 ++ 5 files changed, 58 insertions(+) diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 24caa01ef..39659cbf3 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -161,9 +161,16 @@ private: // TODO(Keren): Reduce with index is not supported yet. auto value = op->getOperand(0); if (auto tensorType = value.getType().dyn_cast()) { + auto srcLayout = tensorType.getEncoding(); + bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0]; auto smemShape = getScratchConfigForReduce(reduceOp); unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, std::multiplies{}); + if (fastReduce) { + auto mod = op->getParentOfType(); + unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + elems = std::max(elems, numWarps * 32); + } auto bytes = elems * tensorType.getElementTypeBitWidth() / 8; allocation->addBuffer(op, bytes); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 982ba7d37..d31b3df3c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -64,6 +64,12 @@ Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { rewriter.getF32FloatAttr(v)); } +Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + // Create an index type constant. static Value createIndexConstant(OpBuilder &builder, Location loc, TypeConverter *converter, int64_t value) { @@ -132,6 +138,7 @@ void llPrintf(StringRef msg, ValueRange args, #define f64_ty rewriter.getF64Type() #define vec_ty(type, num) VectorType::get(num, type) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) +#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) @@ -2592,6 +2599,8 @@ public: for (unsigned i = 0; i < elems; ++i) { resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy, operands[i], loc); + if (!bool(resultVals[i])) + return failure(); } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); @@ -5834,6 +5843,32 @@ struct FDivOpConversion } }; +struct ExpOpConversionApprox + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + Value createDestOp(mlir::math::ExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + ValueRange operands, Location loc) const { + // For FP64 input, call __nv_expf for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() == 64) + return {}; + const double log2e = 1.4426950408889634; + Value prod = + rewriter.create(loc, f32_ty, operands[0], f32_val(log2e)); + PTXBuilder ptxBuilder; + auto &exp2 = ptxBuilder.create("ex2")->o("approx").o("f32"); + auto output = ptxBuilder.newOperand("=f"); + auto input = ptxBuilder.newOperand(prod, "f"); + exp2(output, input); + return ptxBuilder.launch(rewriter, loc, f32_ty, false); + } +}; + /// ====================== atomic_rmw codegen begin ========================== struct AtomicRMWOpConversion : public ConvertTritonGPUOpToLLVMPattern, @@ -5994,6 +6029,13 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + + // ExpOpConversionApprox will try using ex2.approx if the input type is FP32. + // For FP64 input type, ExpOpConversionApprox will return failure and + // ElementwiseOpConversion defined below will call + // __nv_expf for higher-precision calculation + patterns.add(typeConverter, benefit); + #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); diff --git a/python/tests/test_elementwise.py b/python/tests/test_elementwise.py index 8f0b2682f..2ff4acae9 100644 --- a/python/tests/test_elementwise.py +++ b/python/tests/test_elementwise.py @@ -61,6 +61,7 @@ def get_tensor(shape, data_type, b_positive=False): ('sqrt', 'float64', 'float64'), ('abs', 'float32', 'float32'), ('exp', 'float32', 'float32'), + ('exp', 'float64', 'float64'), ('sigmoid', 'float32', 'float32'), ]) def test_single_input(expr, output_type, input0_type): diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 7e2a7d675..28c73b45f 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -9,6 +9,8 @@ #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_loop func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -313,3 +315,5 @@ func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B return // CHECK-NEXT: size = 40960 } + +} diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 8aeb7b2dd..50c8c22c1 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -9,6 +9,8 @@ #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_loop // There shouldn't be any membar with the dot op encoding. func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { @@ -250,3 +252,5 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> return } + +}