Fix visit_While issues

This commit is contained in:
Yan Da
2022-04-10 16:16:13 +08:00
parent 19f81b7dea
commit fcbbb3c10e
4 changed files with 47 additions and 18 deletions

View File

@@ -293,6 +293,7 @@ class CodeGenerator(ast.NodeVisitor):
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
self.local_defs = {}
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(else_block)
@@ -303,7 +304,6 @@ class CodeGenerator(ast.NodeVisitor):
else_defs = {}
for name in then_defs:
if name in liveins:
# TODO: what if this is constexpr?
assert self.is_triton_tensor(then_defs[name])
assert self.is_triton_tensor(liveins[name])
else_defs[name] = liveins[name]
@@ -452,16 +452,16 @@ class CodeGenerator(ast.NodeVisitor):
self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields])
# update global uses in while_op
for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
# update global uses in while_op
for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
# WhileOp defines new values, update the symbol table (lscope, local_defs)
for i, name in enumerate(names):
new_def = triton.language.core.tensor(while_op.get_result(i), ret_types[i])
self.lscope[name] = new_def
self.local_defs[name] = new_def
for stmt in node.orelse:
assert False, "Not implemented"