From b6f15e214bca521768ce0ab818d3144bd2e980b0 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 7 Nov 2022 05:46:24 -0800 Subject: [PATCH] [FRONTEND] Fixed up type cast in atomics codegen (#853) --- python/src/triton.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index b3b745c43..6f60d03c0 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1080,7 +1080,8 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &cmp, mlir::Value &val) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = ptr.getType().dyn_cast(); + auto ptrType = mlir::getElementTypeOrSelf(ptr) + .cast(); mlir::Type dstType = ptrType.getPointeeType(); return self.create(loc, dstType, ptr, cmp, val); @@ -1090,7 +1091,8 @@ void init_triton_ir(py::module &&m) { mlir::Value &ptr, mlir::Value &val, mlir::Value &mask) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = ptr.getType().dyn_cast(); + auto ptrType = mlir::getElementTypeOrSelf(ptr) + .cast(); mlir::Type dstType = ptrType.getPointeeType(); return self.create(loc, dstType, rmwOp, ptr, val, mask);