[Triton-MLIR][Backend]add atomic rmw without mask (#842)
add atomic without mask Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -5116,6 +5116,127 @@ struct FDivOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// ====================== atomic_rmw codegen begin ==========================
|
||||||
|
struct AtomicRMWOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
|
||||||
|
public LoadStoreConversionBase {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
||||||
|
AxisInfoAnalysis &axisAnalysisPass,
|
||||||
|
PatternBenefit benefit)
|
||||||
|
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
|
||||||
|
benefit),
|
||||||
|
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
|
||||||
|
auto atomicRmwAttr = op.atomic_rmw_op();
|
||||||
|
Value ptr = op.ptr();
|
||||||
|
Value val = op.val();
|
||||||
|
|
||||||
|
Value llPtr = adaptor.ptr();
|
||||||
|
Value llVal = adaptor.val();
|
||||||
|
|
||||||
|
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||||
|
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||||
|
|
||||||
|
// TODO[dongdongl]: Support mask and scalar
|
||||||
|
|
||||||
|
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!valueTy)
|
||||||
|
return failure();
|
||||||
|
Type valueElemTy =
|
||||||
|
getTypeConverter()->convertType(valueTy.getElementType());
|
||||||
|
|
||||||
|
auto valTy = val.getType().cast<RankedTensorType>();
|
||||||
|
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||||
|
auto vec = getVectorSize(ptr);
|
||||||
|
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||||
|
|
||||||
|
auto vecTy = vec_ty(valueElemTy, vec);
|
||||||
|
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||||
|
SmallVector<Value> resultVals(elemsPerThread);
|
||||||
|
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||||
|
Value rmvVal = undef(vecTy);
|
||||||
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
|
Value iiVal = createIndexAttrConstant(
|
||||||
|
rewriter, loc, getTypeConverter()->getIndexType(), ii);
|
||||||
|
rmvVal = insert_element(vecTy, rmvVal, valElements[i], iiVal);
|
||||||
|
}
|
||||||
|
Value rmwPtr = bitcast(ptrElements[i], ptr_ty(valTy.getElementType()));
|
||||||
|
std::string sTy;
|
||||||
|
PTXBuilder ptxBuilder;
|
||||||
|
|
||||||
|
auto *dstOpr = ptxBuilder.newOperand("=r");
|
||||||
|
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
|
||||||
|
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
|
||||||
|
|
||||||
|
auto &atom = *ptxBuilder.create<>("atom");
|
||||||
|
|
||||||
|
atom.o("global").o("gpu");
|
||||||
|
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||||
|
auto sBits = std::to_string(valueElemNbits);
|
||||||
|
switch (atomicRmwAttr) {
|
||||||
|
case RMWOp::AND:
|
||||||
|
sTy = "b" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::OR:
|
||||||
|
sTy = "b" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::XOR:
|
||||||
|
sTy = "b" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::ADD:
|
||||||
|
sTy = "s" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::FADD:
|
||||||
|
rmwOp = "add";
|
||||||
|
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
|
||||||
|
sTy = "f" + sBits;
|
||||||
|
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
|
||||||
|
break;
|
||||||
|
case RMWOp::MAX:
|
||||||
|
sTy = "s" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::MIN:
|
||||||
|
sTy = "s" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::UMAX:
|
||||||
|
rmwOp = "max";
|
||||||
|
sTy = "u" + sBits;
|
||||||
|
break;
|
||||||
|
case RMWOp::UMIN:
|
||||||
|
rmwOp = "min";
|
||||||
|
sTy = "u" + sBits;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
atom.o(rmwOp).o(sTy);
|
||||||
|
|
||||||
|
atom(dstOpr, ptrOpr, valOpr);
|
||||||
|
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy, false);
|
||||||
|
for (int ii = 0; ii < vec; ++ii) {
|
||||||
|
resultVals[i * vec + ii] =
|
||||||
|
vec == 1 ? ret : extract_element(vecTy, ret, idx_val(ii));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||||
|
Value resultStruct =
|
||||||
|
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
|
rewriter.replaceOp(op, {resultStruct});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
/// ====================== atomic_rmw codegen end ==========================
|
||||||
|
|
||||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns, int numWarps,
|
RewritePatternSet &patterns, int numWarps,
|
||||||
AxisInfoAnalysis &axisInfoAnalysis,
|
AxisInfoAnalysis &axisInfoAnalysis,
|
||||||
@@ -5187,7 +5308,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
|
patterns.add<AtomicRMWOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||||
|
@@ -830,3 +830,15 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||||
|
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||||
|
// CHECK-LABEL: atomic_add_f32
|
||||||
|
func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||||
|
// CHECK: llvm.inline_asm
|
||||||
|
// CHECK-SAME: atom.global.gpu.add.f32
|
||||||
|
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user