Add more comments
This commit is contained in:
@@ -267,6 +267,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
cond = self.visit(node.test)
|
cond = self.visit(node.test)
|
||||||
if isinstance(cond, triton.language.tensor):
|
if isinstance(cond, triton.language.tensor):
|
||||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||||
|
# record lscope & local_defs in the parent scope
|
||||||
liveins = self.lscope.copy()
|
liveins = self.lscope.copy()
|
||||||
parent_defs = self.local_defs.copy()
|
parent_defs = self.local_defs.copy()
|
||||||
self.local_defs = {}
|
self.local_defs = {}
|
||||||
@@ -446,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
self.builder.set_insertion_point_to_end(after_block)
|
self.builder.set_insertion_point_to_end(after_block)
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
|
|
||||||
# update global_uses in while_op
|
# update global uses in while_op
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
before_block.replace_use_in_block_with(init_args[i].handle, before_block.arg(i))
|
||||||
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i))
|
||||||
@@ -494,23 +495,27 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.NodeVisitor.generic_visit(self, stmt)
|
ast.NodeVisitor.generic_visit(self, stmt)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# collect lower bound (lb), upper bound (ub), and step
|
||||||
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0))
|
||||||
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0])
|
||||||
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1))
|
||||||
# TODO: better way to do this?
|
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||||
|
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
|
||||||
lb = self.builder.create_to_index(lb)
|
lb = self.builder.create_to_index(lb)
|
||||||
ub = self.builder.create_to_index(ub)
|
ub = self.builder.create_to_index(ub)
|
||||||
step = self.builder.create_to_index(step)
|
step = self.builder.create_to_index(step)
|
||||||
|
|
||||||
|
# cache current insertion block
|
||||||
insert_block = self.builder.get_insertion_block()
|
insert_block = self.builder.get_insertion_block()
|
||||||
|
|
||||||
|
# create loop body block
|
||||||
block = self.builder.create_block()
|
block = self.builder.create_block()
|
||||||
self.builder.set_insertion_point_to_start(block)
|
self.builder.set_insertion_point_to_start(block)
|
||||||
|
|
||||||
|
# record lscope & local_defs in the parent scope
|
||||||
liveins = self.lscope.copy()
|
liveins = self.lscope.copy()
|
||||||
prev_defs = self.local_defs.copy()
|
prev_defs = self.local_defs.copy()
|
||||||
self.local_defs = {}
|
self.local_defs = {}
|
||||||
@@ -518,6 +523,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
# visit loop body
|
# visit loop body
|
||||||
self.visit_compound_statement(node.body)
|
self.visit_compound_statement(node.body)
|
||||||
|
|
||||||
|
# If a variable (name) is defined in both its parent & itself, then it's
|
||||||
|
# a loop-carried variable. (They must be of the same type)
|
||||||
init_args = []
|
init_args = []
|
||||||
yields = []
|
yields = []
|
||||||
names = []
|
names = []
|
||||||
@@ -526,20 +533,19 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
assert self.is_triton_tensor(self.local_defs[name])
|
assert self.is_triton_tensor(self.local_defs[name])
|
||||||
assert self.is_triton_tensor(liveins[name])
|
assert self.is_triton_tensor(liveins[name])
|
||||||
if self.local_defs[name].type == liveins[name].type:
|
if self.local_defs[name].type == liveins[name].type:
|
||||||
# TODO: better way to do this?
|
|
||||||
names.append(name)
|
names.append(name)
|
||||||
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
init_args.append(triton.language.core._to_tensor(liveins[name], self.builder))
|
||||||
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
yields.append(triton.language.core._to_tensor(self.local_defs[name], self.builder))
|
||||||
|
# create ForOp
|
||||||
self.builder.set_insertion_point_to_end(insert_block)
|
self.builder.set_insertion_point_to_end(insert_block)
|
||||||
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args])
|
||||||
# FIXME: the body should be a region (?)
|
|
||||||
# FIXME: this won't work for nested control flow
|
|
||||||
block.merge_block_before(for_op.get_body(0))
|
block.merge_block_before(for_op.get_body(0))
|
||||||
|
# create YieldOp
|
||||||
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
self.builder.set_insertion_point_to_end(for_op.get_body(0))
|
||||||
self.builder.create_yield_op([y.handle for y in yields])
|
self.builder.create_yield_op([y.handle for y in yields])
|
||||||
for_op_region = for_op.get_body(0).get_parent()
|
for_op_region = for_op.get_body(0).get_parent()
|
||||||
assert for_op_region.size() == 1, "(For developer) Should use region here"
|
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
|
||||||
|
# replace global uses with block arguments
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
# arg0 is the induction variable
|
# arg0 is the induction variable
|
||||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||||
@@ -548,7 +554,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
self.lscope = liveins
|
self.lscope = liveins
|
||||||
self.local_defs = prev_defs
|
self.local_defs = prev_defs
|
||||||
# ForOp defines new values
|
# update lscope & local_defs (ForOp defines new values)
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
self.lscope[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||||
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
self.local_defs[name] = triton.language.core.tensor(for_op.get_result(i), yields[i].type)
|
||||||
|
@@ -396,31 +396,6 @@ class constexpr:
|
|||||||
|
|
||||||
|
|
||||||
class tensor:
|
class tensor:
|
||||||
# # infer dtype from ir type
|
|
||||||
# @staticmethod
|
|
||||||
# def _to_dtype(ir_type):
|
|
||||||
# # block type
|
|
||||||
# if ir_type.is_block():
|
|
||||||
# scalar_ty = tensor._to_dtype(ir_type.scalar)
|
|
||||||
# return block_type(scalar_ty, ir_type.get_block_shapes())
|
|
||||||
# # pointer type
|
|
||||||
# if ir_type.is_ptr():
|
|
||||||
# element_ty = tensor._to_dtype(ir_type.element)
|
|
||||||
# return pointer_type(element_ty)
|
|
||||||
# # primitive type
|
|
||||||
# if ir_type.is_void(): return void
|
|
||||||
# if ir_type.is_int1(): return int1
|
|
||||||
# if ir_type.is_int8(): return int8
|
|
||||||
# if ir_type.is_int16(): return int16
|
|
||||||
# if ir_type.is_int32(): return int32
|
|
||||||
# if ir_type.is_int64(): return int64
|
|
||||||
# if ir_type.is_fp8(): return float8
|
|
||||||
# if ir_type.is_fp16(): return float16
|
|
||||||
# if ir_type.is_bf16(): return bfloat16
|
|
||||||
# if ir_type.is_fp32(): return float32
|
|
||||||
# if ir_type.is_fp64(): return float64
|
|
||||||
# raise ValueError(f"Unsupported type {ir_type.repr()}")
|
|
||||||
|
|
||||||
def __init__(self, handle, type: dtype):
|
def __init__(self, handle, type: dtype):
|
||||||
# IR handle
|
# IR handle
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
Reference in New Issue
Block a user