[PYTHON] Various minor codegen fixes (#95)
This commit is contained in:
committed by
Philippe Tillet
parent
2b75158426
commit
4290be1ae8
@@ -107,9 +107,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
ir::phi_node* phi = make_phi(ty, 1, block);
|
||||
set_value(name, block, phi);
|
||||
result = add_phi_operands(name, phi);
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
|
||||
result = try_remove_trivial_phis(phi);
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
|
||||
result = try_remove_trivial_phis(phi);
|
||||
}
|
||||
set_value(name, block, result);
|
||||
return result;
|
||||
}
|
||||
|
@@ -181,6 +181,9 @@ void init_triton_ir(py::module &&m) {
|
||||
py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
|
||||
.def_property_readonly("value", &ir::constant_fp::get_value);
|
||||
|
||||
py::class_<ir::instruction, ir::user>(m, "instruction");
|
||||
py::class_<ir::phi_node, ir::user>(m, "phi_node");
|
||||
|
||||
py::class_<ir::type>(m, "type")
|
||||
.def("is_ptr", &ir::type::is_pointer_ty)
|
||||
.def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
|
||||
|
@@ -51,7 +51,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
break
|
||||
if add_scope:
|
||||
self.module.pop_scope()
|
||||
return self.last_ret
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, kwargs):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
@@ -85,7 +85,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
# By design, only non-kernel functions can return
|
||||
def visit_Return(self, node):
|
||||
return self.visit(node.value)
|
||||
ret = self.visit(node.value)
|
||||
if ret is None:
|
||||
return self.builder.ret_void()
|
||||
return ret
|
||||
|
||||
def visit_FunctionDef(self, node, inline=False, arg_values=None):
|
||||
arg_names, kwarg_names = self.visit(node.args)
|
||||
@@ -112,7 +115,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
if inline:
|
||||
return self.visit_compound_statement(node.body, add_scope=True)
|
||||
self.visit_compound_statement(node.body, add_scope=True)
|
||||
return self.last_ret
|
||||
else:
|
||||
entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn)
|
||||
self.module.seal_block(entry)
|
||||
@@ -140,6 +144,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
assert len(names) == 1
|
||||
name = names[0]
|
||||
value = self.visit(node.value)
|
||||
if not isinstance(value, triton.language.block):
|
||||
value = triton.language._to_ir(value, self.builder)
|
||||
self.set_value(names[0], value)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
@@ -208,14 +214,16 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
self.builder.cond_br(cond.handle, then_bb, endif_bb)
|
||||
self.builder.set_insert_block(then_bb)
|
||||
self.visit_compound_statement(node.body, add_scope=True)
|
||||
is_terminator = self.visit_compound_statement(node.body, add_scope=True)
|
||||
# TODO: last statement is a terminator?
|
||||
self.builder.br(endif_bb)
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
if else_bb:
|
||||
self.builder.set_insert_block(else_bb)
|
||||
self.visit_compound_statement(node.orelse, add_scope=True)
|
||||
is_terminator = self.visit_compound_statement(node.orelse, add_scope=True)
|
||||
#TODO: last statement is a terminator?
|
||||
self.builder.br(endif_bb)
|
||||
if not is_terminator:
|
||||
self.builder.br(endif_bb)
|
||||
self.module.seal_block(endif_bb)
|
||||
self.builder.set_insert_block(endif_bb)
|
||||
else:
|
||||
|
@@ -4,22 +4,22 @@ from triton._C.libtriton.triton import frontend
|
||||
from functools import wraps
|
||||
|
||||
|
||||
# convert block/dtype to ir values
|
||||
def _to_ir(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return builder.get_int1(x)
|
||||
elif isinstance(x, int):
|
||||
return builder.get_int32(x)
|
||||
elif isinstance(x, float):
|
||||
return builder.get_float32(x)
|
||||
if isinstance(x, block):
|
||||
return x.handle
|
||||
if isinstance(x, dtype):
|
||||
return x.handle(builder)
|
||||
return x
|
||||
|
||||
|
||||
def _patch(fn):
|
||||
|
||||
# convert block/dtype to ir values
|
||||
def _to_ir(x, builder):
|
||||
if isinstance(x, bool):
|
||||
return builder.get_int1(x)
|
||||
elif isinstance(x, int):
|
||||
return builder.get_int32(x)
|
||||
elif isinstance(x, float):
|
||||
return builder.get_float32(x)
|
||||
if isinstance(x, block):
|
||||
return x.handle
|
||||
if isinstance(x, dtype):
|
||||
return x.handle(builder)
|
||||
return x
|
||||
|
||||
def _from_ir(x):
|
||||
if isinstance(x, ir.value):
|
||||
if x.type.is_void():
|
||||
@@ -306,6 +306,7 @@ def zeros(shape, dtype, builder=None):
|
||||
:param dtype: Data-type of the new array, e.g., :code:`triton.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
shape = [int(x.handle) if isinstance(x, block) else x for x in shape]
|
||||
return frontend.zeros(shape, dtype, builder)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user