Let python manage created objects
This commit is contained in:
@@ -661,7 +661,7 @@ void init_triton_ir(py::module &&m) {
|
||||
throw std::runtime_error("Only FuncOp can call add_entry_block");
|
||||
} else
|
||||
throw std::runtime_error("Unknown error");
|
||||
}, ret::reference) // this should be automatic?
|
||||
}) // this should be automatic?
|
||||
.def("dump", [](MlirOperation &self) -> void {
|
||||
unwrap(self)->dump();
|
||||
})
|
||||
@@ -765,7 +765,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return wrap(mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, v, self.getI32Type()
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
// .def("get_uint32", &ir::builder::get_int32, ret::reference)
|
||||
// .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
|
||||
// .def("get_uint64", &ir::builder::get_int64, ret::reference)
|
||||
@@ -776,40 +776,40 @@ void init_triton_ir(py::module &&m) {
|
||||
// Types
|
||||
.def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType {
|
||||
return wrap(self.getNoneType());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_int1_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI1Type());
|
||||
}, ret::reference) // or ret::copy?
|
||||
}) // or ret::copy?
|
||||
.def("get_int8_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI8Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_int16_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getType<mlir::IntegerType>(16));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_int32_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI32Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_int64_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getI64Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_fp8_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getType<mlir::triton::Float8Type>());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_bf8_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getType<mlir::triton::BFloat8Type>());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_half_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF16Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_bf16_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getBF16Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_float_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF32Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType {
|
||||
return wrap(self.getF64Type());
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType {
|
||||
return wrap(
|
||||
mlir::triton::PointerType::get(unwrap(type))
|
||||
@@ -823,7 +823,7 @@ void init_triton_ir(py::module &&m) {
|
||||
(void)unwrapList(inTypes.size(), inTypes.data(), inputsTypeList);
|
||||
(void)unwrapList(outTypes.size(), outTypes.data(), resultsTypeList);
|
||||
return wrap(self.getFunctionType(inputsTypeList, resultsTypeList));
|
||||
}, ret::reference)
|
||||
})
|
||||
|
||||
// Ops
|
||||
.def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> MlirOperation {
|
||||
@@ -833,7 +833,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return wrap(self.create<mlir::FuncOp>(loc, name, funcTy));
|
||||
}
|
||||
throw std::runtime_error("invalid function type");
|
||||
}, ret::reference)
|
||||
})
|
||||
// // Structured control flow
|
||||
// .def("create_scf_for", [](mlir::OpBuilder &self) {
|
||||
// return self.create<mlir::scf::ForOp>(/*fill this*/);
|
||||
@@ -849,7 +849,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::triton::MakeRangeOp>(loc, retType, start, end))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
@@ -876,80 +876,80 @@ void init_triton_ir(py::module &&m) {
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::MulFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::DivFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::RemFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::AddFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(
|
||||
self.create<mlir::arith::SubFOp>(loc, unwrap(lhs), unwrap(rhs))
|
||||
));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
// Check lhs & rhs have single result (?)
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::MulIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::DivSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::DivUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::RemSIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::RemUIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::AddIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::SubIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::arith::ShLIOp>(loc, unwrap(lhs), unwrap(rhs)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
// .def("create_lshr", &ir::builder::create_lshr, ret::reference,
|
||||
// py::arg("lhs"), py::arg("rhs"),
|
||||
// py::arg("has_nuw")=false, py::arg("has_nsw")=false)
|
||||
@@ -962,7 +962,7 @@ void init_triton_ir(py::module &&m) {
|
||||
return wrap(
|
||||
mlir::Value(self.create<mlir::triton::GEPOp>(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset)))
|
||||
);
|
||||
}, ret::reference)
|
||||
})
|
||||
// Comparison (int)
|
||||
.def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -970,70 +970,70 @@ void init_triton_ir(py::module &&m) {
|
||||
loc, mlir::arith::CmpIPredicate::sle,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::slt,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::sge,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::sgt,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ule,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ult,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::uge,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ugt,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::eq,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpIOp>(
|
||||
loc, mlir::arith::CmpIPredicate::ne,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
// Comparison (float)
|
||||
.def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -1041,103 +1041,103 @@ void init_triton_ir(py::module &&m) {
|
||||
loc, mlir::arith::CmpFPredicate::OLT,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpOGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::OGT,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpOLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::OLE,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpOGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::OGE,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpOEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::OEQ,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpONE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::ONE,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::ULT,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::UGT,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::ULE,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::UGE,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpUEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::UEQ,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_fcmpUNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::CmpFOp>(
|
||||
loc, mlir::arith::CmpFPredicate::UNE,
|
||||
unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
// // Logical
|
||||
.def("create_and", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::AndIOp>(
|
||||
loc, unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_xor", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::XOrIOp>(
|
||||
loc, unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
.def("create_or", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return wrap(mlir::Value(self.create<mlir::arith::OrIOp>(
|
||||
loc, unwrap(lhs), unwrap(rhs)
|
||||
)));
|
||||
}, ret::reference)
|
||||
})
|
||||
// // Input/Output
|
||||
.def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue {
|
||||
auto loc = self.getUnknownLoc();
|
||||
|
Reference in New Issue
Block a user