ForOp's SSA construction
This commit is contained in:
@@ -195,7 +195,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(_names) == 1
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
print(f'visit_Assign({names}, {values})')
|
||||
if not isinstance(names, tuple):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
@@ -440,34 +439,30 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
|
||||
print(f'names: {names}')
|
||||
print("After insert block body")
|
||||
self.module.dump()
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
print("After create for op")
|
||||
self.module.dump()
|
||||
# TODO: revisit blocks & regions of ForOp
|
||||
block.move_before(for_op.get_body())
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body())
|
||||
# FIXME: the body should be a region (?)
|
||||
# FIXME: this won't work for nested control flow
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
print("After create yield op")
|
||||
self.module.dump()
|
||||
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "(For developer) Should use region here"
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
print("After restoring insert point")
|
||||
self.module.dump()
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# ForOp defines new values
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = for_op.get_result(i)
|
||||
self.local_defs[name] = for_op.get_result(i)
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
|
||||
for stmt in node.orelse:
|
||||
assert False
|
||||
assert False, "Don't know what to do with else after for"
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
|
||||
def visit_Slice(self, node):
|
||||
|
Reference in New Issue
Block a user