[Triton-MLIR][Backend]add atomic rmw without mask (#842)

add atomic without mask

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-10 08:15:58 +08:00
committed by GitHub
parent 0c87360657
commit 4640023d9b
2 changed files with 134 additions and 1 deletions

View File

@@ -5116,6 +5116,127 @@ struct FDivOpConversion
}
};
/// ====================== atomic_rmw codegen begin ==========================
struct AtomicRMWOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>,
public LoadStoreConversionBase {
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(LLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto atomicRmwAttr = op.atomic_rmw_op();
Value ptr = op.ptr();
Value val = op.val();
Value llPtr = adaptor.ptr();
Value llVal = adaptor.val();
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
// TODO[dongdongl]: Support mask and scalar
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
auto valTy = val.getType().cast<RankedTensorType>();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto vec = getVectorSize(ptr);
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
auto vecTy = vec_ty(valueElemTy, vec);
auto elemsPerThread = getElemsPerThread(val.getType());
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmvVal = undef(vecTy);
for (int ii = 0; ii < vec; ++ii) {
Value iiVal = createIndexAttrConstant(
rewriter, loc, getTypeConverter()->getIndexType(), ii);
rmvVal = insert_element(vecTy, rmvVal, valElements[i], iiVal);
}
Value rmwPtr = bitcast(ptrElements[i], ptr_ty(valTy.getElementType()));
std::string sTy;
PTXBuilder ptxBuilder;
auto *dstOpr = ptxBuilder.newOperand("=r");
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "r");
auto *valOpr = ptxBuilder.newOperand(rmvVal, "r");
auto &atom = *ptxBuilder.create<>("atom");
atom.o("global").o("gpu");
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
auto sBits = std::to_string(valueElemNbits);
switch (atomicRmwAttr) {
case RMWOp::AND:
sTy = "b" + sBits;
break;
case RMWOp::OR:
sTy = "b" + sBits;
break;
case RMWOp::XOR:
sTy = "b" + sBits;
break;
case RMWOp::ADD:
sTy = "s" + sBits;
break;
case RMWOp::FADD:
rmwOp = "add";
rmwOp += (valueElemNbits == 16 ? ".noftz" : "");
sTy = "f" + sBits;
sTy += (vec == 2 && valueElemNbits == 16) ? "x2" : "";
break;
case RMWOp::MAX:
sTy = "s" + sBits;
break;
case RMWOp::MIN:
sTy = "s" + sBits;
break;
case RMWOp::UMAX:
rmwOp = "max";
sTy = "u" + sBits;
break;
case RMWOp::UMIN:
rmwOp = "min";
sTy = "u" + sBits;
break;
default:
return failure();
}
atom.o(rmwOp).o(sTy);
atom(dstOpr, ptrOpr, valOpr);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy, false);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] =
vec == 1 ? ret : extract_element(vecTy, ret, idx_val(ii));
}
}
Type structTy = getTypeConverter()->convertType(valueTy);
Value resultStruct =
getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
return success();
}
};
/// ====================== atomic_rmw codegen end ==========================
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
@@ -5187,7 +5308,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);