[FRONTEND] Made test_if/test_default pass (#823)
This commit is contained in:
@@ -196,7 +196,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return tuple(ret_types)
|
||||
else:
|
||||
ret = triton.language.core._to_tensor(ret_value, self.builder)
|
||||
self.builder.ret([ret_value.handle])
|
||||
self.builder.ret([ret.handle])
|
||||
return ret.type
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
@@ -399,13 +399,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
||||
if len(names) > 0:
|
||||
self.builder.create_yield_op([then_defs[n].handle for n in names])
|
||||
if not node.orelse:
|
||||
else_block = if_op.get_else_block()
|
||||
else:
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
||||
if len(names) > 0:
|
||||
self.builder.create_yield_op([else_defs[n].handle for n in names])
|
||||
else: # no else block
|
||||
if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, False)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
@@ -526,7 +528,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
loop_block.merge_block_before(after_block)
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
if len(yields) > 0:
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
|
Reference in New Issue
Block a user