Add more comments

This commit is contained in:
Yan Da
2022-04-10 14:36:03 +08:00
parent f1cc67bbc3
commit aa6e086881
2 changed files with 15 additions and 34 deletions

View File

@@ -267,6 +267,7 @@ class CodeGenerator(ast.NodeVisitor):
cond = self.visit(node.test) cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor): if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder) cond = cond.to(triton.language.int1, _builder=self.builder)
# record lscope & local_defs in the parent scope
liveins = self.lscope.copy() liveins = self.lscope.copy()
parent_defs = self.local_defs.copy() parent_defs = self.local_defs.copy()
self.local_defs = {} self.local_defs = {}
@@ -446,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor):
self.builder.set_insertion_point_to_end(after_block) self.builder.set_insertion_point_to_end(after_block)
self.builder.create_yield_op([y.handle for y in yields]) self.builder.create_yield_op([y.handle for y in yields])
# update global_uses in while_op # update global uses in while_op
for i, name in enumerate(names): for i, name in enumerate(names):
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i)) before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
@@ -494,23 +495,27 @@ class CodeGenerator(ast.NodeVisitor):
ast.NodeVisitor.generic_visit(self, stmt) ast.NodeVisitor.generic_visit(self, stmt)
return return
# collect lower bound (lb), upper bound (ub), and step
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
# TODO: better way to do this? # lb/ub/step might be constexpr, we need to cast them to tensor
lb = triton.language.core._to_tensor(lb, self.builder).handle lb = triton.language.core._to_tensor(lb, self.builder).handle
ub = triton.language.core._to_tensor(ub, self.builder).handle ub = triton.language.core._to_tensor(ub, self.builder).handle
step = triton.language.core._to_tensor(step, self.builder).handle step = triton.language.core._to_tensor(step, self.builder).handle
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
lb = self.builder.create_to_index(lb) lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub) ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step) step = self.builder.create_to_index(step)
# cache current insertion block
insert_block = self.builder.get_insertion_block() insert_block = self.builder.get_insertion_block()
# create loop body block
block = self.builder.create_block() block = self.builder.create_block()
self.builder.set_insertion_point_to_start(block) self.builder.set_insertion_point_to_start(block)
# record lscope & local_defs in the parent scope
liveins = self.lscope.copy() liveins = self.lscope.copy()
prev_defs = self.local_defs.copy() prev_defs = self.local_defs.copy()
self.local_defs = {} self.local_defs = {}
@@ -518,6 +523,8 @@ class CodeGenerator(ast.NodeVisitor):
# visit loop body # visit loop body
self.visit_compound_statement(node.body) self.visit_compound_statement(node.body)
# If a variable (name) is defined in both its parent & itself, then it's
# a loop-carried variable. (They must be of the same type)
init_args = [] init_args = []
yields = [] yields = []
names = [] names = []
@@ -526,20 +533,19 @@ class CodeGenerator(ast.NodeVisitor):
assert self.is_triton_tensor(self.local_defs[name]) assert self.is_triton_tensor(self.local_defs[name])
assert self.is_triton_tensor(liveins[name]) assert self.is_triton_tensor(liveins[name])
if self.local_defs[name].type == liveins[name].type: if self.local_defs[name].type == liveins[name].type:
# TODO: better way to do this?
names.append(name) names.append(name)
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder)) init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder)) yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
# create ForOp
self.builder.set_insertion_point_to_end(insert_block) self.builder.set_insertion_point_to_end(insert_block)
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
# FIXME: the body should be a region (?)
# FIXME: this won't work for nested control flow
block.merge_block_before(for_op.get_body(0)) block.merge_block_before(for_op.get_body(0))
# create YieldOp
self.builder.set_insertion_point_to_end(for_op.get_body(0)) self.builder.set_insertion_point_to_end(for_op.get_body(0))
self.builder.create_yield_op([y.handle for y in yields]) self.builder.create_yield_op([y.handle for y in yields])
for_op_region = for_op.get_body(0).get_parent() for_op_region = for_op.get_body(0).get_parent()
assert for_op_region.size() == 1, "(For developer) Should use region here" assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
# replace global uses with block arguments
for i, name in enumerate(names): for i, name in enumerate(names):
# arg0 is the induction variable # arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
@@ -548,7 +554,7 @@ class CodeGenerator(ast.NodeVisitor):
self.lscope = liveins self.lscope = liveins
self.local_defs = prev_defs self.local_defs = prev_defs
# ForOp defines new values # update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names): for i, name in enumerate(names):
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type) self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)

View File

@@ -396,31 +396,6 @@ class constexpr:
class tensor: class tensor:
# # infer dtype from ir type
# @staticmethod
# def _to_dtype(ir_type):
# # block type
# if ir_type.is_block():
# scalar_ty = tensor._to_dtype(ir_type.scalar)
# return block_type(scalar_ty, ir_type.get_block_shapes())
# # pointer type
# if ir_type.is_ptr():
# element_ty = tensor._to_dtype(ir_type.element)
# return pointer_type(element_ty)
# # primitive type
# if ir_type.is_void(): return void
# if ir_type.is_int1(): return int1
# if ir_type.is_int8(): return int8
# if ir_type.is_int16(): return int16
# if ir_type.is_int32(): return int32
# if ir_type.is_int64(): return int64
# if ir_type.is_fp8(): return float8
# if ir_type.is_fp16(): return float16
# if ir_type.is_bf16(): return bfloat16
# if ir_type.is_fp32(): return float32
# if ir_type.is_fp64(): return float64
# raise ValueError(f"Unsupported type {ir_type.repr()}")
def __init__(self, handle, type: dtype): def __init__(self, handle, type: dtype):
# IR handle # IR handle
self.handle = handle self.handle = handle