[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)

AtomicRMWOp supports scalar

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-23 15:59:09 +08:00
committed by GitHub
parent 2e33352419
commit 8925c2cd11
5 changed files with 163 additions and 125 deletions

View File

@@ -112,6 +112,20 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
return smemShape;
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
auto ptrTy = op.ptr().getType();
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
// do nothing or just assert because shared memory is not used in tensor
} else {
// need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar?
smemShape.push_back(1);
}
return smemShape;
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -200,6 +214,23 @@ private:
elems * kPtrBitWidth / 8 :
elems * srcTy.getElementTypeBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (value.getType().dyn_cast<RankedTensorType>()) {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}
}