From a17fba86b11f910bed1ba84ee158ae1521578ab8 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sat, 26 Mar 2022 16:16:20 +0800 Subject: [PATCH] Logic Op creation --- python/src/triton.cc | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index d968af181..2e825d7a2 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1120,20 +1120,35 @@ void init_triton_ir(py::module &&m) { ))); }, ret::reference) // // Logical - // .def("create_and", &ir::builder::create_and, ret::reference) - // .def("create_xor", &ir::builder::create_xor, ret::reference) - // .def("create_or", &ir::builder::create_or, ret::reference) + .def("create_and", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, unwrap(lhs), unwrap(rhs) + ))); + }, ret::reference) + .def("create_xor", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, unwrap(lhs), unwrap(rhs) + ))); + }, ret::reference) + .def("create_or", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + auto loc = self.getUnknownLoc(); + return wrap(mlir::Value(self.create( + loc, unwrap(lhs), unwrap(rhs) + ))); + }, ret::reference) // // Input/Output .def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue { auto loc = self.getUnknownLoc(); return wrap(mlir::Value( self.create(loc, unwrap(ptrs)) )); - }, ret::reference) + }) .def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void { auto loc = self.getUnknownLoc(); self.create(loc, unwrap(ptrs), unwrap(value)); - }, ret::reference) + }) // .def("create_masked_load", &ir::builder::create_masked_load, ret::reference) // .def("create_masked_store", &ir::builder::create_masked_store, ret::reference) // // Block instruction @@ -1146,7 +1161,7 @@ void init_triton_ir(py::module &&m) { return wrap(mlir::Value(self.create( loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg) ))); - }, ret::reference) + }) // // atomic // .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference) // .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)