[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -5947,10 +5947,11 @@ struct AtomicRMWOpConversion
|
||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
|
||||
benefit),
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -5971,30 +5972,29 @@ struct AtomicRMWOpConversion
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||
|
||||
// TODO[dongdongl]: Support scalar
|
||||
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy)
|
||||
return failure();
|
||||
Type valueElemTy =
|
||||
getTypeConverter()->convertType(valueTy.getElementType());
|
||||
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
// tensor
|
||||
if (valueTy) {
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
// mask
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
}
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// mask
|
||||
Value mask = int_val(1, 1);
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
for (size_t i = 0; i < elemsPerThread; i += vec) {
|
||||
Value rmwVal = undef(vecTy);
|
||||
@@ -6008,10 +6008,12 @@ struct AtomicRMWOpConversion
|
||||
rmwMask = and_(rmwMask, mask);
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilder;
|
||||
|
||||
auto *dstOpr = ptxBuilder.newOperand("=r");
|
||||
std::string tyId = valueElemNbits * vec == 64
|
||||
? "l"
|
||||
: (valueElemNbits * vec == 32 ? "r" : "h");
|
||||
auto *dstOpr = ptxBuilder.newOperand("=" + tyId);
|
||||
auto *ptrOpr = ptxBuilder.newAddrOperand(rmwPtr, "l");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, "r");
|
||||
auto *valOpr = ptxBuilder.newOperand(rmwVal, tyId);
|
||||
|
||||
auto &atom = ptxBuilder.create<>("atom")->global().o("gpu");
|
||||
auto rmwOp = stringifyRMWOp(atomicRmwAttr).str();
|
||||
@@ -6053,18 +6055,32 @@ struct AtomicRMWOpConversion
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
if (valueTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
}
|
||||
} else {
|
||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
|
||||
store(old, atomPtr);
|
||||
barrier();
|
||||
Value ret = load(atomPtr);
|
||||
barrier();
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
if (valueTy) {
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -6150,7 +6166,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem, axisInfoAnalysis, benefit);
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
|
Reference in New Issue
Block a user