diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 1f99a155a..916ba19b9 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -267,6 +267,7 @@ class CodeGenerator(ast.NodeVisitor): cond = self.visit(node.test) if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) + # record lscope & local_defs in the parent scope liveins = self.lscope.copy() parent_defs = self.local_defs.copy() self.local_defs = {} @@ -446,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor): self.builder.set_insertion_point_to_end(after_block) self.builder.create_yield_op([y.handle for y in yields]) - # update global_uses in while_op + # update global uses in while_op for i, name in enumerate(names): before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) @@ -494,23 +495,27 @@ class CodeGenerator(ast.NodeVisitor): ast.NodeVisitor.generic_visit(self, stmt) return + # collect lower bound (lb), upper bound (ub), and step lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) - # TODO: better way to do this? + # lb/ub/step might be constexpr, we need to cast them to tensor lb = triton.language.core._to_tensor(lb, self.builder).handle ub = triton.language.core._to_tensor(ub, self.builder).handle step = triton.language.core._to_tensor(step, self.builder).handle - + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index lb = self.builder.create_to_index(lb) ub = self.builder.create_to_index(ub) step = self.builder.create_to_index(step) + # cache current insertion block insert_block = self.builder.get_insertion_block() + # create loop body block block = self.builder.create_block() self.builder.set_insertion_point_to_start(block) + # record lscope & local_defs in the parent scope liveins = self.lscope.copy() prev_defs = self.local_defs.copy() self.local_defs = {} @@ -518,6 +523,8 @@ class CodeGenerator(ast.NodeVisitor): # visit loop body self.visit_compound_statement(node.body) + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) init_args = [] yields = [] names = [] @@ -526,20 +533,19 @@ class CodeGenerator(ast.NodeVisitor): assert self.is_triton_tensor(self.local_defs[name]) assert self.is_triton_tensor(liveins[name]) if self.local_defs[name].type == liveins[name].type: - # TODO: better way to do this? names.append(name) init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) - + # create ForOp self.builder.set_insertion_point_to_end(insert_block) for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) - # FIXME: the body should be a region (?) - # FIXME: this won't work for nested control flow block.merge_block_before(for_op.get_body(0)) + # create YieldOp self.builder.set_insertion_point_to_end(for_op.get_body(0)) self.builder.create_yield_op([y.handle for y in yields]) for_op_region = for_op.get_body(0).get_parent() - assert for_op_region.size() == 1, "(For developer) Should use region here" + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + # replace global uses with block arguments for i, name in enumerate(names): # arg0 is the induction variable for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) @@ -548,7 +554,7 @@ class CodeGenerator(ast.NodeVisitor): self.lscope = liveins self.local_defs = prev_defs - # ForOp defines new values + # update lscope & local_defs (ForOp defines new values) for i, name in enumerate(names): self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 17e653778..ab9c194a0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -396,31 +396,6 @@ class constexpr: class tensor: - # # infer dtype from ir type - # @staticmethod - # def _to_dtype(ir_type): - # # block type - # if ir_type.is_block(): - # scalar_ty = tensor._to_dtype(ir_type.scalar) - # return block_type(scalar_ty, ir_type.get_block_shapes()) - # # pointer type - # if ir_type.is_ptr(): - # element_ty = tensor._to_dtype(ir_type.element) - # return pointer_type(element_ty) - # # primitive type - # if ir_type.is_void(): return void - # if ir_type.is_int1(): return int1 - # if ir_type.is_int8(): return int8 - # if ir_type.is_int16(): return int16 - # if ir_type.is_int32(): return int32 - # if ir_type.is_int64(): return int64 - # if ir_type.is_fp8(): return float8 - # if ir_type.is_fp16(): return float16 - # if ir_type.is_bf16(): return bfloat16 - # if ir_type.is_fp32(): return float32 - # if ir_type.is_fp64(): return float64 - # raise ValueError(f"Unsupported type {ir_type.repr()}") - def __init__(self, handle, type: dtype): # IR handle self.handle = handle