Logic Op creation

This commit is contained in:
Yan Da
2022-03-26 16:16:20 +08:00
parent 5e117966d0
commit a17fba86b1

View File

@@ -1120,20 +1120,35 @@ void init_triton_ir(py::module &&m) {
)));
}, ret::reference)
// // Logical
// .def("create_and", &ir::builder::create_and, ret::reference)
// .def("create_xor", &ir::builder::create_xor, ret::reference)
// .def("create_or", &ir::builder::create_or, ret::reference)
.def("create_and", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::AndIOp>(
loc, unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_xor", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::XOrIOp>(
loc, unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
.def("create_or", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(self.create<mlir::arith::OrIOp>(
loc, unwrap(lhs), unwrap(rhs)
)));
}, ret::reference)
// // Input/Output
.def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue {
auto loc = self.getUnknownLoc();
return wrap(mlir::Value(
self.create<mlir::triton::LoadOp>(loc, unwrap(ptrs))
));
}, ret::reference)
})
.def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value));
}, ret::reference)
})
// .def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
// .def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
// // Block instruction
@@ -1146,7 +1161,7 @@ void init_triton_ir(py::module &&m) {
return wrap(mlir::Value(self.create<mlir::triton::BroadcastOp>(
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg)
)));
}, ret::reference)
})
// // atomic
// .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
// .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)