diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 70799861d..f91d0e78f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -316,20 +316,24 @@ class CodeGenerator(ast.NodeVisitor): def visit_For(self, node): iterator = self.visit(node.iter.func) - assert iterator == self.builtins['range'] + if iterator != self.builtins['range']: + raise RuntimeError('Only `range` iterator currently supported') # create nodes st_target = ast.Name(id=node.target.id, ctx=ast.Store()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) - init_node = ast.Assign(targets=[st_target], value=node.iter.args[0]) - pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]]) - neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]]) - pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)]) + arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0) + arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0] + arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1) + init_node = ast.Assign(targets=[st_target], value=arg_0) + pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) + neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) + pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ self.visit(pos_cond_node),\ self.visit(neg_cond_node),\ builder=self.builder) #cond_node = neg_cond_node - step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2]) + step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation current_bb = self.builder.get_insert_block() loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)