Use mlir::Block to replace MlirBlock
This commit is contained in:
@@ -30,7 +30,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.prototype = prototype
|
||||
self.gscope = gscope
|
||||
self.lscope = dict()
|
||||
self.is_arg_lscope = dict() # name => is_arg: {str: bool}
|
||||
self.attributes = attributes
|
||||
self.constants = constants
|
||||
self.kwargs = kwargs
|
||||
@@ -69,33 +68,32 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret = self.builtins[name]
|
||||
else:
|
||||
raise ValueError(f'{name} is not defined')
|
||||
if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]:
|
||||
if self.is_triton_tensor(ret):
|
||||
return self._get_tensor(name)
|
||||
return ret
|
||||
|
||||
def set_value(self, name: str,
|
||||
value: Union[triton.language.tensor, triton.language.constexpr],
|
||||
is_arg: bool = False) -> None:
|
||||
value: Union[triton.language.tensor, triton.language.constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
self.lscope[name] = value
|
||||
# if this value is an argument, we don't need to create phis for it
|
||||
self.is_arg_lscope[name] = is_arg
|
||||
if isinstance(value, triton.language.tensor) and not is_arg:
|
||||
self._set_value(name, self.builder.get_insert_block(), value)
|
||||
if isinstance(value, triton.language.tensor):
|
||||
self._set_value(name, self.builder.get_insertion_block(), value)
|
||||
|
||||
#
|
||||
# SSA-construction
|
||||
#
|
||||
def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor:
|
||||
if not bb:
|
||||
bb = self.builder.get_insert_block()
|
||||
bb = self.builder.get_insertion_block()
|
||||
# local value numbering
|
||||
if (name, bb) in self.lvalues:
|
||||
return self.lvalues[(name, bb)]
|
||||
print(self.lvalues)
|
||||
assert False, f'Cannot find {name} in {bb}'
|
||||
# global value numbering
|
||||
saved_insert_point = self.builder.get_insert_point()
|
||||
result = self._get_tensor_recursive(name, bb)
|
||||
@@ -115,8 +113,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
elif len(preds) == 1:
|
||||
# one predecessor: no phi needed, try get value from pred
|
||||
result = self._get_tensor(name, preds[0])
|
||||
elif len(preds) == 0:
|
||||
result = self._get_tensor(name, None)
|
||||
else: # multiple preds
|
||||
assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)'
|
||||
phi = self._make_phi(type, len(preds), bb)
|
||||
self._set_value(name, bb, phi)
|
||||
result = self._add_phi_operands(name, phi)
|
||||
@@ -148,8 +147,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None:
|
||||
self.lvalues[(name, bb)] = value
|
||||
# TODO: why we need this?
|
||||
self.module.set_instr_metadata(name, value.handle)
|
||||
# # TODO: why we need this?
|
||||
# self.module.set_instr_metadata(name, value.handle)
|
||||
|
||||
def _seal_block(self, bb: _triton.ir.basic_block):
|
||||
# complete all incomplete phis
|
||||
@@ -220,7 +219,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if inline:
|
||||
pass
|
||||
else:
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder))
|
||||
fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder))
|
||||
self.module.push_back(fn)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
@@ -230,25 +230,27 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cst = triton.language.constexpr(self.constants[i])
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[idx].type.is_ptr()
|
||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
pass
|
||||
# TODO: ...
|
||||
# if i in self.attributes:
|
||||
# is_ptr = fn.args[idx].type.is_ptr()
|
||||
# attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
# attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
# fn.add_attr(idx + 1, attr)
|
||||
# fn.args[idx].name = arg_name
|
||||
# arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx]))
|
||||
# idx += 1
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value, is_arg=True)
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
self.visit_compound_statement(node.body)
|
||||
return self.last_ret
|
||||
else:
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
entry = fn.add_entry_block()
|
||||
self._seal_block(entry)
|
||||
self.builder.set_insert_block(entry)
|
||||
self.builder.set_insertion_point_to_start(entry)
|
||||
# visit function body
|
||||
self.visit_compound_statement(node.body)
|
||||
# finalize function
|
||||
@@ -358,7 +360,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
cond = self.visit(node.test)
|
||||
if isinstance(cond, triton.language.tensor):
|
||||
cond = cond.to(triton.language.int1, _builder=self.builder)
|
||||
current_bb = self.builder.get_insert_block()
|
||||
current_bb = self.builder.get_insertion_block()
|
||||
then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent)
|
||||
else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None
|
||||
endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent)
|
||||
@@ -445,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return getattr(op, fn)()
|
||||
|
||||
def visit_While(self, node):
|
||||
current_bb = self.builder.get_insert_block()
|
||||
current_bb = self.builder.get_insertion_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
@@ -457,7 +459,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.builder.set_insert_block(loop_bb)
|
||||
self.visit_compound_statement(node.body)
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
stop_bb = self.builder.get_insertion_block()
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
@@ -512,7 +514,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# cond_node = neg_cond_node
|
||||
step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2)
|
||||
# code generation
|
||||
current_bb = self.builder.get_insert_block()
|
||||
current_bb = self.builder.get_insertion_block()
|
||||
loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent)
|
||||
next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent)
|
||||
|
||||
@@ -528,7 +530,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
# TODO: handle case where body breaks control flow
|
||||
continue_fn()
|
||||
stop_bb = self.builder.get_insert_block()
|
||||
stop_bb = self.builder.get_insertion_block()
|
||||
self._seal_block(stop_bb)
|
||||
self._seal_block(loop_bb)
|
||||
self._seal_block(next_bb)
|
||||
@@ -845,10 +847,11 @@ class Kernel:
|
||||
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
@@ -1179,10 +1182,11 @@ class JITFunction:
|
||||
def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages):
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types]
|
||||
ret_type = triton.language.void
|
||||
prototype = triton.language.function_type(ret_type, arg_types)
|
||||
prototype = triton.language.function_type([ret_type], arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self into code-generator object
|
||||
gscope = self.__globals__
|
||||
|
||||
Reference in New Issue
Block a user