[FRONTEND] Added default arguments for range
(#203)
This commit is contained in:
@@ -316,20 +316,24 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def visit_For(self, node):
|
def visit_For(self, node):
|
||||||
iterator = self.visit(node.iter.func)
|
iterator = self.visit(node.iter.func)
|
||||||
assert iterator == self.builtins['range']
|
if iterator != self.builtins['range']:
|
||||||
|
raise RuntimeError('Only `range` iterator currently supported')
|
||||||
# create nodes
|
# create nodes
|
||||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||||
init_node = ast.Assign(targets=[st_target], value=node.iter.args[0])
|
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
||||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [node.iter.args[1]])
|
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
||||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [node.iter.args[1]])
|
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
||||||
pos_step_node = ast.Compare(node.iter.args[2], [ast.Gt()], [ast.Num(0)])
|
init_node = ast.Assign(targets=[st_target], value=arg_0)
|
||||||
|
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||||
|
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||||
|
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||||
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
build_cond = lambda: triton.language.where(self.visit(pos_step_node),\
|
||||||
self.visit(pos_cond_node),\
|
self.visit(pos_cond_node),\
|
||||||
self.visit(neg_cond_node),\
|
self.visit(neg_cond_node),\
|
||||||
builder=self.builder)
|
builder=self.builder)
|
||||||
#cond_node = neg_cond_node
|
#cond_node = neg_cond_node
|
||||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=node.iter.args[2])
|
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||||
# code generation
|
# code generation
|
||||||
current_bb = self.builder.get_insert_block()
|
current_bb = self.builder.get_insert_block()
|
||||||
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
loop_bb = _triton.ir.basic_block.create(self.module.builder.context, "loop", current_bb.parent)
|
||||||
|
Reference in New Issue
Block a user