[FRONTEND] Added default arguments for range (#203)

This commit is contained in:
Philippe Tillet
2021-08-14 10:11:18 -07:00
committed by GitHub
parent b120d70a0a
commit c7a272cb91

View File

@@ -316,20 +316,24 @@ class CodeGenerator(ast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
iterator = self.visit(node.iter.func) iterator = self.visit(node.iter.func)
assert iterator == self.builtins['range'] if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
# create nodes # create nodes
st_target = ast.Name(id=node.target.id, ctx=ast.Store()) st_target = ast.Name(id=node.target.id, ctx=ast.Store())
ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0]) arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]]) arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]]) arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)]) init_node = ast.Assign(targets=[st_target], value=arg_0)
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\ build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
self.visit(pos_cond_node),\ self.visit(pos_cond_node),\
self.visit(neg_cond_node),\ self.visit(neg_cond_node),\
builder=self.builder) builder=self.builder)
#cond_node = neg_cond_node #cond_node = neg_cond_node
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2]) step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
# code generation # code generation
current_bb = self.builder.get_insert_block() current_bb = self.builder.get_insert_block()
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent) loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)