[TRITON-MLIR][BACKEND]support atomic_cas (#914)

1. support atomics-cas
2. add xchg support in atomic_rmw

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-25 12:02:08 +08:00
committed by GitHub
parent 153aecb339
commit f63be0e9b5
6 changed files with 150 additions and 15 deletions

View File

@@ -105,7 +105,7 @@ void init_triton_ir(py::module &&m) {
.value("AND", mlir::triton::RMWOp::AND)
.value("OR", mlir::triton::RMWOp::OR)
.value("XOR", mlir::triton::RMWOp::XOR)
// .value("XCHG", mlir::triton::RMWOp::Xchg)
.value("XCHG", mlir::triton::RMWOp::XCHG)
.value("MAX", mlir::triton::RMWOp::MAX)
.value("MIN", mlir::triton::RMWOp::MIN)
.value("UMIN", mlir::triton::RMWOp::UMIN)
@@ -1095,9 +1095,18 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
mlir::Type dstType = ptrType.getPointeeType();
mlir::Type dstType;
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
cmp, val);
})

View File

@@ -700,6 +700,16 @@ def test_tensor_atomic_rmw(axis, device="cuda"):
# serialized_add[(64,)](data, Lock)
# triton.testing.assert_almost_equal(data, ref)
def test_simple_atomic_cas():
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1,), device='cuda', dtype=torch.int32)
change_value[(1,)](Lock)
assert (Lock[0] == 1)
# # ---------------
# # test cast