Add more comments
This commit is contained in:
@@ -267,6 +267,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
# record lscope & local_defs in the parent scope
|
||||
liveins = self.lscope.copy()
|
||||
parent_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
@@ -446,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.builder.set_insertion_point_to_end(after_block)
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
|
||||
# update global_uses in while_op
|
||||
# update global uses in while_op
|
||||
for i, name in enumerate(names):
|
||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||
@@ -494,23 +495,27 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
|
||||
# collect lower bound (lb), upper bound (ub), and step
|
||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||
# TODO: better way to do this?
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||
|
||||
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
|
||||
# cache current insertion block
|
||||
insert_block = self.builder.get_insertion_block()
|
||||
|
||||
# create loop body block
|
||||
block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(block)
|
||||
|
||||
# record lscope & local_defs in the parent scope
|
||||
liveins = self.lscope.copy()
|
||||
prev_defs = self.local_defs.copy()
|
||||
self.local_defs = {}
|
||||
@@ -518,6 +523,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# visit loop body
|
||||
self.visit_compound_statement(node.body)
|
||||
|
||||
# If a variable (name) is defined in both its parent & itself, then it's
|
||||
# a loop-carried variable. (They must be of the same type)
|
||||
init_args = []
|
||||
yields = []
|
||||
names = []
|
||||
@@ -526,20 +533,19 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert self.is_triton_tensor(self.local_defs[name])
|
||||
assert self.is_triton_tensor(liveins[name])
|
||||
if self.local_defs[name].type == liveins[name].type:
|
||||
# TODO: better way to do this?
|
||||
names.append(name)
|
||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||
|
||||
# create ForOp
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||
# FIXME: the body should be a region (?)
|
||||
# FIXME: this won't work for nested control flow
|
||||
block.merge_block_before(for_op.get_body(0))
|
||||
# create YieldOp
|
||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||
self.builder.create_yield_op([y.handle for y in yields])
|
||||
for_op_region = for_op.get_body(0).get_parent()
|
||||
assert for_op_region.size() == 1, "(For developer) Should use region here"
|
||||
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||
# replace global uses with block arguments
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
@@ -548,7 +554,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
self.lscope = liveins
|
||||
self.local_defs = prev_defs
|
||||
# ForOp defines new values
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||
|
@@ -396,31 +396,6 @@ class constexpr:
|
||||
|
||||
|
||||
class tensor:
|
||||
# # infer dtype from ir type
|
||||
# @staticmethod
|
||||
# def _to_dtype(ir_type):
|
||||
# # block type
|
||||
# if ir_type.is_block():
|
||||
# scalar_ty = tensor._to_dtype(ir_type.scalar)
|
||||
# return block_type(scalar_ty, ir_type.get_block_shapes())
|
||||
# # pointer type
|
||||
# if ir_type.is_ptr():
|
||||
# element_ty = tensor._to_dtype(ir_type.element)
|
||||
# return pointer_type(element_ty)
|
||||
# # primitive type
|
||||
# if ir_type.is_void(): return void
|
||||
# if ir_type.is_int1(): return int1
|
||||
# if ir_type.is_int8(): return int8
|
||||
# if ir_type.is_int16(): return int16
|
||||
# if ir_type.is_int32(): return int32
|
||||
# if ir_type.is_int64(): return int64
|
||||
# if ir_type.is_fp8(): return float8
|
||||
# if ir_type.is_fp16(): return float16
|
||||
# if ir_type.is_bf16(): return bfloat16
|
||||
# if ir_type.is_fp32(): return float32
|
||||
# if ir_type.is_fp64(): return float64
|
||||
# raise ValueError(f"Unsupported type {ir_type.repr()}")
|
||||
|
||||
def __init__(self, handle, type: dtype):
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
|
Reference in New Issue
Block a user