Update trtion dependencies
This commit is contained in:
@@ -28,6 +28,7 @@ def Triton_Dialect : Dialect {
|
|||||||
"arith::ArithmeticDialect",
|
"arith::ArithmeticDialect",
|
||||||
"tensor::TensorDialect",
|
"tensor::TensorDialect",
|
||||||
"StandardOpsDialect",
|
"StandardOpsDialect",
|
||||||
|
"scf::SCFDialect"
|
||||||
|
|
||||||
// Since LLVM 15
|
// Since LLVM 15
|
||||||
// "cf::ControlFlowDialect",
|
// "cf::ControlFlowDialect",
|
||||||
|
@@ -410,6 +410,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||||
|
# TODO: better way to do this?
|
||||||
|
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||||
|
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||||
|
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||||
|
|
||||||
loop_body = self.builder.create_block()
|
loop_body = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(loop_body)
|
self.builder.set_insertion_point_to_start(loop_body)
|
||||||
@@ -422,16 +426,16 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
|
|
||||||
# TODO: update insertion point
|
# TODO: update insertion point
|
||||||
init_args = {}
|
init_args = []
|
||||||
yields = {}
|
yields = []
|
||||||
for name in self.local_defs:
|
for name in self.local_defs:
|
||||||
if name in liveins:
|
if name in liveins:
|
||||||
assert self.is_triton_tensor(self.local_defs[name])
|
assert self.is_triton_tensor(self.local_defs[name])
|
||||||
assert self.is_triton_tensor(liveins[name])
|
assert self.is_triton_tensor(liveins[name])
|
||||||
if self.local_defs[name].type == liveins[name].type:
|
if self.local_defs[name].type == liveins[name].type:
|
||||||
init_args[name] = liveins[name]
|
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder).handle)
|
||||||
yields[name] = self.local_defs[name]
|
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder).handle)
|
||||||
# for_op = self.builder.create_for_op(lb, ub, step, [init_args])
|
for_op = self.builder.create_for_op(lb, ub, step, init_args)
|
||||||
|
|
||||||
self.lscope = liveins
|
self.lscope = liveins
|
||||||
self.local_defs = prev_defs
|
self.local_defs = prev_defs
|
||||||
|
Reference in New Issue
Block a user