More python bindings

This commit is contained in:
Yan Da
2022-04-01 22:22:39 +08:00
parent 9dafa0e2e3
commit 61413b8a97
2 changed files with 54 additions and 18 deletions

View File

@@ -415,8 +415,10 @@ class CodeGenerator(ast.NodeVisitor):
ub = triton.language.core._to_tensor(ub, self.builder).handle
step = triton.language.core._to_tensor(step, self.builder).handle
loop_body = self.builder.create_block()
self.builder.set_insertion_point_to_start(loop_body)
insert_block = self.builder.get_insertion_block()
block = self.builder.create_block()
self.builder.set_insertion_point_to_start(block)
liveins = self.lscope.copy()
prev_defs = self.local_defs.copy()
@@ -425,20 +427,44 @@ class CodeGenerator(ast.NodeVisitor):
# visit loop body
self.visit_compound_statement(node.body)
# TODO: update insertion point
init_args = []
yields = []
names = []
for name in self.local_defs:
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name])
assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type:
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle)
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle)
for_op = self.builder.create_for_op(lb, ub, step, init_args)
# TODO: better way to do this?
names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
print(f'names: {names}')
print("After insert block body")
self.module.dump()
self.builder.set_insertion_point_to_end(insert_block)
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
print("After create for op")
self.module.dump()
# TODO: revisit blocks & regions of ForOp
block.move_before(for_op.get_body())
self.builder.set_insertion_point_to_end(for_op.get_body())
self.builder.create_yield_op([y.handle for y in yields])
print("After create yield op")
self.module.dump()
self.builder.set_insertion_point_to_end(insert_block)
print("After restoring insert point")
self.module.dump()
self.lscope = liveins
self.local_defs = prev_defs
# ForOp defines new values
for i, name in enumerate(names):
self.lscope[name] = for_op.get_result(i)
self.local_defs[name] = for_op.get_result(i)
for stmt in node.orelse:
assert False