Update trtion dependencies

This commit is contained in:
Yan Da
2022-04-01 20:16:07 +08:00
parent bde103fab0
commit 9dafa0e2e3
2 changed files with 10 additions and 5 deletions

View File

@@ -28,6 +28,7 @@ def Triton_Dialect : Dialect {
"arith::ArithmeticDialect",
"tensor::TensorDialect",
"StandardOpsDialect",
"scf::SCFDialect"
// Since LLVM 15
// "cf::ControlFlowDialect",

View File

@@ -410,6 +410,10 @@ class CodeGenerator(ast.NodeVisitor):
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
# TODO: better way to do this?
lb = triton.language.core._to_tensor(lb, self.builder).handle
ub = triton.language.core._to_tensor(ub, self.builder).handle
step = triton.language.core._to_tensor(step, self.builder).handle
loop_body = self.builder.create_block()
self.builder.set_insertion_point_to_start(loop_body)
@@ -422,16 +426,16 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body)
# TODO: update insertion point
init_args = {}
yields = {}
init_args = []
yields = []
for name in self.local_defs:
if name in liveins:
assert self.is_triton_tensor(self.local_defs[name])
assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type:
init_args[name] = liveins[name]
yields[name] = self.local_defs[name]
# for_op = self.builder.create_for_op(lb, ub, step, [init_args])
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle)
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle)
for_op = self.builder.create_for_op(lb, ub, step, init_args)
self.lscope = liveins
self.local_defs = prev_defs