[TRITON-MLIR][BACKEND]AtomicRMWOp supports scalar (#903)

AtomicRMWOp supports scalar

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
donproc
2022-11-23 15:59:09 +08:00
committed by GitHub
parent 2e33352419
commit 8925c2cd11
5 changed files with 163 additions and 125 deletions

View File

@@ -1106,7 +1106,18 @@ void init_triton_ir(py::module &&m) {
mlir::Value &ptr, mlir::Value &val,
mlir::Value &mask) -> mlir::Value {
auto loc = self.getUnknownLoc();
mlir::Type dstType = val.getType();
mlir::Type dstType;
if (auto srcTensorType = ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
mlir::Type dstElemType = srcTensorType.getElementType()
.cast<mlir::triton::PointerType>()
.getPointeeType();
dstType = mlir::RankedTensorType::get(srcTensorType.getShape(),
dstElemType);
} else {
auto ptrType = mlir::getElementTypeOrSelf(ptr)
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
ptr, val, mask);
})