diff --git a/python/src/triton.cc b/python/src/triton.cc index 8baffd066..d7e76a3b8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -4,11 +4,8 @@ #include "triton/driver/llvm.h" #include "mlir/IR/Builders.h" -#include "mlir-c/IR.h" -#include "mlir-c/BuiltinTypes.h" -#include "mlir/CAPI/IR.h" -// #include "mlir/IR/BuiltinOps.h" -// #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" #include "triton/ir/Dialect.h" #include "triton/ir/Types.h" @@ -652,13 +649,9 @@ void init_triton_ir(py::module &&m) { .def("get_context", &mlir::ModuleOp::getContext) ; - py::class_(m, "type") - .def("is_integer", [](MlirType &self) -> bool { - return mlirTypeIsAInteger(self); - }) - .def("is_fp16", [](MlirType &self) -> bool { - return mlirTypeIsABF16(self); - }) + py::class_(m, "type") + .def("is_integer", &mlir::Type::isInteger) + .def("is_fp16", &mlir::Type::isF16) ; py::class_(m, "value") @@ -782,74 +775,77 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, self.getF32FloatAttr(v)); }) + .def("get_null_value", [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { + auto loc = self.getUnknownLoc(); + if (type.isa()) + return self.create(loc, self.getF32FloatAttr(0.0)); + else + throw std::runtime_error("Not implemented"); + }) // Types - .def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType { - return wrap(self.getNoneType()); + .def("get_void_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getNoneType(); }) - .def("get_int1_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getI1Type()); + .def("get_int1_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getI1Type(); }) // or ret::copy? - .def("get_int8_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getI8Type()); + .def("get_int8_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getI8Type(); }) - .def("get_int16_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getType(16)); + .def("get_int16_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getType(16); }) - .def("get_int32_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getI32Type()); + .def("get_int32_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getI32Type(); }) - .def("get_int64_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getI64Type()); + .def("get_int64_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getI64Type(); }) - .def("get_fp8_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getType()); + .def("get_fp8_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getType(); }) - .def("get_bf8_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getType()); + .def("get_bf8_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getType(); }) - .def("get_half_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getF16Type()); + .def("get_half_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getF16Type(); }) - .def("get_bf16_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getBF16Type()); + .def("get_bf16_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getBF16Type(); }) - .def("get_float_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getF32Type()); + .def("get_float_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getF32Type(); }) - .def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType { - return wrap(self.getF64Type()); + .def("get_double_ty", [](mlir::OpBuilder &self) -> mlir::Type { + return self.getF64Type(); }) - .def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType { - return wrap( - mlir::triton::PointerType::get(unwrap(type), addrSpace) - ); + .def("get_ptr_ty", [](mlir::OpBuilder &self, mlir::Type &type, int addrSpace) -> mlir::Type { + return mlir::triton::PointerType::get(type, addrSpace); }) - .def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType, - std::vector &shape) -> MlirType { - return wrap( - mlir::RankedTensorType::get(shape, unwrap(elementType)) - ); + .def("get_block_ty", [](mlir::OpBuilder &self, mlir::Type &elementType, + std::vector &shape) -> mlir::Type { + return mlir::RankedTensorType::get(shape, elementType); }) .def("get_function_ty", [](mlir::OpBuilder &self, - std::vector inTypes, - std::vector outTypes) -> MlirType { - llvm::SmallVector inputsTypeList; - llvm::SmallVector resultsTypeList; - (void)unwrapList(inTypes.size(), inTypes.data(), inputsTypeList); - (void)unwrapList(outTypes.size(), outTypes.data(), resultsTypeList); - return wrap(self.getFunctionType(inputsTypeList, resultsTypeList)); + std::vector inTypes, + std::vector outTypes) -> mlir::Type { + return self.getFunctionType(inTypes, outTypes); }) // Ops - .def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp { + .def("create_function", [](mlir::OpBuilder &self, std::string name, mlir::Type &funcType) -> mlir::FuncOp { // TODO: loc auto loc = self.getUnknownLoc(); - if (auto funcTy = unwrap(funcType).dyn_cast()) { + if (auto funcTy = funcType.dyn_cast()) { return self.create(loc, name, funcTy); } throw std::runtime_error("invalid function type"); }) + .def("create_block", [](mlir::OpBuilder &self) -> mlir::Block* { + mlir::Region *parent = self.getBlock()->getParent(); + return self.createBlock(parent); + }, ret::reference) // Structured control flow .def("create_for_op", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub, mlir::Value &step, std::vector &initArgs) -> mlir::scf::ForOp { @@ -878,35 +874,35 @@ void init_triton_ir(py::module &&m) { }) // Cast instructions - .def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) // .def("create_cast", &ir::builder::create_cast) // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) - .def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) - .def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) - .def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) - .def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) - .def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) - .def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { + .def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, mlir::Type &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, unwrap(dstType), src); + return self.create(loc, dstType, src); }) // .def("create_int_cast", &ir::builder::create_int_cast) // .def("create_downcast", &ir::builder::create_downcast) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 02ba697e7..faf264de2 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -44,11 +44,8 @@ class CodeGenerator(ast.NodeVisitor): 'getattr': getattr, } # SSA-construction - # [name, bb] => triton.language.tensor - self.lvalues: Dict[Tuple[str, _triton.ir.basic_block], triton.language.tensor] = {} - # bb => {name => phi} - self.incomplete_phis = {} - self.sealed_blocks: Set[_triton.ir.basic_block] = set() + # name => triton.language.tensor + self.local_defs: Dict[str, triton.language.tensor] = {} def get_value(self, name): ''' This function: @@ -67,10 +64,7 @@ class CodeGenerator(ast.NodeVisitor): elif name in self.builtins: ret = self.builtins[name] else: - print(self.lscope) raise ValueError(f'{name} is not defined') - if self.is_triton_tensor(ret): - return self._get_tensor(name, self.builder.get_insertion_block()) return ret def set_value(self, name: str, @@ -81,99 +75,7 @@ class CodeGenerator(ast.NodeVisitor): 2. store tensor in self.lvalue ''' self.lscope[name] = value - if isinstance(value, triton.language.tensor): - self._set_value(name, self.builder.get_insertion_block(), value) - - # - # SSA-construction - # - def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: - if not bb: - bb = self.builder.get_insertion_block() - # local value numbering - if (name, bb) in self.lvalues: - return self.lvalues[(name, bb)] - # param. FIXME: should delete this - if (name, None) in self.lvalues: - return self.lvalues[(name, None)] - print(self.lvalues) - assert False, f'Cannot find {name} in {bb}' - # global value numbering - saved_insert_point = self.builder.get_insert_point() - result = self._get_tensor_recursive(name, bb) - self.builder.set_insert_point(saved_insert_point) - return result - - def _get_tensor_recursive(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: - preds = bb.get_predecessors() - type = self.lscope[name].type - # some preds haven't been filled, create a phi as a proxy of the value - if bb not in self.sealed_blocks: - result = self._make_phi(type, len(preds), bb) - if bb in self.incomplete_phis: - self.incomplete_phis[bb][name] = result - else: - self.incomplete_phis[bb] = {name: result} - elif len(preds) == 1: - # one predecessor: no phi needed, try get value from pred - result = self._get_tensor(name, preds[0]) - elif len(preds) == 0: - result = self._get_tensor(name, None) - else: # multiple preds - phi = self._make_phi(type, len(preds), bb) - self._set_value(name, bb, phi) - result = self._add_phi_operands(name, phi) - self._set_value(name, bb, result) - return result - - # returns a new phi tensor, which encausulate an ir.phi_node - def _make_phi(self, - type: triton.language.dtype, - num_values: int, - bb: _triton.ir.basic_block) -> triton.language.tensor: - instr = bb.get_first_non_phi() - self.builder.set_insert_point((bb, instr)) - ir_phi = self.builder.create_phi(type.to_ir(self.builder), num_values) - if instr: - self.builder.set_insert_block(bb) - return triton.language.tensor(ir_phi, type) - - # complete a phi node. (TODO: rename this as _complete_phis?) - # Note: since we try to remove tryival phi, the return tensor might not be a phi - def _add_phi_operands(self, name: str, - phi: triton.language.tensor) -> triton.language.tensor: - bb = phi.handle.get_parent() - for pred in bb.get_predecessors(): - v = self._get_tensor(name, pred) - phi.handle.add_incoming(v.handle, pred) - phi = self._try_remove_trivial_phi(phi) - return phi - - def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: - self.lvalues[(name, bb)] = value - # # TODO: why we need this? - # self.module.set_instr_metadata(name, value.handle) - - def _seal_block(self, bb: _triton.ir.basic_block): - # complete all incomplete phis - if bb in self.incomplete_phis: - for name, phi in self.incomplete_phis[bb].items(): - result = self._add_phi_operands(name, phi) - # it's possible that this phi is trivial - if self._get_tensor(name, bb).handle == phi.handle: - self._set_value(name, bb, result) - del self.incomplete_phis[bb] - self.sealed_blocks.add(bb) - - def _try_remove_trivial_phi(self, phi: triton.language.tensor) -> triton.language.tensor: - unique_handles = {op for op in phi.handle.ops() if op != phi.handle} - if len(unique_handles) != 1: # non-trivial phi - return phi - v = unique_handles.pop() - phi.handle.replace_all_uses_with(v) - phi.handle.erase_from_parent() - # TODO: remove trivial phis recursively - return triton.language.tensor(v, phi.type) + self.local_defs[name] = value def is_triton_tensor(self, value): return isinstance(value, triton.language.tensor) @@ -229,7 +131,6 @@ class CodeGenerator(ast.NodeVisitor): fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder)) self.module.push_back(fn) entry = fn.add_entry_block() - self._seal_block(entry) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -294,6 +195,7 @@ class CodeGenerator(ast.NodeVisitor): assert len(_names) == 1 names = _names[0] values = self.visit(node.value) + print(f'visit_Assign({names}, {values})') if not isinstance(names, tuple): names = [names] if not isinstance(values, tuple): @@ -367,9 +269,7 @@ class CodeGenerator(ast.NodeVisitor): # then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) # else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None # endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) - # self._seal_block(then_bb) # if else_bb: - # self._seal_block(else_bb) # self.builder.cond_br(cond.handle, then_bb, else_bb) # else: # self.builder.cond_br(cond.handle, then_bb, endif_bb) @@ -384,17 +284,16 @@ class CodeGenerator(ast.NodeVisitor): # # TODO: last statement is a terminator? # if not is_terminator: # self.builder.br(endif_bb) - # self._seal_block(endif_bb) # self.builder.set_insert_block(endif_bb) - parent_lvalues = self.lvalues.copy() + parent_values = self.lscope.copy() self.visit_compound_statement(node.body) - then_lvalues = self.lvalues.copy() + then_values = self.lvalues.copy() assert node.orelse - self.lvalues = parent_lvalues + self.lscope = parent_values self.visit_compound_statement(node.orelse) - else_lvalues = self.lvalues.copy() + else_values = self.lscope.copy() - self.lvalues = join_if_lvalues(then_lvalues, else_lvalues) + self.lvalues = join_if_lvalues(then_values, else_values) else: if isinstance(cond, triton.language.constexpr): @@ -473,9 +372,6 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) continue_fn() stop_bb = self.builder.get_insertion_block() - self._seal_block(stop_bb) - self._seal_block(loop_bb) - self._seal_block(next_bb) self.builder.set_insert_block(next_bb) for stmt in node.orelse: @@ -510,47 +406,7 @@ class CodeGenerator(ast.NodeVisitor): for stmt in node.orelse: ast.NodeVisitor.generic_visit(self, stmt) return - # # create nodes - # st_target = ast.Name(id=node.target.id, ctx=ast.Store()) - # ld_target = ast.Name(id=node.target.id, ctx=ast.Load()) - # arg_0 = node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0) - # arg_1 = node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0] - # arg_2 = node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1) - # init_node = ast.Assign(targets=[st_target], value=arg_0) - # pos_cond_node = ast.Compare(ld_target, [ast.Lt()], [arg_1]) - # neg_cond_node = ast.Compare(ld_target, [ast.Gt()], [arg_1]) - # pos_step_node = ast.Compare(arg_2, [ast.Gt()], [ast.Num(0)]) - # build_cond = lambda: triton.language.where(self.visit(pos_step_node), - # self.visit(pos_cond_node), - # self.visit(neg_cond_node), - # _builder=self.builder) - # # cond_node = neg_cond_node - # step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) - # # code generation - # current_bb = self.builder.get_insertion_block() - # loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) - # next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) - # def continue_fn(): - # self.visit(step_node) - # cond = build_cond() - # return self.builder.cond_br(cond.handle, loop_bb, next_bb) - - # self.visit(init_node) - # cond = build_cond() - # self.builder.cond_br(cond.handle, loop_bb, next_bb) - # self.builder.set_insert_block(loop_bb) - # self.visit_compound_statement(node.body) - # # TODO: handle case where body breaks control flow - # continue_fn() - # stop_bb = self.builder.get_insertion_block() - # self._seal_block(stop_bb) - # self._seal_block(loop_bb) - # self._seal_block(next_bb) - # self.builder.set_insert_block(next_bb) - - # for stmt in node.orelse: - # ast.NodeVisitor.generic_visit(self, stmt) lb = self.visit(node.iter.args[0] if len(node.iter.args) > 1 else ast.Num(0)) ub = self.visit(node.iter.args[1] if len(node.iter.args) > 1 else node.iter.args[0]) step = self.visit(node.iter.args[2] if len(node.iter.args) > 2 else ast.Num(1)) @@ -558,22 +414,32 @@ class CodeGenerator(ast.NodeVisitor): loop_body = self.builder.create_block() self.builder.set_insertion_point_to_start(loop_body) + liveins = self.lscope.copy() prev_defs = self.local_defs.copy() - self.local_defs = set() + self.local_defs = {} # visit loop body - parent_lvalues = self.lvalues.copy() - self.visit_compound_statement() - loop_lvalues = self.lvalues.copy() + self.visit_compound_statement(node.body) # TODO: update insertion point - # TODO: create scf.forOp - # self.lvalues = join_loop_lvalues(parent_lvalues, loop_lvalues) - # self.make_for_op(parent_lvalues, loop_lvalues, lb, ub, step) - for_op = self.builder.create_for_op(lb, ub, step, [loop_init_args]) + init_args = {} + yields = {} + for name in self.local_defs: + if name in liveins: + assert self.is_triton_tensor(self.local_defs[name]) + assert self.is_triton_tensor(liveins[name]) + if self.local_defs[name].type == liveins[name].type: + init_args[name] = liveins[name] + yields[name] = self.local_defs[name] + # for_op = self.builder.create_for_op(lb, ub, step, [init_args]) + self.lscope = liveins self.local_defs = prev_defs + for stmt in node.orelse: + assert False + ast.NodeVisitor.generic_visit(self, stmt) + def visit_Slice(self, node): lower = self.visit(node.lower) upper = self.visit(node.upper) @@ -1079,17 +945,12 @@ class JITFunction: lscope = generator.lscope.copy() # TODO: clear values other than args - lvalues = generator.lvalues.copy() - # types = generator.module.get_types().copy() generator.gscope = sys.modules[self.fn.__module__].__dict__ generator.lscope = dict() ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=arg_values) generator.gscope = gscope generator.lscope = lscope - generator.lvalues = lvalues - # generator.module.set_types(types) - return ret except Exception as e: node = generator.last_node @@ -1222,9 +1083,6 @@ class JITFunction: constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - print(f'wargs: {wargs}') - print(f'constants: {constants}') - print(f'arg_types: {arg_types}') # create IR module context = _triton.ir.context() context.load_triton() diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 53bcc6d3e..0f026b717 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -312,7 +312,7 @@ def minus(input: tl.tensor, input_sca_ty = input.type.scalar if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") - _0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) return sub(_0, input, builder) @@ -442,7 +442,7 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - _0 = ir.constant.get_null_value(dtype.to_ir(builder)) + _0 = builder.get_null_value(dtype.to_ir(builder)) ret_ty = tl.block_type(dtype, shape) return tl.tensor(builder.create_splat(_0, shape), ret_ty)