diff --git a/python/src/triton.cc b/python/src/triton.cc index 6bf5c79e0..2bbd0c67d 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -950,12 +950,18 @@ void init_triton_ir(py::module &&m) { mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) ); }) - // .def("create_lshr", &ir::builder::create_lshr, ret::reference, - // py::arg("lhs"), py::arg("rhs"), - // py::arg("has_nuw")=false, py::arg("has_nsw")=false) - // .def("create_ashr", &ir::builder::create_ashr, ret::reference, - // py::arg("lhs"), py::arg("rhs"), - // py::arg("has_nuw")=false, py::arg("has_nsw")=false) + .def("create_lshr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); + }) + .def("create_ashr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap( + mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) + ); + }) // GEP .def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue { auto loc = self.getUnknownLoc();