[FRONTEND] For loops now promote initial value (#524)
This commit is contained in:
@@ -584,13 +584,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
for stmt in node.orelse:
|
for stmt in node.orelse:
|
||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
return
|
return
|
||||||
|
|
||||||
# create nodes
|
# create nodes
|
||||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||||
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
ld_target = ast.Name(id=node.target.id, ctx=ast.Load())
|
||||||
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)
|
||||||
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]
|
||||||
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)
|
||||||
|
# init node
|
||||||
init_node = ast.Assign(targets=[st_target], value=arg_0)
|
init_node = ast.Assign(targets=[st_target], value=arg_0)
|
||||||
|
|
||||||
|
# step node
|
||||||
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1])
|
||||||
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1])
|
||||||
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)])
|
||||||
@@ -610,7 +614,17 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = build_cond()
|
cond = build_cond()
|
||||||
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
return self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||||
|
|
||||||
|
# init loop induction variable
|
||||||
self.visit(init_node)
|
self.visit(init_node)
|
||||||
|
# promote it to right type
|
||||||
|
init_val = self.value_constructor.get_value(node.target.id)
|
||||||
|
promote = lambda a, b: triton.language.semantic.computation_type_impl(a, b, False)
|
||||||
|
start_ty = triton.language.core._to_tensor(iter_args[0], self.builder).type
|
||||||
|
stop_ty = triton.language.core._to_tensor(iter_args[1], self.builder).type if len(iter_args) > 1 else None
|
||||||
|
ty = promote(start_ty, stop_ty) if len(iter_args) > 1 else start_ty
|
||||||
|
casted = triton.language.semantic.cast(init_val, ty, self.builder)
|
||||||
|
self.value_constructor.set_value(node.target.id, casted)
|
||||||
|
# create cond
|
||||||
cond = build_cond()
|
cond = build_cond()
|
||||||
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
self.builder.cond_br(cond.handle, loop_bb, next_bb)
|
||||||
self.builder.set_insert_block(loop_bb)
|
self.builder.set_insert_block(loop_bb)
|
||||||
|
Reference in New Issue
Block a user