[TESTING] Added infrastructure for executing TTGIR program and test for layout conversions (#885)

This commit is contained in:
Philippe Tillet
2022-11-18 07:46:45 +01:00
committed by GitHub
parent 9ea6135eb5
commit dab4855bdf
6 changed files with 243 additions and 67 deletions

View File

@@ -163,7 +163,19 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::Type>(m, "type")
.def("is_integer", &mlir::Type::isInteger)
.def("is_fp16", &mlir::Type::isF16);
.def("is_fp16", &mlir::Type::isF16)
.def("__str__", [](mlir::Type &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return os.str();
});
py::class_<mlir::FunctionType>(m, "function_type")
.def("param_types", [](mlir::FunctionType &self) {
return std::vector<mlir::Type>(self.getInputs().begin(),
self.getInputs().end());
});
py::class_<mlir::Value>(m, "value")
.def("set_attr",
@@ -314,7 +326,14 @@ void init_triton_ir(py::module &&m) {
.def("get_function",
[](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
return self.lookupSymbol<mlir::FuncOp>(funcName);
});
})
.def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
llvm::SmallVector<mlir::FuncOp> funcs;
self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
if (funcs.size() != 1)
throw std::runtime_error("Expected a single function");
return funcs[0];
});
m.def(
"parse_mlir_module",
@@ -363,6 +382,7 @@ void init_triton_ir(py::module &&m) {
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def_property_readonly("type", &mlir::FuncOp::getType)
.def("reset_type", &mlir::FuncOp::setType);
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -1274,8 +1294,8 @@ void init_triton_ir(py::module &&m) {
void init_triton_translation(py::module &m) {
using ret = py::return_value_policy;
m.def("get_shared_memory_size", [](mlir::ModuleOp module) {
auto shared = module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
m.def("get_shared_memory_size", [](mlir::ModuleOp mod) {
auto shared = mod->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared");
return shared.getInt();
});