|
|
@@ -641,21 +641,15 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
|
|
|
// // py::class_<ir::undef_value, ir::constant>(m, "undef")
|
|
|
|
// // .def("get", &ir::undef_value::get, ret::reference);
|
|
|
|
// // .def("get", &ir::undef_value::get, ret::reference);
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<MlirModule>(m, "module")
|
|
|
|
py::class_<mlir::ModuleOp>(m, "module")
|
|
|
|
// .def("set_attr")
|
|
|
|
// .def("set_attr")
|
|
|
|
.def("dump", [](MlirModule &self) -> void {
|
|
|
|
.def("dump", [](mlir::ModuleOp &self) -> void {
|
|
|
|
unwrap(self).dump();
|
|
|
|
self.dump();
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("push_back", [](MlirModule &self, MlirOperation &funcOperation) {
|
|
|
|
.def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
|
|
|
|
if (auto info = unwrap(funcOperation)->getRegisteredInfo()) {
|
|
|
|
self.push_back(funcOp);
|
|
|
|
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
|
|
|
|
|
|
|
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation));
|
|
|
|
|
|
|
|
unwrap(self).push_back(funcOp);
|
|
|
|
|
|
|
|
} else
|
|
|
|
|
|
|
|
throw std::runtime_error("Only FuncOp can call push_back");
|
|
|
|
|
|
|
|
} else
|
|
|
|
|
|
|
|
throw std::runtime_error("Unknown error");
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("get_context", &mlir::ModuleOp::getContext)
|
|
|
|
;
|
|
|
|
;
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<MlirType>(m, "type")
|
|
|
|
py::class_<MlirType>(m, "type")
|
|
|
@@ -667,23 +661,6 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
})
|
|
|
|
})
|
|
|
|
;
|
|
|
|
;
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<MlirOperation>(m, "operation")
|
|
|
|
|
|
|
|
.def("add_entry_block", [](MlirOperation &self) -> mlir::Block {
|
|
|
|
|
|
|
|
if (auto info = unwrap(self)->getRegisteredInfo()) {
|
|
|
|
|
|
|
|
if (mlir::TypeID::get<mlir::FuncOp>() == info->getTypeID()) {
|
|
|
|
|
|
|
|
auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self));
|
|
|
|
|
|
|
|
mlir::Block *entry = funcOp.addEntryBlock();
|
|
|
|
|
|
|
|
return *entry;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
|
|
|
|
|
|
|
} else
|
|
|
|
|
|
|
|
throw std::runtime_error("Unknown error");
|
|
|
|
|
|
|
|
}) // this should be automatic?
|
|
|
|
|
|
|
|
.def("dump", [](MlirOperation &self) -> void {
|
|
|
|
|
|
|
|
unwrap(self)->dump();
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<mlir::Value>(m, "value")
|
|
|
|
py::class_<mlir::Value>(m, "value")
|
|
|
|
;
|
|
|
|
;
|
|
|
|
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
|
|
|
|
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement")
|
|
|
@@ -693,6 +670,7 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
|
|
|
|
.def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument {
|
|
|
|
return self.getArgument(index);
|
|
|
|
return self.getArgument(index);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("dump", &mlir::Block::dump)
|
|
|
|
;
|
|
|
|
;
|
|
|
|
|
|
|
|
|
|
|
|
// py::class_<mlir::ModuleOp>(m, "module")
|
|
|
|
// py::class_<mlir::ModuleOp>(m, "module")
|
|
|
@@ -720,29 +698,34 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
// py::class_<mlir::Attribute>(m, "attribute");
|
|
|
|
// py::class_<mlir::Attribute>(m, "attribute");
|
|
|
|
// // .def(py::init<eattr, int>());
|
|
|
|
// // .def(py::init<eattr, int>());
|
|
|
|
|
|
|
|
|
|
|
|
// py::class_<mlir::FuncOp>(m, "function")
|
|
|
|
py::class_<mlir::FuncOp>(m, "function")
|
|
|
|
// .def_property_readonly("args", &ir::function::args)
|
|
|
|
// .def_property_readonly("args", &ir::function::args)
|
|
|
|
// .def_property_readonly("attrs", &ir::function::attrs)
|
|
|
|
// .def_property_readonly("attrs", &ir::function::attrs)
|
|
|
|
// .def("add_attr", &ir::function::add_attr);
|
|
|
|
// .def("add_attr", &ir::function::add_attr);
|
|
|
|
|
|
|
|
.def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
|
|
|
|
// // // We don't need to expose mlir::Block (?)
|
|
|
|
return self.getArgument(idx);
|
|
|
|
// // py::class_<mlir::Block>(m, "basic_block")
|
|
|
|
})
|
|
|
|
// // // .def("create", &ir::basic_block::create, ret::reference)
|
|
|
|
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
|
|
|
// // .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
|
|
|
|
return self.addEntryBlock();
|
|
|
|
// // .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
|
|
|
|
}, ret::reference)
|
|
|
|
|
|
|
|
.def("dump", [](mlir::FuncOp &self) { self.dump(); })
|
|
|
|
|
|
|
|
;
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
|
|
|
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())
|
|
|
|
.def(py::init<mlir::MLIRContext *>())
|
|
|
|
.def(py::init<mlir::MLIRContext *>())
|
|
|
|
// // getters
|
|
|
|
// // getters
|
|
|
|
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
|
|
|
|
// .def_property_readonly("context", &ir::builder::get_context, ret::reference);
|
|
|
|
.def("create_module", [](mlir::OpBuilder &self) -> MlirModule {
|
|
|
|
.def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(self.create<mlir::ModuleOp>(loc));
|
|
|
|
return self.create<mlir::ModuleOp>(loc);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// // control flow
|
|
|
|
// // control flow
|
|
|
|
// .def("br", &ir::builder::create_br, ret::reference)
|
|
|
|
// .def("br", &ir::builder::create_br, ret::reference)
|
|
|
|
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
|
|
|
// .def("cond_br", &ir::builder::create_cond_br, ret::reference)
|
|
|
|
// .def("ret_void", &ir::builder::create_ret_void, ret::reference)
|
|
|
|
.def("ret_void", [](mlir::OpBuilder &self) {
|
|
|
|
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
|
|
|
|
self.create<mlir::ReturnOp>(loc);
|
|
|
|
|
|
|
|
}, ret::reference)
|
|
|
|
// insertion block/point
|
|
|
|
// insertion block/point
|
|
|
|
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
|
|
|
|
.def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void {
|
|
|
|
self.setInsertionPointToStart(&block);
|
|
|
|
self.setInsertionPointToStart(&block);
|
|
|
@@ -750,8 +733,8 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
.def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) {
|
|
|
|
.def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) {
|
|
|
|
self.setInsertionPointToEnd(&block);
|
|
|
|
self.setInsertionPointToEnd(&block);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & {
|
|
|
|
.def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* {
|
|
|
|
return *self.getInsertionBlock();
|
|
|
|
return self.getInsertionBlock();
|
|
|
|
}, ret::reference)
|
|
|
|
}, ret::reference)
|
|
|
|
// .def("get_insert_point", [](ir::builder *self) {
|
|
|
|
// .def("get_insert_point", [](ir::builder *self) {
|
|
|
|
// ir::basic_block *bb = self->get_insert_block();
|
|
|
|
// ir::basic_block *bb = self->get_insert_block();
|
|
|
@@ -784,8 +767,10 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
|
|
|
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
|
|
|
// .def("get_uint64", &ir::builder::get_int64, ret::reference)
|
|
|
|
// .def("get_uint64", &ir::builder::get_int64, ret::reference)
|
|
|
|
// .def("get_float16", &ir::builder::get_float16, ret::reference)
|
|
|
|
// .def("get_float16", &ir::builder::get_float16, ret::reference)
|
|
|
|
// .def("get_float32", &ir::builder::get_float32, ret::reference)
|
|
|
|
.def("get_float32", [](mlir::OpBuilder &self, float v) -> mlir::Value {
|
|
|
|
// .def("get_range", &ir::builder::get_range, ret::reference)
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
|
|
|
|
return self.create<mlir::arith::ConstantOp>(loc, self.getF32FloatAttr(v));
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// Types
|
|
|
|
// Types
|
|
|
|
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
|
|
|
|
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
|
|
|
@@ -846,22 +831,22 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// Ops
|
|
|
|
// Ops
|
|
|
|
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> MlirOperation {
|
|
|
|
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp {
|
|
|
|
// TODO: loc
|
|
|
|
// TODO: loc
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) {
|
|
|
|
if (auto funcTy = unwrap(funcType).dyn_cast<mlir::FunctionType>()) {
|
|
|
|
return wrap(self.create<mlir::FuncOp>(loc, name, funcTy));
|
|
|
|
return self.create<mlir::FuncOp>(loc, name, funcTy);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
throw std::runtime_error("invalid function type");
|
|
|
|
throw std::runtime_error("invalid function type");
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// // Structured control flow
|
|
|
|
// // Structured control flow
|
|
|
|
// .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub,
|
|
|
|
// .def("create_for", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub,
|
|
|
|
// MlirValue &step, std::vector<MlirValue> &initArgs) -> MlirOperation {
|
|
|
|
// mlir::Value &step, std::vector<mlir::Value> &initArgs) -> MlirOperation {
|
|
|
|
// auto loc = self.getUnknownLoc();
|
|
|
|
// auto loc = self.getUnknownLoc();
|
|
|
|
// return wrap(self.create<mlir::scf::ForOp>(
|
|
|
|
// return wrap(self.create<mlir::scf::ForOp>(
|
|
|
|
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
|
|
|
|
// loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation());
|
|
|
|
// })
|
|
|
|
// })
|
|
|
|
// .def("create_if", [](mlir::OpBuilder &self, MlirValue &condition) -> MlirOperation {
|
|
|
|
// .def("create_if", [](mlir::OpBuilder &self, mlir::Value &condition) -> MlirOperation {
|
|
|
|
// auto loc = self.getUnknownLoc();
|
|
|
|
// auto loc = self.getUnknownLoc();
|
|
|
|
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
|
|
|
|
// return wrap(self.create<mlir::scf::IfOp>(loc, unwrap(condition)).getOperation());
|
|
|
|
// })
|
|
|
|
// })
|
|
|
@@ -872,428 +857,334 @@ void init_triton_ir(py::module &&m) {
|
|
|
|
// // .def("create_while")
|
|
|
|
// // .def("create_while")
|
|
|
|
|
|
|
|
|
|
|
|
// miscellious
|
|
|
|
// miscellious
|
|
|
|
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
|
|
|
|
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type());
|
|
|
|
auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type());
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end);
|
|
|
|
mlir::Value(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
|
|
|
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis);
|
|
|
|
mlir::Value(self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// Cast instructions
|
|
|
|
// Cast instructions
|
|
|
|
.def("create_bitcast", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::BitcastOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// .def("create_cast", &ir::builder::create_cast)
|
|
|
|
// .def("create_cast", &ir::builder::create_cast)
|
|
|
|
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
|
|
|
// .def("create_ptr_to_int", &ir::builder::create_ptr_to_int)
|
|
|
|
.def("create_si_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::SIToFPOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_ui_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::UIToFPOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fp_to_si", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::FPToSIOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fp_to_ui", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::FPToUIOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fp_ext", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::ExtFOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fp_trunc", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue {
|
|
|
|
.def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), src);
|
|
|
|
self.create<mlir::arith::TruncFOp>(loc, unwrap(dstType), unwrap(src))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// .def("create_int_cast", &ir::builder::create_int_cast)
|
|
|
|
// .def("create_int_cast", &ir::builder::create_int_cast)
|
|
|
|
// .def("create_downcast", &ir::builder::create_downcast)
|
|
|
|
// .def("create_downcast", &ir::builder::create_downcast)
|
|
|
|
.def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::MulFOp>(loc, lhs, rhs);
|
|
|
|
self.create<mlir::arith::MulFOp>(loc, unwrap(lhs), unwrap(rhs))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fdiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::DivFOp>(loc, lhs, rhs);
|
|
|
|
self.create<mlir::arith::DivFOp>(loc, unwrap(lhs), unwrap(rhs))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_frem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::RemFOp>(loc, lhs, rhs);
|
|
|
|
self.create<mlir::arith::RemFOp>(loc, unwrap(lhs), unwrap(rhs))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fadd", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::AddFOp>(loc, lhs, rhs);
|
|
|
|
self.create<mlir::arith::AddFOp>(loc, unwrap(lhs), unwrap(rhs))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fsub", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::arith::SubFOp>(loc, lhs, rhs);
|
|
|
|
self.create<mlir::arith::SubFOp>(loc, unwrap(lhs), unwrap(rhs))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_mul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
// Check lhs & rhs have single result (?)
|
|
|
|
return self.create<mlir::arith::MulIOp>(loc, lhs, rhs);
|
|
|
|
return wrap(
|
|
|
|
|
|
|
|
mlir::Value(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_sdiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::arith::DivSIOp>(loc, lhs, rhs);
|
|
|
|
mlir::Value(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_udiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::arith::DivUIOp>(loc, lhs, rhs);
|
|
|
|
mlir::Value(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_srem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::arith::RemSIOp>(loc, lhs, rhs);
|
|
|
|
mlir::Value(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_urem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::arith::RemUIOp>(loc, lhs, rhs);
|
|
|
|
mlir::Value(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_add", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::arith::AddIOp>(loc, lhs, rhs);
|
|
|
|
mlir::Value(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_sub", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return mlir::Value(self.create<mlir::arith::SubIOp>(loc, lhs, rhs));
|
|
|
|
mlir::Value(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_shl", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return mlir::Value(self.create<mlir::arith::ShLIOp>(loc, lhs, rhs));
|
|
|
|
mlir::Value(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_lshr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_lshr", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return mlir::Value(self.create<mlir::arith::ShRUIOp>(loc, lhs, rhs));
|
|
|
|
mlir::Value(self.create<mlir::arith::ShRUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_ashr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_ashr", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return mlir::Value(self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
|
|
|
|
mlir::Value(self.create<mlir::arith::ShRSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// GEP
|
|
|
|
// GEP
|
|
|
|
.def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue {
|
|
|
|
.def("create_gep", [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &offset) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(
|
|
|
|
return self.create<mlir::triton::GEPOp>(loc, ptr.getType(), ptr, offset);
|
|
|
|
mlir::Value(self.create<mlir::triton::GEPOp>(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset)))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// Comparison (int)
|
|
|
|
// Comparison (int)
|
|
|
|
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpSLE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::sle,
|
|
|
|
loc, mlir::arith::CmpIPredicate::sle, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpSLT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::slt,
|
|
|
|
loc, mlir::arith::CmpIPredicate::slt, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpSGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::sge,
|
|
|
|
loc, mlir::arith::CmpIPredicate::sge, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpSGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::sgt,
|
|
|
|
loc, mlir::arith::CmpIPredicate::sgt, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpULE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::ule,
|
|
|
|
loc, mlir::arith::CmpIPredicate::ule, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpULT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::ult,
|
|
|
|
loc, mlir::arith::CmpIPredicate::ult, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpUGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::uge,
|
|
|
|
loc, mlir::arith::CmpIPredicate::uge, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpUGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::ugt,
|
|
|
|
loc, mlir::arith::CmpIPredicate::ugt, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::eq,
|
|
|
|
loc, mlir::arith::CmpIPredicate::eq, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_icmpNE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
|
|
|
return self.create<mlir::arith::CmpIOp>(
|
|
|
|
loc, mlir::arith::CmpIPredicate::ne,
|
|
|
|
loc, mlir::arith::CmpIPredicate::ne, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// Comparison (float)
|
|
|
|
// Comparison (float)
|
|
|
|
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpOLT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::OLT,
|
|
|
|
loc, mlir::arith::CmpFPredicate::OLT, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpOGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpOGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::OGT,
|
|
|
|
loc, mlir::arith::CmpFPredicate::OGT, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpOLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpOLE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::OLE,
|
|
|
|
loc, mlir::arith::CmpFPredicate::OLE, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpOGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpOGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::OGE,
|
|
|
|
loc, mlir::arith::CmpFPredicate::OGE, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpOEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpOEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::OEQ,
|
|
|
|
loc, mlir::arith::CmpFPredicate::OEQ, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpONE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpONE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::ONE,
|
|
|
|
loc, mlir::arith::CmpFPredicate::ONE, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpULT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::ULT,
|
|
|
|
loc, mlir::arith::CmpFPredicate::ULT, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpUGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::UGT,
|
|
|
|
loc, mlir::arith::CmpFPredicate::UGT, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpULE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::ULE,
|
|
|
|
loc, mlir::arith::CmpFPredicate::ULE, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpUGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::UGE,
|
|
|
|
loc, mlir::arith::CmpFPredicate::UGE, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpUEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpUEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::UEQ,
|
|
|
|
loc, mlir::arith::CmpFPredicate::UEQ, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_fcmpUNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_fcmpUNE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
|
|
|
return self.create<mlir::arith::CmpFOp>(
|
|
|
|
loc, mlir::arith::CmpFPredicate::UNE,
|
|
|
|
loc, mlir::arith::CmpFPredicate::UNE, lhs, rhs);
|
|
|
|
unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// // Logical
|
|
|
|
// // Logical
|
|
|
|
.def("create_and", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_and", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::AndIOp>(
|
|
|
|
return self.create<mlir::arith::AndIOp>(loc, lhs, rhs);
|
|
|
|
loc, unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_xor", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_xor", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::XOrIOp>(
|
|
|
|
return self.create<mlir::arith::XOrIOp>(loc, lhs, rhs);
|
|
|
|
loc, unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_or", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_or", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::arith::OrIOp>(
|
|
|
|
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
|
|
|
loc, unwrap(lhs), unwrap(rhs)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// // Input/Output
|
|
|
|
// // Input/Output
|
|
|
|
.def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue {
|
|
|
|
.def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(
|
|
|
|
return self.create<mlir::triton::LoadOp>(loc, ptrs);
|
|
|
|
self.create<mlir::triton::LoadOp>(loc, unwrap(ptrs))
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void {
|
|
|
|
.def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value));
|
|
|
|
self.create<mlir::triton::StoreOp>(loc, ptrs, value);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_masked_load", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &mask, MlirValue &other) -> MlirValue {
|
|
|
|
.def("create_masked_load", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask, mlir::Value &other,
|
|
|
|
|
|
|
|
mlir::triton::CacheModifier cacheModifier,
|
|
|
|
|
|
|
|
mlir::triton::EvictionPolicy evictionPolicy,
|
|
|
|
|
|
|
|
bool isVolatile) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto ptrType = unwrap(ptrs).getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
auto ptrType = ptrs.getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
std::vector<int64_t> shape = ptrType.getShape();
|
|
|
|
std::vector<int64_t> shape = ptrType.getShape();
|
|
|
|
mlir::Type elementType = ptrType.getElementType().dyn_cast<mlir::triton::PointerType>().getPointeeType();
|
|
|
|
mlir::Type elementType = ptrType.getElementType().dyn_cast<mlir::triton::PointerType>().getPointeeType();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::LoadOp>(
|
|
|
|
return self.create<mlir::triton::LoadOp>(
|
|
|
|
loc, mlir::RankedTensorType::get(shape, elementType), unwrap(ptrs), unwrap(mask), unwrap(other))
|
|
|
|
loc, mlir::RankedTensorType::get(shape, elementType), ptrs, mask, other,
|
|
|
|
));
|
|
|
|
cacheModifier, evictionPolicy, isVolatile);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_masked_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &val, MlirValue &mask) -> void {
|
|
|
|
.def("create_masked_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val, mlir::Value &mask) -> void {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(val), unwrap(mask));
|
|
|
|
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// Block instruction
|
|
|
|
// Block instruction
|
|
|
|
.def("create_reshape", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
|
|
|
|
.def("create_reshape", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto argType = unwrap(arg).getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
auto argType = arg.getType().dyn_cast<mlir::RankedTensorType>().getElementType();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::ReshapeOp>(
|
|
|
|
return self.create<mlir::triton::ReshapeOp>(
|
|
|
|
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg), self.getI64ArrayAttr(shape)
|
|
|
|
loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape)
|
|
|
|
)));
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_cat", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
|
|
|
.def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto lhsType = unwrap(lhs).getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
auto lhsType = lhs.getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
auto rhsType = unwrap(rhs).getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
auto rhsType = rhs.getType().dyn_cast<mlir::RankedTensorType>();
|
|
|
|
if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1))
|
|
|
|
if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1))
|
|
|
|
throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs");
|
|
|
|
throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs");
|
|
|
|
std::vector<int64_t> shape {lhsType.getShape()[0] + rhsType.getShape()[0]};
|
|
|
|
std::vector<int64_t> shape {lhsType.getShape()[0] + rhsType.getShape()[0]};
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::CatOp>(
|
|
|
|
return self.create<mlir::triton::CatOp>(
|
|
|
|
loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), unwrap(lhs), unwrap(rhs)
|
|
|
|
loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), lhs, rhs
|
|
|
|
)));
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
|
|
|
|
.def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto argType = unwrap(arg).getType();
|
|
|
|
// TODO: should be scalar type here
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::BroadcastOp>(
|
|
|
|
auto argType = arg.getType();
|
|
|
|
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg)
|
|
|
|
return self.create<mlir::triton::BroadcastOp>(
|
|
|
|
)));
|
|
|
|
loc, mlir::RankedTensorType::get(shape, argType), arg
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
|
|
|
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
|
|
|
|
auto argType = arg.getType();
|
|
|
|
|
|
|
|
return self.create<mlir::triton::BroadcastOp>(
|
|
|
|
|
|
|
|
loc, mlir::RankedTensorType::get(shape, argType), arg
|
|
|
|
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// // atomic
|
|
|
|
// // atomic
|
|
|
|
.def("create_atomic_cas", [](mlir::OpBuilder &self, MlirValue &ptr,
|
|
|
|
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
|
|
|
|
MlirValue &cmp, MlirValue &val) -> MlirValue {
|
|
|
|
mlir::Value &cmp, mlir::Value &val) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>();
|
|
|
|
auto ptrType = ptr.getType().dyn_cast<mlir::triton::PointerType>();
|
|
|
|
mlir::Type dstType = ptrType.getPointeeType();
|
|
|
|
mlir::Type dstType = ptrType.getPointeeType();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::AtomicCASOp>(
|
|
|
|
return self.create<mlir::triton::AtomicCASOp>(
|
|
|
|
loc, dstType, unwrap(ptr), unwrap(cmp), unwrap(val)
|
|
|
|
loc, dstType, ptr, cmp, val
|
|
|
|
)));
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
|
|
|
|
.def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp,
|
|
|
|
MlirValue &ptr, MlirValue &val, MlirValue &mask) -> MlirValue {
|
|
|
|
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto ptrType = unwrap(ptr).getType().dyn_cast<mlir::triton::PointerType>();
|
|
|
|
auto ptrType = ptr.getType().dyn_cast<mlir::triton::PointerType>();
|
|
|
|
mlir::Type dstType = ptrType.getPointeeType();
|
|
|
|
mlir::Type dstType = ptrType.getPointeeType();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::AtomicRMWOp>(
|
|
|
|
return self.create<mlir::triton::AtomicRMWOp>(
|
|
|
|
loc, dstType, rmwOp, unwrap(ptr), unwrap(val), unwrap(mask)
|
|
|
|
loc, dstType, rmwOp, ptr, val, mask
|
|
|
|
)));
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
// Built-in instruction
|
|
|
|
// Built-in instruction
|
|
|
|
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
|
|
|
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::GetProgramIdOp>(
|
|
|
|
return self.create<mlir::triton::GetProgramIdOp>(
|
|
|
|
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
|
|
|
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
|
|
|
)));
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
|
|
|
.def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::GetNumProgramsOp>(
|
|
|
|
return self.create<mlir::triton::GetNumProgramsOp>(
|
|
|
|
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
|
|
|
loc, self.getI32Type(), self.getI32IntegerAttr(axis)
|
|
|
|
)));
|
|
|
|
);
|
|
|
|
})
|
|
|
|
})
|
|
|
|
.def("create_dot", [](mlir::OpBuilder &self, MlirValue &a, MlirValue &b, MlirValue &c) -> MlirValue {
|
|
|
|
.def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value {
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
auto loc = self.getUnknownLoc();
|
|
|
|
return wrap(mlir::Value(self.create<mlir::triton::DotOp>(
|
|
|
|
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c);
|
|
|
|
loc, unwrap(c).getType(), unwrap(a), unwrap(b), unwrap(c)
|
|
|
|
|
|
|
|
)));
|
|
|
|
|
|
|
|
})
|
|
|
|
})
|
|
|
|
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
|
|
|
// .def("create_exp", &ir::builder::create_exp, ret::reference)
|
|
|
|
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
|
|
|
// .def("create_cos", &ir::builder::create_cos, ret::reference)
|
|
|
|