[TRITON-MLIR][BACKEND]support atomic_cas (#914)
1. support atomics-cas 2. add xchg support in atomic_rmw Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("AND", mlir::triton::RMWOp::AND)
|
||||
.value("OR", mlir::triton::RMWOp::OR)
|
||||
.value("XOR", mlir::triton::RMWOp::XOR)
|
||||
// .value("XCHG", mlir::triton::RMWOp::Xchg)
|
||||
.value("XCHG", mlir::triton::RMWOp::XCHG)
|
||||
.value("MAX", mlir::triton::RMWOp::MAX)
|
||||
.value("MIN", mlir::triton::RMWOp::MIN)
|
||||
.value("UMIN", mlir::triton::RMWOp::UMIN)
|
||||
@@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||
mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||
.cast<mlir::triton::PointerType>();
|
||||
mlir::Type dstType = ptrType.getPointeeType();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||
.cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
|
||||
dstElemType);
|
||||
} else {
|
||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||
.cast<mlir::triton::PointerType>();
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
||||
cmp, val);
|
||||
})
|
||||
|
Reference in New Issue
Block a user