[FRONTEND] Added default arguments for range
(#203)
This commit is contained in:
@@ -316,20 +316,24 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_For(self, node):
|
||||
iterator = self.visit(node.iter.func)
|
||||
assert iterator == self.builtins['range']
|
||||
if iterator != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
# create nodes
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0])
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
|
||||
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
|
||||
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
||||
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
||||
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
||||
init_node = ast.Assign(targets=[st_target], value=arg_0)
|
||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||
self.visit(pos_cond_node),\
|
||||
self.visit(neg_cond_node),\
|
||||
builder=self.builder)
|
||||
#cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2])
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
current_bb = self.builder.get_insert_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||
|
Reference in New Issue
Block a user