[TRITON-MLIR][BACKEND]support atomic_cas (#914)

1. support atomics-cas
2. add xchg support in atomic_rmw

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-25 12:02:08 +08:00
committed by GitHub
parent 153aecb339
commit f63be0e9b5
6 changed files with 150 additions and 15 deletions

View File

@@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) {
.value("AND", mlir::triton::RMWOp::AND)
.value("OR", mlir::triton::RMWOp::OR)
.value("XOR", mlir::triton::RMWOp::XOR)
// .value("XCHG", mlir::triton::RMWOp::Xchg)
.value("XCHG", mlir::triton::RMWOp::XCHG)
.value("MAX", mlir::triton::RMWOp::MAX)
.value("MIN", mlir::triton::RMWOp::MIN)
.value("UMIN", mlir::triton::RMWOp::UMIN)
@@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType();
mlir::Type dstType;
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
cmp, val);
})