[TRITON-MLIR][BACKEND]support atomic_cas (#914)

1. support atomics-cas
2. add xchg support in atomic_rmw

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-25 12:02:08 +08:00
committed by GitHub
parent 153aecb339
commit f63be0e9b5
6 changed files with 150 additions and 15 deletions

View File

@@ -115,9 +115,9 @@ SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
auto ptrTy = op.ptr().getType();
if (auto tensorType = ptrTy.dyn_cast<RankedTensorType>()) {
// do nothing or just assert because shared memory is not used in tensor
if (op.ptr().getType().isa<RankedTensorType>()) {
// do nothing or just assert because shared memory is not used in tensor up
// to now
} else {
// need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar?
@@ -126,6 +126,10 @@ SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
return smemShape;
}
SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
return SmallVector<unsigned>{1};
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -230,6 +234,17 @@ private:
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elemTy =
value.getType().cast<triton::PointerType>().getPointeeType();
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
}