Fix OpBuilder
This commit is contained in:
@@ -1155,9 +1155,12 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
||||
})
|
||||
// // Input/Output
|
||||
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value {
|
||||
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||
mlir::triton::CacheModifier cacheModifer,
|
||||
mlir::triton::EvictionPolicy evictionPolicy,
|
||||
bool isVolatile) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::LoadOp>(loc, ptrs);
|
||||
return self.create<mlir::triton::LoadOp>(loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
|
||||
})
|
||||
.def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -1200,8 +1203,7 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
// TODO: should be scalar type here
|
||||
auto argType = arg.getType();
|
||||
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||
);
|
||||
@@ -1246,9 +1248,9 @@ void init_triton_ir(py::module &&m) {
|
||||
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
||||
);
|
||||
})
|
||||
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value {
|
||||
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c, bool allowTF32) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c);
|
||||
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c, allowTF32);
|
||||
})
|
||||
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
||||
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
||||
@@ -1257,7 +1259,11 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_trans", &ir::builder::create_trans, ret::reference)
|
||||
// .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
|
||||
// .def("create_reduce", &ir::builder::create_reduce, ret::reference)
|
||||
// .def("create_select", &ir::builder::create_select, ret::reference)
|
||||
.def("create_select", [](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::SelectOp>(loc, condition, trueValue, falseValue);
|
||||
})
|
||||
// // Intrinsics
|
||||
// // These have no place in the IR, and hopefully they can be removed at some point
|
||||
// .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
|
||||
|
Reference in New Issue
Block a user