[PYTHON] Fixed bug in scoping mechanism (#111)
Inline functions didn't restore scope of parents. Also some control flow structure still had the scoping semantics of C++
This commit is contained in:
committed by
Philippe Tillet
parent
9f30af76fb
commit
325ee38581
@@ -35,15 +35,6 @@ class global_value;
|
|||||||
class alloc_const;
|
class alloc_const;
|
||||||
|
|
||||||
/* Module */
|
/* Module */
|
||||||
struct scope {
|
|
||||||
public:
|
|
||||||
const std::map<std::string, ir::value*>& get_values() { return values; }
|
|
||||||
void set_type(const std::string& name, ir::type* ty) { types[name] = ty; }
|
|
||||||
ir::type* get_type(const std::string& name) { return types.at(name); }
|
|
||||||
private:
|
|
||||||
std::map<std::string, ir::type*> types;
|
|
||||||
std::map<std::string, ir::value*> values;
|
|
||||||
};
|
|
||||||
|
|
||||||
class module {
|
class module {
|
||||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||||
@@ -74,8 +65,11 @@ public:
|
|||||||
void set_const(const std::string& name);
|
void set_const(const std::string& name);
|
||||||
void set_continue_fn(std::function<ir::value*()> fn);
|
void set_continue_fn(std::function<ir::value*()> fn);
|
||||||
// Getters
|
// Getters
|
||||||
|
const std::map<val_key_t, value*>& get_values() { return values_; }
|
||||||
|
void set_values(const std::map<val_key_t, value*>& values) { values_ = values; }
|
||||||
value *get_value(const std::string& name, basic_block* block);
|
value *get_value(const std::string& name, basic_block* block);
|
||||||
value *get_value(const std::string& name);
|
value *get_value(const std::string& name);
|
||||||
|
void set_type(const std::string& name, ir::type* ty) { types_[name] = ty; }
|
||||||
const std::string& get_name();
|
const std::string& get_name();
|
||||||
std::function<ir::value*()> get_continue_fn();
|
std::function<ir::value*()> get_continue_fn();
|
||||||
// Seal block -- no more predecessors will be added
|
// Seal block -- no more predecessors will be added
|
||||||
@@ -84,10 +78,6 @@ public:
|
|||||||
const functions_list_t &get_function_list() const { return functions_; }
|
const functions_list_t &get_function_list() const { return functions_; }
|
||||||
functions_list_t &get_function_list() { return functions_; }
|
functions_list_t &get_function_list() { return functions_; }
|
||||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||||
// Scope
|
|
||||||
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
|
|
||||||
void pop_scope() { scopes_.pop(); }
|
|
||||||
scope& get_scope() { return scopes_.top(); }
|
|
||||||
// Const allocation
|
// Const allocation
|
||||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||||
@@ -101,7 +91,7 @@ private:
|
|||||||
std::string name_;
|
std::string name_;
|
||||||
builder& builder_;
|
builder& builder_;
|
||||||
std::map<val_key_t, value*> values_;
|
std::map<val_key_t, value*> values_;
|
||||||
std::map<val_key_t, type*> types_;
|
std::map<std::string, type*> types_;
|
||||||
std::set<std::string> const_;
|
std::set<std::string> const_;
|
||||||
std::set<basic_block*> sealed_blocks_;
|
std::set<basic_block*> sealed_blocks_;
|
||||||
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
std::map<basic_block*, std::map<std::string, phi_node*>> incomplete_phis_;
|
||||||
@@ -109,7 +99,6 @@ private:
|
|||||||
symbols_map_t symbols_;
|
symbols_map_t symbols_;
|
||||||
std::function<ir::value*()> continue_fn_;
|
std::function<ir::value*()> continue_fn_;
|
||||||
std::map<value*, value**> current_phi_;
|
std::map<value*, value**> current_phi_;
|
||||||
std::stack<scope> scopes_;
|
|
||||||
std::vector<ir::alloc_const*> allocs_;
|
std::vector<ir::alloc_const*> allocs_;
|
||||||
std::map<std::string, ir::value*> globals_;
|
std::map<std::string, ir::value*> globals_;
|
||||||
std::map<std::string, md_pair_t> metadatas_;
|
std::map<std::string, md_pair_t> metadatas_;
|
||||||
|
@@ -94,7 +94,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
|||||||
ir::value *result;
|
ir::value *result;
|
||||||
bool is_const = const_.find(name) != const_.end();
|
bool is_const = const_.find(name) != const_.end();
|
||||||
auto &preds = block->get_predecessors();
|
auto &preds = block->get_predecessors();
|
||||||
ir::type *ty = get_scope().get_type(name);
|
ir::type *ty = types_.at(name);
|
||||||
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||||
result = (ir::value*)incomplete_phis_[block][name];
|
result = (ir::value*)incomplete_phis_[block][name];
|
||||||
|
@@ -228,20 +228,15 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
.def_property_readonly("shape", &ir::block_type::get_shapes)
|
||||||
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
.def_property_readonly("numel", &ir::type::get_tile_num_elements);
|
||||||
|
|
||||||
py::class_<ir::scope>(m, "scope")
|
|
||||||
.def(py::init<>())
|
|
||||||
.def_property_readonly("values", &ir::scope::get_values)
|
|
||||||
.def("set_type", &ir::scope::set_type);
|
|
||||||
|
|
||||||
py::class_<ir::module>(m, "module")
|
py::class_<ir::module>(m, "module")
|
||||||
.def(py::init<std::string, ir::builder &>())
|
.def(py::init<std::string, ir::builder &>())
|
||||||
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
.def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
|
||||||
.def("add_new_scope", &ir::module::add_new_scope, ret::reference)
|
|
||||||
.def("seal_block", &ir::module::seal_block)
|
.def("seal_block", &ir::module::seal_block)
|
||||||
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
.def("set_value", (void (ir::module::*)(const std::string &, ir::value *)) & ir::module::set_value)
|
||||||
|
.def("set_type", &ir::module::set_type)
|
||||||
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
.def("get_value", (ir::value * (ir::module::*)(const std::string &)) & ir::module::get_value, ret::reference)
|
||||||
.def("pop_scope", &ir::module::pop_scope)
|
.def("get_values", &ir::module::get_values, ret::reference)
|
||||||
.def_property_readonly("scope", &ir::module::get_scope, ret::reference)
|
.def("set_values", &ir::module::set_values)
|
||||||
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
.def_property_readonly("builder", &ir::module::get_builder, ret::reference);
|
||||||
|
|
||||||
using eattr = ir::attribute_kind_t;
|
using eattr = ir::attribute_kind_t;
|
||||||
|
@@ -38,21 +38,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
value = triton.language.block(value)
|
value = triton.language.block(value)
|
||||||
if isinstance(value, triton.language.block):
|
if isinstance(value, triton.language.block):
|
||||||
self.module.set_value(name, value.handle)
|
self.module.set_value(name, value.handle)
|
||||||
self.module.scope.set_type(name, value.handle.type)
|
self.module.set_type(name, value.handle.type)
|
||||||
self.lscope[name] = value
|
self.lscope[name] = value
|
||||||
|
|
||||||
def is_triton_object(self, value):
|
def is_triton_object(self, value):
|
||||||
return isinstance(value, triton.language.block)
|
return isinstance(value, triton.language.block)
|
||||||
|
|
||||||
def visit_compound_statement(self, stmts, add_scope=False):
|
def visit_compound_statement(self, stmts):
|
||||||
if add_scope:
|
|
||||||
self.module.add_new_scope()
|
|
||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
self.last_ret = self.visit(stmt)
|
self.last_ret = self.visit(stmt)
|
||||||
if isinstance(stmt, ast.Return):
|
if isinstance(stmt, ast.Return):
|
||||||
break
|
break
|
||||||
if add_scope:
|
|
||||||
self.module.pop_scope()
|
|
||||||
return stmts and isinstance(stmt, ast.Return)
|
return stmts and isinstance(stmt, ast.Return)
|
||||||
|
|
||||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||||
@@ -75,9 +71,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def visit_Module(self, node):
|
def visit_Module(self, node):
|
||||||
self.module.add_new_scope()
|
|
||||||
ast.NodeVisitor.generic_visit(self, node)
|
ast.NodeVisitor.generic_visit(self, node)
|
||||||
self.module.pop_scope()
|
|
||||||
|
|
||||||
def visit_List(self, node):
|
def visit_List(self, node):
|
||||||
ctx = self.visit(node.ctx)
|
ctx = self.visit(node.ctx)
|
||||||
@@ -117,14 +111,14 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||||
self.set_value(arg_name, arg_value)
|
self.set_value(arg_name, arg_value)
|
||||||
if inline:
|
if inline:
|
||||||
self.visit_compound_statement(node.body, add_scope=True)
|
self.visit_compound_statement(node.body)
|
||||||
return self.last_ret
|
return self.last_ret
|
||||||
else:
|
else:
|
||||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||||
self.module.seal_block(entry)
|
self.module.seal_block(entry)
|
||||||
self.builder.set_insert_block(entry)
|
self.builder.set_insert_block(entry)
|
||||||
# visit function body
|
# visit function body
|
||||||
self.visit_compound_statement(node.body, add_scope=True)
|
self.visit_compound_statement(node.body)
|
||||||
# finalize function
|
# finalize function
|
||||||
self.builder.ret_void()
|
self.builder.ret_void()
|
||||||
|
|
||||||
@@ -216,13 +210,13 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
else:
|
else:
|
||||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||||
self.builder.set_insert_block(then_bb)
|
self.builder.set_insert_block(then_bb)
|
||||||
is_terminator = self.visit_compound_statement(node.body, add_scope=True)
|
is_terminator = self.visit_compound_statement(node.body)
|
||||||
# TODO: last statement is a terminator?
|
# TODO: last statement is a terminator?
|
||||||
if not is_terminator:
|
if not is_terminator:
|
||||||
self.builder.br(endif_bb)
|
self.builder.br(endif_bb)
|
||||||
if else_bb:
|
if else_bb:
|
||||||
self.builder.set_insert_block(else_bb)
|
self.builder.set_insert_block(else_bb)
|
||||||
is_terminator = self.visit_compound_statement(node.orelse, add_scope=True)
|
is_terminator = self.visit_compound_statement(node.orelse)
|
||||||
#TODO: last statement is a terminator?
|
#TODO: last statement is a terminator?
|
||||||
if not is_terminator:
|
if not is_terminator:
|
||||||
self.builder.br(endif_bb)
|
self.builder.br(endif_bb)
|
||||||
@@ -289,7 +283,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
continue_fn()
|
continue_fn()
|
||||||
self.builder.set_insert_block(loop_bb)
|
self.builder.set_insert_block(loop_bb)
|
||||||
self.visit_compound_statement(node.body, add_scope=True)
|
self.visit_compound_statement(node.body)
|
||||||
continue_fn()
|
continue_fn()
|
||||||
stop_bb = self.builder.get_insert_block()
|
stop_bb = self.builder.get_insert_block()
|
||||||
self.module.seal_block(stop_bb)
|
self.module.seal_block(stop_bb)
|
||||||
@@ -344,7 +338,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = build_cond()
|
cond = build_cond()
|
||||||
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||||
self.builder.set_insert_block(loop_bb)
|
self.builder.set_insert_block(loop_bb)
|
||||||
self.visit_compound_statement(node.body, add_scope=True)
|
self.visit_compound_statement(node.body)
|
||||||
# TODO: handle case where body breaks control flow
|
# TODO: handle case where body breaks control flow
|
||||||
continue_fn()
|
continue_fn()
|
||||||
stop_bb = self.builder.get_insert_block()
|
stop_bb = self.builder.get_insert_block()
|
||||||
@@ -643,7 +637,12 @@ class JITFunction:
|
|||||||
|
|
||||||
def __call__(self, *args, generator: CodeGenerator, **meta):
|
def __call__(self, *args, generator: CodeGenerator, **meta):
|
||||||
try:
|
try:
|
||||||
return generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
lscope = generator.lscope.copy()
|
||||||
|
values = generator.module.get_values().copy()
|
||||||
|
ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args)
|
||||||
|
generator.lscope = lscope
|
||||||
|
generator.module.set_values(values)
|
||||||
|
return ret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
node = generator.last_node
|
node = generator.last_node
|
||||||
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
|
||||||
|
Reference in New Issue
Block a user