[LANG] Added support for device functions (#484)

This commit is contained in:
Philippe Tillet
2022-04-03 20:58:16 -07:00
committed by GitHub
parent e85c7a7fc7
commit 2bed6fc850
39 changed files with 1213 additions and 379 deletions

View File

@@ -659,6 +659,8 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
.def("get_int_width", &ir::type::get_integer_bitwidth)
.def("is_floating", &ir::type::is_floating_point_ty)
.def("is_block", &ir::type::is_block_ty)
.def("make_ptr", &ir::pointer_type::get, ret::reference)
@@ -695,6 +697,7 @@ void init_triton_ir(py::module &&m) {
.def("is_uint16", [](ir::type *self) { return self->is_integer_ty(16, ir::signedness::UNSIGNED); })
.def("is_uint32", [](ir::type *self) { return self->is_integer_ty(32, ir::signedness::UNSIGNED); })
.def("is_uint64", [](ir::type *self) { return self->is_integer_ty(64, ir::signedness::UNSIGNED); })
.def("is_struct", &ir::type::is_struct_ty)
.def("repr", &ir::type::repr)
.def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
@@ -704,23 +707,37 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
.def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference);
py::class_<ir::function_type, ir::type>(m, "function_type");
py::class_<ir::function_type, ir::type>(m, "function_type")
.def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
.def_property_readonly("arg_tys", [](ir::function_type* self){
return std::vector<ir::type*>(self->params_begin(), self->params_end());
});
py::class_<ir::integer_type, ir::type>(m, "integer_type");
py::class_<ir::block_type, ir::type>(m, "block_type")
.def_property_readonly("shape", &ir::block_type::get_shapes)
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
py::class_<ir::struct_type, ir::type>(m, "struct_type")
.def("get", &ir::struct_type::get, ret::reference)
.def_property_readonly("num_types", &ir::struct_type::get_num_types);
py::class_<ir::value_constructor>(m, "value_constructor")
.def(py::init<ir::builder&>())
.def("seal_block", &ir::value_constructor::seal_block)
.def("set_value", (void (ir::value_constructor::*)(const std::string &, ir::value *)) & ir::value_constructor::set_value)
.def("set_type", &ir::value_constructor::set_type)
.def("get_value", (ir::value * (ir::value_constructor::*)(const std::string &)) & ir::value_constructor::get_value, ret::reference)
.def("get_values", &ir::value_constructor::get_values, ret::reference)
.def("set_values", &ir::value_constructor::set_values);
py::class_<ir::module>(m, "module")
.def(py::init<std::string, ir::builder &>())
.def("has_function", &ir::module::has_function)
.def("get_function", &ir::module::get_function, ret::reference)
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
.def("seal_block", &ir::module::seal_block)
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
.def("set_type", &ir::module::set_type)
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
.def("get_values", &ir::module::get_values, ret::reference)
.def("set_values", &ir::module::set_values)
.def("get_types", &ir::module::get_types, ret::reference)
.def("set_types", &ir::module::set_types)
.def("reset_ret_ty", &ir::module::reset_ret_ty)
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
using eattr = ir::attribute_kind_t;
@@ -734,29 +751,45 @@ void init_triton_ir(py::module &&m) {
.value("not_implemented", eattr::not_implemented);
py::class_<ir::attribute>(m, "attribute")
.def(py::init<eattr, int>());
.def(py::init<eattr, int>())
.def_property_readonly("value", &ir::attribute::get_value);
py::class_<ir::function>(m, "function")
.def_property_readonly("args", &ir::function::args)
.def_property_readonly("attrs", &ir::function::attrs)
.def("add_attr", &ir::function::add_attr);
.def("set_is_kernel", &ir::function::set_is_kernel)
.def("add_attr", &ir::function::add_attr)
.def("has_attr", &ir::function::has_attr)
.def("get_attrs", &ir::function::get_attributes);
py::class_<ir::argument, ir::value>(m, "argument");
py::class_<ir::argument, ir::value>(m, "argument")
.def_property_readonly("parent", &ir::argument::get_parent, ret::reference)
.def_property_readonly("arg_no", &ir::argument::get_arg_no);
py::class_<ir::basic_block, ir::value>(m, "basic_block")
.def("create", &ir::basic_block::create, ret::reference)
.def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
.def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
py::class_<ir::builder::iterator>(m, "bb_iterator");
py::class_<ir::builder>(m, "builder", py::dynamic_attr())
.def(py::init<ir::context &>())
// getters
.def_property_readonly("context", &ir::builder::get_context, ret::reference)
// control flow
.def("call", &ir::builder::create_call, ret::reference)
.def("launch", &ir::builder::create_launch, ret::reference)
.def("br", &ir::builder::create_br, ret::reference)
.def("cond_br", &ir::builder::create_cond_br, ret::reference)
.def("ret_void", &ir::builder::create_ret_void, ret::reference)
.def("ret", &ir::builder::create_ret, ret::reference)
.def("get_insert_point", &ir::builder::get_insert_point)
.def("set_insert_point", (void (ir::builder::*)(ir::builder::iterator))&ir::builder::set_insert_point)
.def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
.def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
// struct
.def("insert_value", &ir::builder::create_insert_value, ret::reference)
.def("extract_value", &ir::builder::create_extract_value, ret::reference)
// constants
.def("get_int1", &ir::builder::get_int1, ret::reference)
.def("get_int32", &ir::builder::get_int32, ret::reference)