diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 7e4a08209..fb4e6455b 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -35,15 +35,6 @@ class global_value; class alloc_const; /* Module */ -struct scope { -public: - const std::map& get_values() { return values; } - void set_type(const std::string& name, ir::type* ty) { types[name] = ty; } - ir::type* get_type(const std::string& name) { return types.at(name); } -private: - std::map types; - std::map values; -}; class module { typedef std::pair val_key_t; @@ -74,8 +65,11 @@ public: void set_const(const std::string& name); void set_continue_fn(std::function fn); // Getters + const std::map& get_values() { return values_; } + void set_values(const std::map& values) { values_ = values; } value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name); + void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; } const std::string& get_name(); std::function get_continue_fn(); // Seal block -- no more predecessors will be added @@ -84,10 +78,6 @@ public: const functions_list_t &get_function_list() const { return functions_; } functions_list_t &get_function_list() { return functions_; } function *get_or_insert_function(const std::string &name, function_type *ty); - // Scope - void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } - void pop_scope() { scopes_.pop(); } - scope& get_scope() { return scopes_.top(); } // Const allocation void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } const std::vector& allocs() { return allocs_; } @@ -101,7 +91,7 @@ private: std::string name_; builder& builder_; std::map values_; - std::map types_; + std::map types_; std::set const_; std::set sealed_blocks_; std::map> incomplete_phis_; @@ -109,7 +99,6 @@ private: symbols_map_t symbols_; std::function continue_fn_; std::map current_phi_; - std::stack scopes_; std::vector allocs_; std::map globals_; std::map metadatas_; diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 626065224..33b39de3a 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -94,7 +94,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block ir::value *result; bool is_const = const_.find(name) != const_.end(); auto &preds = block->get_predecessors(); - ir::type *ty = get_scope().get_type(name); + ir::type *ty = types_.at(name); if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ incomplete_phis_[block][name] = make_phi(ty, 1, block); result = (ir::value*)incomplete_phis_[block][name]; diff --git a/python/src/triton.cc b/python/src/triton.cc index 45d0e2704..c07fade36 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -228,20 +228,15 @@ void init_triton_ir(py::module &&m) { .def_property_readonly("shape", &ir::block_type::get_shapes) .def_property_readonly("numel", &ir::type::get_tile_num_elements); - py::class_(m, "scope") - .def(py::init<>()) - .def_property_readonly("values", &ir::scope::get_values) - .def("set_type", &ir::scope::set_type); - py::class_(m, "module") .def(py::init()) .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference) - .def("add_new_scope", &ir::module::add_new_scope, ret::reference) .def("seal_block", &ir::module::seal_block) .def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value) + .def("set_type", &ir::module::set_type) .def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference) - .def("pop_scope", &ir::module::pop_scope) - .def_property_readonly("scope", &ir::module::get_scope, ret::reference) + .def("get_values", &ir::module::get_values, ret::reference) + .def("set_values", &ir::module::set_values) .def_property_readonly("builder", &ir::module::get_builder, ret::reference); using eattr = ir::attribute_kind_t; diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 9bdec6129..014104b99 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -38,21 +38,17 @@ class CodeGenerator(ast.NodeVisitor): value = triton.language.block(value) if isinstance(value, triton.language.block): self.module.set_value(name, value.handle) - self.module.scope.set_type(name, value.handle.type) + self.module.set_type(name, value.handle.type) self.lscope[name] = value def is_triton_object(self, value): return isinstance(value, triton.language.block) - def visit_compound_statement(self, stmts, add_scope=False): - if add_scope: - self.module.add_new_scope() + def visit_compound_statement(self, stmts): for stmt in stmts: self.last_ret = self.visit(stmt) if isinstance(stmt, ast.Return): break - if add_scope: - self.module.pop_scope() return stmts and isinstance(stmt, ast.Return) def __init__(self, context, prototype, gscope, attributes, constants, kwargs): @@ -75,9 +71,7 @@ class CodeGenerator(ast.NodeVisitor): } def visit_Module(self, node): - self.module.add_new_scope() ast.NodeVisitor.generic_visit(self, node) - self.module.pop_scope() def visit_List(self, node): ctx = self.visit(node.ctx) @@ -117,14 +111,14 @@ class CodeGenerator(ast.NodeVisitor): for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) if inline: - self.visit_compound_statement(node.body, add_scope=True) + self.visit_compound_statement(node.body) return self.last_ret else: entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) self.module.seal_block(entry) self.builder.set_insert_block(entry) # visit function body - self.visit_compound_statement(node.body, add_scope=True) + self.visit_compound_statement(node.body) # finalize function self.builder.ret_void() @@ -216,13 +210,13 @@ class CodeGenerator(ast.NodeVisitor): else: self.builder.cond_br(cond.handle, then_bb, endif_bb) self.builder.set_insert_block(then_bb) - is_terminator = self.visit_compound_statement(node.body, add_scope=True) + is_terminator = self.visit_compound_statement(node.body) # TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) if else_bb: self.builder.set_insert_block(else_bb) - is_terminator = self.visit_compound_statement(node.orelse, add_scope=True) + is_terminator = self.visit_compound_statement(node.orelse) #TODO: last statement is a terminator? if not is_terminator: self.builder.br(endif_bb) @@ -289,7 +283,7 @@ class CodeGenerator(ast.NodeVisitor): continue_fn() self.builder.set_insert_block(loop_bb) - self.visit_compound_statement(node.body, add_scope=True) + self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insert_block() self.module.seal_block(stop_bb) @@ -344,7 +338,7 @@ class CodeGenerator(ast.NodeVisitor): cond = build_cond() self.builder.cond_br(cond.handle, loop_bb, next_bb) self.builder.set_insert_block(loop_bb) - self.visit_compound_statement(node.body, add_scope=True) + self.visit_compound_statement(node.body) # TODO: handle case where body breaks control flow continue_fn() stop_bb = self.builder.get_insert_block() @@ -643,7 +637,12 @@ class JITFunction: def __call__(self, *args, generator: CodeGenerator, **meta): try: - return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + lscope = generator.lscope.copy() + values = generator.module.get_values().copy() + ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + generator.lscope = lscope + generator.module.set_values(values) + return ret except Exception as e: node = generator.last_node if node is None or isinstance(e, (NotImplementedError, CompilationError)):