From e381dc72c50cf79b75877e2e1990496959575a0f Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 30 Mar 2022 16:31:03 +0800 Subject: [PATCH] Use mlir::Block to replace MlirBlock --- python/src/triton.cc | 44 ++++++++++++++-------- python/triton/code_gen.py | 68 ++++++++++++++++++---------------- python/triton/language/core.py | 15 ++++---- rewrite-test/test_ir.py | 6 ++- 4 files changed, 76 insertions(+), 57 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index c486ba299..d56170363 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -668,12 +668,12 @@ void init_triton_ir(py::module &&m) { ; py::class_(m, "operation") - .def("add_entry_block", [](MlirOperation &self) -> MlirBlock { + .def("add_entry_block", [](MlirOperation &self) -> mlir::Block { if (auto info = unwrap(self)->getRegisteredInfo()) { if (mlir::TypeID::get() == info->getTypeID()) { auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self)); mlir::Block *entry = funcOp.addEntryBlock(); - return wrap(entry); + return *entry; } throw std::runtime_error("Only FuncOp can call add_entry_block"); } else @@ -684,12 +684,14 @@ void init_triton_ir(py::module &&m) { }) ; - py::class_(m, "value") + py::class_(m, "value") + ; + py::class_(m, "block_arguement") ; - py::class_(m, "block") - .def("arg", [](MlirBlock &self, int index) -> MlirValue { - return wrap(unwrap(self)->getArgument(index)); + py::class_(m, "block") + .def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument { + return self.getArgument(index); }) ; @@ -741,12 +743,16 @@ void init_triton_ir(py::module &&m) { // .def("br", &ir::builder::create_br, ret::reference) // .def("cond_br", &ir::builder::create_cond_br, ret::reference) // .def("ret_void", &ir::builder::create_ret_void, ret::reference) - // // insertion block/point, insert points are represented as (*bb, *instr) - .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, MlirBlock &block) -> void{ - self.setInsertionPointToStart(unwrap(block)); + // insertion block/point + .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void { + self.setInsertionPointToStart(&block); }) - // .def("get_insert_block", &ir::builder::get_insert_block, ret::reference) - // .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) + .def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) { + self.setInsertionPointToEnd(&block); + }) + .def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & { + return *self.getInsertionBlock(); + }, ret::reference) // .def("get_insert_point", [](ir::builder *self) { // ir::basic_block *bb = self->get_insert_block(); // ir::basic_block::iterator it = self->get_insert_point(); @@ -768,11 +774,11 @@ void init_triton_ir(py::module &&m) { // Use arith.ConstantOp to create constants // // Constants // .def("get_int1", &ir::builder::get_int1, ret::reference) - .def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> MlirValue { + .def("get_int32", [](mlir::OpBuilder &self, int64_t v) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( + return mlir::Value(self.create( loc, v, self.getI32Type() - ))); + )); }) // .def("get_uint32", &ir::builder::get_int32, ret::reference) // .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) @@ -818,9 +824,15 @@ void init_triton_ir(py::module &&m) { .def("get_double_ty", [](mlir::OpBuilder &self) -> MlirType { return wrap(self.getF64Type()); }) - .def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type) -> MlirType { + .def("get_ptr_ty", [](mlir::OpBuilder &self, MlirType &type, int addrSpace) -> MlirType { return wrap( - mlir::triton::PointerType::get(unwrap(type)) + mlir::triton::PointerType::get(unwrap(type), addrSpace) + ); + }) + .def("get_block_ty", [](mlir::OpBuilder &self, MlirType &elementType, + std::vector &shape) -> MlirType { + return wrap( + mlir::RankedTensorType::get(shape, unwrap(elementType)) ); }) .def("get_function_ty", [](mlir::OpBuilder &self, diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index e61eef451..1e743ffb0 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -30,7 +30,6 @@ class CodeGenerator(ast.NodeVisitor): self.prototype = prototype self.gscope = gscope self.lscope = dict() - self.is_arg_lscope = dict() # name => is_arg: {str: bool} self.attributes = attributes self.constants = constants self.kwargs = kwargs @@ -69,33 +68,32 @@ class CodeGenerator(ast.NodeVisitor): ret = self.builtins[name] else: raise ValueError(f'{name} is not defined') - if self.is_triton_tensor(ret) and not self.is_arg_lscope[name]: + if self.is_triton_tensor(ret): return self._get_tensor(name) return ret def set_value(self, name: str, - value: Union[triton.language.tensor, triton.language.constexpr], - is_arg: bool = False) -> None: + value: Union[triton.language.tensor, triton.language.constexpr]) -> None: ''' This function: called by visit_Assign() & visit_FuncDef() to store left value (lvalue) 1. record local defined name (FIXME: should consider control flow) 2. store tensor in self.lvalue ''' self.lscope[name] = value - # if this value is an argument, we don't need to create phis for it - self.is_arg_lscope[name] = is_arg - if isinstance(value, triton.language.tensor) and not is_arg: - self._set_value(name, self.builder.get_insert_block(), value) + if isinstance(value, triton.language.tensor): + self._set_value(name, self.builder.get_insertion_block(), value) # # SSA-construction # def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: if not bb: - bb = self.builder.get_insert_block() + bb = self.builder.get_insertion_block() # local value numbering if (name, bb) in self.lvalues: return self.lvalues[(name, bb)] + print(self.lvalues) + assert False, f'Cannot find {name} in {bb}' # global value numbering saved_insert_point = self.builder.get_insert_point() result = self._get_tensor_recursive(name, bb) @@ -115,8 +113,9 @@ class CodeGenerator(ast.NodeVisitor): elif len(preds) == 1: # one predecessor: no phi needed, try get value from pred result = self._get_tensor(name, preds[0]) + elif len(preds) == 0: + result = self._get_tensor(name, None) else: # multiple preds - assert len(preds) > 1, f'{name} is an undefined name (cannot find in the entry block)' phi = self._make_phi(type, len(preds), bb) self._set_value(name, bb, phi) result = self._add_phi_operands(name, phi) @@ -148,8 +147,8 @@ class CodeGenerator(ast.NodeVisitor): def _set_value(self, name: str, bb: _triton.ir.basic_block, value: triton.language.tensor) -> None: self.lvalues[(name, bb)] = value - # TODO: why we need this? - self.module.set_instr_metadata(name, value.handle) + # # TODO: why we need this? + # self.module.set_instr_metadata(name, value.handle) def _seal_block(self, bb: _triton.ir.basic_block): # complete all incomplete phis @@ -220,7 +219,8 @@ class CodeGenerator(ast.NodeVisitor): if inline: pass else: - fn = self.module.get_or_insert_function(node.name, self.prototype.to_ir(self.builder)) + fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder)) + self.module.push_back(fn) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -230,25 +230,27 @@ class CodeGenerator(ast.NodeVisitor): cst = triton.language.constexpr(self.constants[i]) arg_values.append(cst) else: - if i in self.attributes: - is_ptr = fn.args[idx].type.is_ptr() - attr = 'aligned' if is_ptr else 'multiple_of' - attr = getattr(_triton.ir.attribute_kind, attr) - attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(idx + 1, attr) - fn.args[idx].name = arg_name - arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) - idx += 1 + pass + # TODO: ... + # if i in self.attributes: + # is_ptr = fn.args[idx].type.is_ptr() + # attr = 'aligned' if is_ptr else 'multiple_of' + # attr = getattr(_triton.ir.attribute_kind, attr) + # attr = _triton.ir.attribute(attr, self.attributes[i]) + # fn.add_attr(idx + 1, attr) + # fn.args[idx].name = arg_name + # arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) + # idx += 1 for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value, is_arg=True) + self.set_value(arg_name, arg_value) if inline: self.visit_compound_statement(node.body) return self.last_ret else: - entry = _triton.ir.basic_block.create(self.builder.context, "entry", fn) + entry = fn.add_entry_block() self._seal_block(entry) - self.builder.set_insert_block(entry) + self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) # finalize function @@ -358,7 +360,7 @@ class CodeGenerator(ast.NodeVisitor): cond = self.visit(node.test) if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) - current_bb = self.builder.get_insert_block() + current_bb = self.builder.get_insertion_block() then_bb = _triton.ir.basic_block.create(self.builder.context, "then", current_bb.parent) else_bb = _triton.ir.basic_block.create(self.builder.context, "else", current_bb.parent) if node.orelse else None endif_bb = _triton.ir.basic_block.create(self.builder.context, "endif", current_bb.parent) @@ -445,7 +447,7 @@ class CodeGenerator(ast.NodeVisitor): return getattr(op, fn)() def visit_While(self, node): - current_bb = self.builder.get_insert_block() + current_bb = self.builder.get_insertion_block() loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) @@ -457,7 +459,7 @@ class CodeGenerator(ast.NodeVisitor): self.builder.set_insert_block(loop_bb) self.visit_compound_statement(node.body) continue_fn() - stop_bb = self.builder.get_insert_block() + stop_bb = self.builder.get_insertion_block() self._seal_block(stop_bb) self._seal_block(loop_bb) self._seal_block(next_bb) @@ -512,7 +514,7 @@ class CodeGenerator(ast.NodeVisitor): # cond_node = neg_cond_node step_node = ast.AugAssign(target=st_target, op=ast.Add(), value=arg_2) # code generation - current_bb = self.builder.get_insert_block() + current_bb = self.builder.get_insertion_block() loop_bb = _triton.ir.basic_block.create(self.builder.context, "loop", current_bb.parent) next_bb = _triton.ir.basic_block.create(self.builder.context, "postloop", current_bb.parent) @@ -528,7 +530,7 @@ class CodeGenerator(ast.NodeVisitor): self.visit_compound_statement(node.body) # TODO: handle case where body breaks control flow continue_fn() - stop_bb = self.builder.get_insert_block() + stop_bb = self.builder.get_insertion_block() self._seal_block(stop_bb) self._seal_block(loop_bb) self._seal_block(next_bb) @@ -845,10 +847,11 @@ class Kernel: # create IR module context = _triton.ir.context() + context.load_triton() # get just-in-time proto-type of kernel arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] ret_type = triton.language.void - prototype = triton.language.function_type(ret_type, arg_types) + prototype = triton.language.function_type([ret_type], arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ @@ -1179,10 +1182,11 @@ class JITFunction: def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): # create IR module context = _triton.ir.context() + context.load_triton() # get just-in-time proto-type of kernel arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] ret_type = triton.language.void - prototype = triton.language.function_type(ret_type, arg_types) + prototype = triton.language.function_type([ret_type], arg_types) # generate Triton-IR # export symbols visible from self into code-generator object gscope = self.__globals__ diff --git a/python/triton/language/core.py b/python/triton/language/core.py index ea1b24940..c07c44cfd 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -209,7 +209,7 @@ class pointer_type(dtype): self.name = self.__str__() def to_ir(self, builder: ir.builder) -> ir.pointer_type: - return ir.type.make_ptr(self.element_ty.to_ir(builder), 1) + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) def __str__(self): return f'pointer<{self.element_ty}>' @@ -247,7 +247,7 @@ class block_type(dtype): self.name = self.__str__() def to_ir(self, builder: ir.builder) -> ir.block_type: - return ir.type.make_block(self.element_ty.to_ir(builder), self.shape) + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) def __str__(self): return f'<{self.shape}, {self.element_ty}>' @@ -275,8 +275,8 @@ class block_type(dtype): class function_type(dtype): - def __init__(self, ret_type: dtype, param_types: List[dtype]) -> None: - self.ret_type = ret_type + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types self.param_types = param_types def __str__(self): @@ -284,7 +284,8 @@ class function_type(dtype): def to_ir(self, builder: ir.builder): ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - return ir.type.make_function(self.ret_type.to_ir(builder), ir_param_types) + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) # scalar types @@ -425,8 +426,8 @@ class tensor: self.handle = handle # Block shape self.shape = (1, ) - if self.handle.type.is_block(): - self.shape = self.handle.type.shape + if type.is_block(): + self.shape = type.shape self.numel = 1 for s in self.shape: self.numel *= s diff --git a/rewrite-test/test_ir.py b/rewrite-test/test_ir.py index 2008f9c4e..a3077f4d5 100644 --- a/rewrite-test/test_ir.py +++ b/rewrite-test/test_ir.py @@ -17,11 +17,13 @@ i64_ty = builder.get_int64_ty() f16_ty = builder.get_half_ty() -f16_ptr_ty = builder.get_ptr_ty(f16_ty) +f16_ptr_ty = builder.get_ptr_ty(f16_ty, 1) func_ty = builder.get_function_ty([f16_ptr_ty, f16_ptr_ty, f16_ptr_ty], []) func = builder.create_function('foo', func_ty) +module.push_back(func) + # ... entry = func.add_entry_block() builder.set_insertion_point_to_start(entry) @@ -51,5 +53,5 @@ builder.create_store(c_ptrs, c) # func.dump() -module.push_back(func) + module.dump()