[TRITON-MLIR][BACKEND]support atomic_cas (#914)
1. support atomics-cas 2. add xchg support in atomic_rmw Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("AND", mlir::triton::RMWOp::AND)
|
||||
.value("OR", mlir::triton::RMWOp::OR)
|
||||
.value("XOR", mlir::triton::RMWOp::XOR)
|
||||
// .value("XCHG", mlir::triton::RMWOp::Xchg)
|
||||
.value("XCHG", mlir::triton::RMWOp::XCHG)
|
||||
.value("MAX", mlir::triton::RMWOp::MAX)
|
||||
.value("MIN", mlir::triton::RMWOp::MIN)
|
||||
.value("UMIN", mlir::triton::RMWOp::UMIN)
|
||||
@@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||
mlir::Value &val) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||
.cast<mlir::triton::PointerType>();
|
||||
mlir::Type dstType = ptrType.getPointeeType();
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
mlir::Type dstElemType = srcTensorType.getElementType()
|
||||
.cast<mlir::triton::PointerType>()
|
||||
.getPointeeType();
|
||||
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
|
||||
dstElemType);
|
||||
} else {
|
||||
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||
.cast<mlir::triton::PointerType>();
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
||||
cmp, val);
|
||||
})
|
||||
|
@@ -700,6 +700,16 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
|
||||
# serialized_add[(64,)](data, Lock)
|
||||
# triton.testing.assert_almost_equal(data, ref)
|
||||
|
||||
def test_simple_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
def change_value(Lock):
|
||||
tl.atomic_cas(Lock, 0, 1)
|
||||
|
||||
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
|
||||
change_value[(1,)](Lock)
|
||||
|
||||
assert (Lock[0] == 1)
|
||||
|
||||
# # ---------------
|
||||
# # test cast
|
||||
|
Reference in New Issue
Block a user