[FRONTEND] Fixed up type cast in atomics codegen (#853)
This commit is contained in:
@@ -1080,7 +1080,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
[](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||||
mlir::Value &val) -> mlir::Value {
|
mlir::Value &val) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto ptrType = ptr.getType().dyn_cast<mlir::triton::PointerType>();
|
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||||
|
.cast<mlir::triton::PointerType>();
|
||||||
mlir::Type dstType = ptrType.getPointeeType();
|
mlir::Type dstType = ptrType.getPointeeType();
|
||||||
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
return self.create<mlir::triton::AtomicCASOp>(loc, dstType, ptr,
|
||||||
cmp, val);
|
cmp, val);
|
||||||
@@ -1090,7 +1091,8 @@ void init_triton_ir(py::module &&m) {
|
|||||||
mlir::Value &ptr, mlir::Value &val,
|
mlir::Value &ptr, mlir::Value &val,
|
||||||
mlir::Value &mask) -> mlir::Value {
|
mlir::Value &mask) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto ptrType = ptr.getType().dyn_cast<mlir::triton::PointerType>();
|
auto ptrType = mlir::getElementTypeOrSelf(ptr)
|
||||||
|
.cast<mlir::triton::PointerType>();
|
||||||
mlir::Type dstType = ptrType.getPointeeType();
|
mlir::Type dstType = ptrType.getPointeeType();
|
||||||
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
return self.create<mlir::triton::AtomicRMWOp>(loc, dstType, rmwOp,
|
||||||
ptr, val, mask);
|
ptr, val, mask);
|
||||||
|
Reference in New Issue
Block a user