Replace MlirOperation with MlirValue
This commit is contained in:
@@ -643,10 +643,10 @@ void init_triton_ir(py::module &&m) {
|
||||
// // .def("get", &ir::undef_value::get, ret::reference);
|
||||
|
||||
py::class_<MlirType>(m, "type")
|
||||
.def("is_integer", [](MlirType &self) {
|
||||
.def("is_integer", [](MlirType &self) -> bool {
|
||||
return mlirTypeIsAInteger(self);
|
||||
})
|
||||
.def("is_fp16", [](MlirType &self) {
|
||||
.def("is_fp16", [](MlirType &self) -> bool {
|
||||
return mlirTypeIsABF16(self);
|
||||
})
|
||||
;
|
||||
@@ -669,7 +669,13 @@ void init_triton_ir(py::module &&m) {
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<MlirValue>(m, "value")
|
||||
;
|
||||
|
||||
py::class_<MlirBlock>(m, "block")
|
||||
.def("arg", [](MlirBlock &self, int index) -> MlirValue {
|
||||
return wrap(unwrap(self)->getArgument(index));
|
||||
})
|
||||
;
|
||||
|
||||
// py::class_<mlir::triton::Float8Type>(m, "float8_type")
|
||||
@@ -756,7 +762,12 @@ void init_triton_ir(py::module &&m) {
|
||||
// Use arith.ConstantOp to create constants
|
||||
// // Constants
|
||||
// .def("get_int1", &ir::builder::get_int1, ret::reference)
|
||||
// .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
|
||||
.def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, v, self.getI32Type()
|
||||
)));
|
||||
}, ret::reference)
|
||||
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||
// .def("get_uint64", &ir::builder::get_int64, ret::reference)
|
||||
@@ -801,6 +812,11 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF64Type());
|
||||
}, ret::reference)
|
||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType {
|
||||
return wrap(
|
||||
mlir::triton::PointerType::get(unwrap(type))
|
||||
);
|
||||
})
|
||||
.def("get_function_ty", [](mlir::OpBuilder &self,
|
||||
std::vector<MlirType> inTypes,
|
||||
std::vector<MlirType> outTypes) -> MlirType {
|
||||
@@ -829,14 +845,18 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_scf_while")
|
||||
|
||||
// miscellious
|
||||
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end){
|
||||
.def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type());
|
||||
return wrap(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) {
|
||||
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::triton::GetProgramIdOp>(loc, self.getI32Type(), axis))
|
||||
);
|
||||
})
|
||||
|
||||
// // Cast instructions
|
||||
@@ -853,43 +873,84 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_downcast", &ir::builder::create_downcast, ret::reference)
|
||||
// // Binary instructions
|
||||
// .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
|
||||
// .def("create_fmul", &ir::builder::create_fmul, ret::reference)
|
||||
// .def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
|
||||
// .def("create_frem", &ir::builder::create_frem, ret::reference)
|
||||
// .def("create_fadd", &ir::builder::create_fadd, ret::reference)
|
||||
// .def("create_fsub", &ir::builder::create_fsub, ret::reference)
|
||||
.def("create_mul", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::MulFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
.def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::DivFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
.def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::RemFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
.def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::AddFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
.def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::SubFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
.def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
// Check lhs & rhs have single result (?)
|
||||
return wrap(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_sdiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_udiv", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_srem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_urem", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_add", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_sub", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
.def("create_shl", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)).getOperation());
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
// .def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
// py::arg("lhs"), py::arg("rhs"),
|
||||
@@ -897,88 +958,91 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_ashr", &ir::builder::create_ashr, ret::reference,
|
||||
// py::arg("lhs"), py::arg("rhs"),
|
||||
// py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
// // GEP
|
||||
// .def("create_gep", [](mlir::OpBuilder &self, MlirOperation &ptr, MlirOperation &offset) -> MlirOperation {
|
||||
// auto loc = self.getUnknownLoc();
|
||||
// }, ret::reference)
|
||||
// GEP
|
||||
.def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::triton::GEPOp>(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset)))
|
||||
);
|
||||
}, ret::reference)
|
||||
// Comparison (int)
|
||||
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::sle,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::slt,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::sge,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::sgt,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ule,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::uge,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ugt,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::eq,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpIOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ne,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
// Comparison (float)
|
||||
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirOperation &lhs, MlirOperation &rhs) -> MlirOperation {
|
||||
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(self.create<mlir::arith::CmpFOp>(
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::OLT,
|
||||
unwrap(lhs)->getResult(0), unwrap(rhs)->getResult(0)
|
||||
).getOperation());
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
// .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
|
||||
// .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
|
||||
@@ -996,15 +1060,29 @@ void init_triton_ir(py::module &&m) {
|
||||
// .def("create_xor", &ir::builder::create_xor, ret::reference)
|
||||
// .def("create_or", &ir::builder::create_or, ret::reference)
|
||||
// // Input/Output
|
||||
// .def("create_load", &ir::builder::create_load, ret::reference)
|
||||
// .def("create_store", &ir::builder::create_store, ret::reference)
|
||||
.def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::triton::LoadOp>(loc, unwrap(ptrs))
|
||||
));
|
||||
}, ret::reference)
|
||||
.def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::StoreOp>(loc, unwrap(ptrs), unwrap(value));
|
||||
}, ret::reference)
|
||||
// .def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
|
||||
// .def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
|
||||
// // Block instruction
|
||||
// .def("create_splat", &ir::builder::create_splat, ret::reference)
|
||||
// .def("create_reshape", &ir::builder::create_reshape, ret::reference)
|
||||
// .def("create_cat", &ir::builder::create_cat, ret::reference)
|
||||
// .def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
|
||||
.def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector<int64_t> &shape) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = unwrap(arg).getType();
|
||||
return wrap(mlir::Value(self.create<mlir::triton::BroadcastOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg)
|
||||
)));
|
||||
}, ret::reference)
|
||||
// // atomic
|
||||
// .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
|
||||
// .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
|
||||
|
Reference in New Issue
Block a user