diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2d79611f6..e9d0942df 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -5116,6 +5116,127 @@ struct FDivOpConversion } }; +/// ====================== atomic_rmw codegen begin ========================== +struct AtomicRMWOpConversion + : public ConvertTritonGPUOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertTritonGPUOpToLLVMPattern< + triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; + + AtomicRMWOpConversion(LLVMTypeConverter &converter, + AxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto atomicRmwAttr = op.atomic_rmw_op(); + Value ptr = op.ptr(); + Value val = op.val(); + + Value llPtr = adaptor.ptr(); + Value llVal = adaptor.val(); + + auto valElements = getElementsFromStruct(loc, llVal, rewriter); + auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter); + + // TODO[dongdongl]: Support mask and scalar + + auto valueTy = op.getResult().getType().dyn_cast(); + if (!valueTy) + return failure(); + Type valueElemTy = + getTypeConverter()->convertType(valueTy.getElementType()); + + auto valTy = val.getType().cast(); + const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); + auto vec = getVectorSize(ptr); + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + + auto vecTy = vec_ty(valueElemTy, vec); + auto elemsPerThread = getElemsPerThread(val.getType()); + SmallVector resultVals(elemsPerThread); + for (size_t i = 0; i < elemsPerThread; i += vec) { + Value rmvVal = undef(vecTy); + for (int ii = 0; ii < vec; ++ii) { + Value iiVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), ii); + rmvVal = insert_element(vecTy, rmvVal, valElements[i], iiVal); + } + Value rmwPtr = bitcast(ptrElements[i], ptr_ty(valTy.getElementType())); + std::string sTy; + PTXBuilder ptxBuilder; + + auto *dstOpr = ptxBuilder.newOperand("=r"); + auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r"); + auto *valOpr = ptxBuilder.newOperand(rmvVal, "r"); + + auto &atom = *ptxBuilder.create<>("atom"); + + atom.o("global").o("gpu"); + auto rmwOp = stringifyRMWOp(atomicRmwAttr).str(); + auto sBits = std::to_string(valueElemNbits); + switch (atomicRmwAttr) { + case RMWOp::AND: + sTy = "b" + sBits; + break; + case RMWOp::OR: + sTy = "b" + sBits; + break; + case RMWOp::XOR: + sTy = "b" + sBits; + break; + case RMWOp::ADD: + sTy = "s" + sBits; + break; + case RMWOp::FADD: + rmwOp = "add"; + rmwOp += (valueElemNbits == 16 ? ".noftz" : ""); + sTy = "f" + sBits; + sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : ""; + break; + case RMWOp::MAX: + sTy = "s" + sBits; + break; + case RMWOp::MIN: + sTy = "s" + sBits; + break; + case RMWOp::UMAX: + rmwOp = "max"; + sTy = "u" + sBits; + break; + case RMWOp::UMIN: + rmwOp = "min"; + sTy = "u" + sBits; + break; + default: + return failure(); + } + atom.o(rmwOp).o(sTy); + + atom(dstOpr, ptrOpr, valOpr); + auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy, false); + for (int ii = 0; ii < vec; ++ii) { + resultVals[i * vec + ii] = + vec == 1 ? ret : extract_element(vecTy, ret, idx_val(ii)); + } + } + Type structTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = + getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; +/// ====================== atomic_rmw codegen end ========================== + void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, @@ -5187,7 +5308,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, allocation, smem, benefit); - + patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, benefit); patterns.add(typeConverter, benefit); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 30da6ad52..c3a6fd63a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -830,3 +830,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { return } } + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f32 + func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + // CHECK: llvm.inline_asm + // CHECK-SAME: atom.global.gpu.add.f32 + %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> + return + } +} \ No newline at end of file