Update trtion dependencies
This commit is contained in:
@@ -28,6 +28,7 @@ def Triton_Dialect : Dialect {
|
||||
"arith::ArithmeticDialect",
|
||||
"tensor::TensorDialect",
|
||||
"StandardOpsDialect",
|
||||
"scf::SCFDialect"
|
||||
|
||||
// Since LLVM 15
|
||||
// "cf::ControlFlowDialect",
|
||||
|
@@ -410,6 +410,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||
# TODO: better way to do this?
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||
|
||||
loop_body = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(loop_body)
|
||||
@@ -422,16 +426,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
|
||||
# TODO: update insertion point
|
||||
init_args = {}
|
||||
yields = {}
|
||||
init_args = []
|
||||
yields = []
|
||||
for name in self.local_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(self.local_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
init_args[name] = liveins[name]
|
||||
yields[name] = self.local_defs[name]
|
||||
# for_op = self.builder.create_for_op(lb, ub, step, [init_args])
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle)
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, init_args)
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
|
Reference in New Issue
Block a user