[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)
AtomicRMWOp supports scalar Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -112,6 +112,20 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
// TODO: extend beyond scalars
|
||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||
SmallVector<unsigned> smemShape;
|
||||
auto ptrTy = op.ptr().getType();
|
||||
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
|
||||
// do nothing or just assert because shared memory is not used in tensor
|
||||
} else {
|
||||
// need only bytes for scalar
|
||||
// always vec = 1 and elemsPerThread = 1 for scalar?
|
||||
smemShape.push_back(1);
|
||||
}
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
class AllocationAnalysis {
|
||||
public:
|
||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||
@@ -200,6 +214,23 @@ private:
|
||||
elems * kPtrBitWidth / 8 :
|
||||
elems * srcTy.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
|
||||
auto value = op->getOperand(0);
|
||||
// only scalar requires scratch memory
|
||||
// make it explicit for readability
|
||||
if (value.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nothing to do
|
||||
} else {
|
||||
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
auto elemTy =
|
||||
value.getType().cast<triton::PointerType>().getPointeeType();
|
||||
auto bytes = elemTy.isa<triton::PointerType>()
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user