[PYTHON] Various minor codegen fixes (#95)

This commit is contained in:
Philippe Tillet
2021-04-29 18:54:38 -04:00
committed by Philippe Tillet
parent 2b75158426
commit 4290be1ae8
4 changed files with 38 additions and 23 deletions

View File

@@ -107,9 +107,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
result = try_remove_trivial_phis(phi);
}
set_value(name, block, result);
return result;
}

View File

@@ -181,6 +181,9 @@ void init_triton_ir(py::module &&m) {
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
.def_property_readonly("value", &ir::constant_fp::get_value);
py::class_<ir::instruction, ir::user>(m, "instruction");
py::class_<ir::phi_node, ir::user>(m, "phi_node");
py::class_<ir::type>(m, "type")
.def("is_ptr", &ir::type::is_pointer_ty)
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))

View File

@@ -51,7 +51,7 @@ class CodeGenerator(ast.NodeVisitor):
break
if add_scope:
self.module.pop_scope()
return self.last_ret
return stmts and isinstance(stmt, ast.Return)
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
self.builder = _triton.ir.builder(context)
@@ -85,7 +85,10 @@ class CodeGenerator(ast.NodeVisitor):
# By design, only non-kernel functions can return
def visit_Return(self, node):
return self.visit(node.value)
ret = self.visit(node.value)
if ret is None:
return self.builder.ret_void()
return ret
def visit_FunctionDef(self, node, inline=False, arg_values=None):
arg_names, kwarg_names = self.visit(node.args)
@@ -112,7 +115,8 @@ class CodeGenerator(ast.NodeVisitor):
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
if inline:
return self.visit_compound_statement(node.body, add_scope=True)
self.visit_compound_statement(node.body, add_scope=True)
return self.last_ret
else:
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
self.module.seal_block(entry)
@@ -140,6 +144,8 @@ class CodeGenerator(ast.NodeVisitor):
assert len(names) == 1
name = names[0]
value = self.visit(node.value)
if not isinstance(value, triton.language.block):
value = triton.language._to_ir(value, self.builder)
self.set_value(names[0], value)
def visit_AugAssign(self, node):
@@ -208,14 +214,16 @@ class CodeGenerator(ast.NodeVisitor):
else:
self.builder.cond_br(cond.handle, then_bb, endif_bb)
self.builder.set_insert_block(then_bb)
self.visit_compound_statement(node.body, add_scope=True)
is_terminator = self.visit_compound_statement(node.body, add_scope=True)
# TODO: last statement is a terminator?
self.builder.br(endif_bb)
if not is_terminator:
self.builder.br(endif_bb)
if else_bb:
self.builder.set_insert_block(else_bb)
self.visit_compound_statement(node.orelse, add_scope=True)
is_terminator = self.visit_compound_statement(node.orelse, add_scope=True)
#TODO: last statement is a terminator?
self.builder.br(endif_bb)
if not is_terminator:
self.builder.br(endif_bb)
self.module.seal_block(endif_bb)
self.builder.set_insert_block(endif_bb)
else:

View File

@@ -4,22 +4,22 @@ from triton._C.libtriton.triton import frontend
from functools import wraps
# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
return builder.get_int32(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
return x.handle
if isinstance(x, dtype):
return x.handle(builder)
return x
def _patch(fn):
# convert block/dtype to ir values
def _to_ir(x, builder):
if isinstance(x, bool):
return builder.get_int1(x)
elif isinstance(x, int):
return builder.get_int32(x)
elif isinstance(x, float):
return builder.get_float32(x)
if isinstance(x, block):
return x.handle
if isinstance(x, dtype):
return x.handle(builder)
return x
def _from_ir(x):
if isinstance(x, ir.value):
if x.type.is_void():
@@ -306,6 +306,7 @@ def zeros(shape, dtype, builder=None):
:param dtype: Data-type of the new array, e.g., :code:`triton.float16`
:type dtype: DType
"""
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
return frontend.zeros(shape, dtype, builder)