[FRONTEND] Added default arguments for range (#203)

This commit is contained in:
Philippe Tillet
2021-08-14 10:11:18 -07:00
committed by GitHub
parent b120d70a0a
commit c7a272cb91

View File

@@ -316,20 +316,24 @@ class CodeGenerator(ast.NodeVisitor):
def visit_For(self, node):
iterator = self.visit(node.iter.func)
assert iterator == self.builtins['range']
if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0])
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
init_node = ast.Assign(targets=[st_target], value=arg_0)
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\
self.visit(neg_cond_node),\
builder=self.builder)
#cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2])
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation
current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)