diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 619e3109e..d2c834f40 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -584,13 +584,17 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) return + # create nodes st_target = ast.Name(id=node.target.id, ctx=ast.Store()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0) arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0] arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1) + # init node init_node = ast.Assign(targets=[st_target], value=arg_0) + + # step node pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) @@ -610,7 +614,17 @@ class CodeGenerator(ast.NodeVisitor): cond = build_cond() return self.builder.cond_br(cond.handle, loop_bb, next_bb) + # init loop induction variable self.visit(init_node) + # promote it to right type + init_val = self.value_constructor.get_value(node.target.id) + promote = lambda a, b: triton.language.semantic.computation_type_impl(a, b, False) + start_ty = triton.language.core._to_tensor(iter_args[0], self.builder).type + stop_ty = triton.language.core._to_tensor(iter_args[1], self.builder).type if len(iter_args) > 1 else None + ty = promote(start_ty, stop_ty) if len(iter_args) > 1 else start_ty + casted = triton.language.semantic.cast(init_val, ty, self.builder) + self.value_constructor.set_value(node.target.id, casted) + # create cond cond = build_cond() self.builder.cond_br(cond.handle, loop_bb, next_bb) self.builder.set_insert_block(loop_bb)