From 2041b67fbf83a00bc931b1b30fcd6c6752d1471f Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 30 Mar 2022 20:21:47 +0800 Subject: [PATCH] Now vecadd works --- include/triton/ir/TritonOps.td | 4 +- python/src/triton.cc | 539 ++++++++++++----------------- python/triton/code_gen.py | 120 ++++--- python/triton/language/core.py | 2 +- python/triton/language/semantic.py | 6 +- 5 files changed, 285 insertions(+), 386 deletions(-) diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 60b768e32..5c04398a8 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -114,7 +114,9 @@ def TT_EvictionPolicyAttr : I32EnumAttr< def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> { let summary = "load"; - let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other); + let arguments = (ins TT_PtrTensor:$ptr, BoolLike:$mask, TT_Type:$other, + TT_CacheModifierAttr:$cache, TT_EvictionPolicyAttr:$evict, + BoolAttr:$isVolatile); let results = (outs TT_Type:$result); diff --git a/python/src/triton.cc b/python/src/triton.cc index d56170363..201f122e1 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -641,21 +641,15 @@ void init_triton_ir(py::module &&m) { // // py::class_(m, "undef") // // .def("get", &ir::undef_value::get, ret::reference); - py::class_(m, "module") + py::class_(m, "module") // .def("set_attr") - .def("dump", [](MlirModule &self) -> void { - unwrap(self).dump(); + .def("dump", [](mlir::ModuleOp &self) -> void { + self.dump(); }) - .def("push_back", [](MlirModule &self, MlirOperation &funcOperation) { - if (auto info = unwrap(funcOperation)->getRegisteredInfo()) { - if (mlir::TypeID::get() == info->getTypeID()) { - auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(funcOperation)); - unwrap(self).push_back(funcOp); - } else - throw std::runtime_error("Only FuncOp can call push_back"); - } else - throw std::runtime_error("Unknown error"); + .def("push_back", [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { + self.push_back(funcOp); }) + .def("get_context", &mlir::ModuleOp::getContext) ; py::class_(m, "type") @@ -667,23 +661,6 @@ void init_triton_ir(py::module &&m) { }) ; - py::class_(m, "operation") - .def("add_entry_block", [](MlirOperation &self) -> mlir::Block { - if (auto info = unwrap(self)->getRegisteredInfo()) { - if (mlir::TypeID::get() == info->getTypeID()) { - auto funcOp = mlir::FuncOp::getFromOpaquePointer(unwrap(self)); - mlir::Block *entry = funcOp.addEntryBlock(); - return *entry; - } - throw std::runtime_error("Only FuncOp can call add_entry_block"); - } else - throw std::runtime_error("Unknown error"); - }) // this should be automatic? - .def("dump", [](MlirOperation &self) -> void { - unwrap(self)->dump(); - }) - ; - py::class_(m, "value") ; py::class_(m, "block_arguement") @@ -693,6 +670,7 @@ void init_triton_ir(py::module &&m) { .def("arg", [](mlir::Block &self, int index) -> mlir::BlockArgument { return self.getArgument(index); }) + .def("dump", &mlir::Block::dump) ; // py::class_(m, "module") @@ -720,29 +698,34 @@ void init_triton_ir(py::module &&m) { // py::class_(m, "attribute"); // // .def(py::init()); - // py::class_(m, "function") - // .def_property_readonly("args", &ir::function::args) - // .def_property_readonly("attrs", &ir::function::attrs) - // .def("add_attr", &ir::function::add_attr); - - // // // We don't need to expose mlir::Block (?) - // // py::class_(m, "basic_block") - // // // .def("create", &ir::basic_block::create, ret::reference) - // // .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference) - // // .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference); + py::class_(m, "function") + // .def_property_readonly("args", &ir::function::args) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument { + return self.getArgument(idx); + }) + .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* { + return self.addEntryBlock(); + }, ret::reference) + .def("dump", [](mlir::FuncOp &self) { self.dump(); }) + ; py::class_(m, "builder", py::dynamic_attr()) .def(py::init()) // // getters // .def_property_readonly("context", &ir::builder::get_context, ret::reference); - .def("create_module", [](mlir::OpBuilder &self) -> MlirModule { + .def("create_module", [](mlir::OpBuilder &self) -> mlir::ModuleOp { auto loc = self.getUnknownLoc(); - return wrap(self.create(loc)); + return self.create(loc); }) // // control flow // .def("br", &ir::builder::create_br, ret::reference) // .def("cond_br", &ir::builder::create_cond_br, ret::reference) - // .def("ret_void", &ir::builder::create_ret_void, ret::reference) + .def("ret_void", [](mlir::OpBuilder &self) { + auto loc = self.getUnknownLoc(); + self.create(loc); + }, ret::reference) // insertion block/point .def("set_insertion_point_to_start", [](mlir::OpBuilder &self, mlir::Block &block) -> void { self.setInsertionPointToStart(&block); @@ -750,8 +733,8 @@ void init_triton_ir(py::module &&m) { .def("set_insertion_point_to_end", [](mlir::OpBuilder &self, mlir::Block &block) { self.setInsertionPointToEnd(&block); }) - .def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block & { - return *self.getInsertionBlock(); + .def("get_insertion_block", [](mlir::OpBuilder &self) -> mlir::Block* { + return self.getInsertionBlock(); }, ret::reference) // .def("get_insert_point", [](ir::builder *self) { // ir::basic_block *bb = self->get_insert_block(); @@ -784,8 +767,10 @@ void init_triton_ir(py::module &&m) { // .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) // .def("get_uint64", &ir::builder::get_int64, ret::reference) // .def("get_float16", &ir::builder::get_float16, ret::reference) - // .def("get_float32", &ir::builder::get_float32, ret::reference) - // .def("get_range", &ir::builder::get_range, ret::reference) + .def("get_float32", [](mlir::OpBuilder &self, float v) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, self.getF32FloatAttr(v)); + }) // Types .def("get_void_ty", [](mlir::OpBuilder &self) ->MlirType { @@ -846,22 +831,22 @@ void init_triton_ir(py::module &&m) { }) // Ops - .def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> MlirOperation { + .def("create_function", [](mlir::OpBuilder &self, std::string name, MlirType funcType) -> mlir::FuncOp { // TODO: loc auto loc = self.getUnknownLoc(); if (auto funcTy = unwrap(funcType).dyn_cast()) { - return wrap(self.create(loc, name, funcTy)); + return self.create(loc, name, funcTy); } throw std::runtime_error("invalid function type"); }) // // Structured control flow - // .def("create_for", [](mlir::OpBuilder &self, MlirValue &lb, MlirValue &ub, - // MlirValue &step, std::vector &initArgs) -> MlirOperation { + // .def("create_for", [](mlir::OpBuilder &self, mlir::Value &lb, mlir::Value &ub, + // mlir::Value &step, std::vector &initArgs) -> MlirOperation { // auto loc = self.getUnknownLoc(); // return wrap(self.create( // loc, unwrap(lb), unwrap(ub), unwrap(step)).getOperation()); // }) - // .def("create_if", [](mlir::OpBuilder &self, MlirValue &condition) -> MlirOperation { + // .def("create_if", [](mlir::OpBuilder &self, mlir::Value &condition) -> MlirOperation { // auto loc = self.getUnknownLoc(); // return wrap(self.create(loc, unwrap(condition)).getOperation()); // }) @@ -872,428 +857,334 @@ void init_triton_ir(py::module &&m) { // // .def("create_while") // miscellious - .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> MlirValue { + .def("create_make_range", [](mlir::OpBuilder &self, int start, int end) -> mlir::Value { auto loc = self.getUnknownLoc(); auto retType = mlir::RankedTensorType::get({end-start}, self.getI32Type()); - return wrap( - mlir::Value(self.create(loc, retType, start, end)) - ); + return self.create(loc, retType, start, end); }) - .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue { + .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, self.getI32Type(), axis)) - ); + return self.create(loc, self.getI32Type(), axis); }) // Cast instructions - .def("create_bitcast", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_bitcast", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) // .def("create_cast", &ir::builder::create_cast) // .def("create_ptr_to_int", &ir::builder::create_ptr_to_int) - .def("create_si_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_si_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) - .def("create_ui_to_fp", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_ui_to_fp", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) - .def("create_fp_to_si", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_fp_to_si", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) - .def("create_fp_to_ui", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_fp_to_ui", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) - .def("create_fp_ext", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_fp_ext", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) - .def("create_fp_trunc", [](mlir::OpBuilder &self, MlirValue &src, MlirType &dstType) -> MlirValue { + .def("create_fp_trunc", [](mlir::OpBuilder &self, mlir::Value &src, MlirType &dstType) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(dstType), unwrap(src)) - )); + return self.create(loc, unwrap(dstType), src); }) // .def("create_int_cast", &ir::builder::create_int_cast) // .def("create_downcast", &ir::builder::create_downcast) - .def("create_fmul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(lhs), unwrap(rhs)) - )); + return self.create(loc, lhs, rhs); }) - .def("create_fdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fdiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(lhs), unwrap(rhs)) - )); + return self.create(loc, lhs, rhs); }) - .def("create_frem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_frem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(lhs), unwrap(rhs)) - )); + return self.create(loc, lhs, rhs); }) - .def("create_fadd", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fadd", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(lhs), unwrap(rhs)) - )); + return self.create(loc, lhs, rhs); }) - .def("create_fsub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fsub", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(lhs), unwrap(rhs)) - )); + return self.create(loc, lhs, rhs); }) - .def("create_mul", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_mul", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - // Check lhs & rhs have single result (?) - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return self.create(loc, lhs, rhs); }) - .def("create_sdiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_sdiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return self.create(loc, lhs, rhs); }) - .def("create_udiv", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_udiv", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return self.create(loc, lhs, rhs); }) - .def("create_srem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_srem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return self.create(loc, lhs, rhs); }) - .def("create_urem", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_urem", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return self.create(loc, lhs, rhs); }) - .def("create_add", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_add", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return self.create(loc, lhs, rhs); }) - .def("create_sub", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_sub", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return mlir::Value(self.create(loc, lhs, rhs)); }) - .def("create_shl", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_shl", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return mlir::Value(self.create(loc, lhs, rhs)); }) - .def("create_lshr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_lshr", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return mlir::Value(self.create(loc, lhs, rhs)); }) - .def("create_ashr", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_ashr", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(lhs), unwrap(rhs))) - ); + return mlir::Value(self.create(loc, lhs, rhs)); }) // GEP - .def("create_gep", [](mlir::OpBuilder &self, MlirValue &ptr, MlirValue &offset) -> MlirValue { + .def("create_gep", [](mlir::OpBuilder &self, mlir::Value &ptr, mlir::Value &offset) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap( - mlir::Value(self.create(loc, unwrap(ptr).getType(), unwrap(ptr), unwrap(offset))) - ); + return self.create(loc, ptr.getType(), ptr, offset); }) // Comparison (int) - .def("create_icmpSLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpSLE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::sle, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::sle, lhs, rhs); }) - .def("create_icmpSLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpSLT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::slt, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::slt, lhs, rhs); }) - .def("create_icmpSGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpSGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::sge, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::sge, lhs, rhs); }) - .def("create_icmpSGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpSGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::sgt, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::sgt, lhs, rhs); }) - .def("create_icmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpULE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::ule, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::ule, lhs, rhs); }) - .def("create_icmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpULT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::ult, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::ult, lhs, rhs); }) - .def("create_icmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpUGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::uge, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::uge, lhs, rhs); }) - .def("create_icmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpUGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::ugt, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::ugt, lhs, rhs); }) - .def("create_icmpEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::eq, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::eq, lhs, rhs); }) - .def("create_icmpNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_icmpNE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpIPredicate::ne, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpIPredicate::ne, lhs, rhs); }) // Comparison (float) - .def("create_fcmpOLT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpOLT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::OLT, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::OLT, lhs, rhs); }) - .def("create_fcmpOGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpOGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::OGT, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::OGT, lhs, rhs); }) - .def("create_fcmpOLE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpOLE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::OLE, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::OLE, lhs, rhs); }) - .def("create_fcmpOGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpOGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::OGE, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::OGE, lhs, rhs); }) - .def("create_fcmpOEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpOEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::OEQ, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::OEQ, lhs, rhs); }) - .def("create_fcmpONE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpONE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::ONE, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::ONE, lhs, rhs); }) - .def("create_fcmpULT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpULT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::ULT, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::ULT, lhs, rhs); }) - .def("create_fcmpUGT", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpUGT", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::UGT, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::UGT, lhs, rhs); }) - .def("create_fcmpULE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpULE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::ULE, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::ULE, lhs, rhs); }) - .def("create_fcmpUGE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpUGE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::UGE, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::UGE, lhs, rhs); }) - .def("create_fcmpUEQ", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpUEQ", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::UEQ, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::UEQ, lhs, rhs); }) - .def("create_fcmpUNE", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_fcmpUNE", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, mlir::arith::CmpFPredicate::UNE, - unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::arith::CmpFPredicate::UNE, lhs, rhs); }) // // Logical - .def("create_and", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_and", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, unwrap(lhs), unwrap(rhs) - ))); + return self.create(loc, lhs, rhs); }) - .def("create_xor", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_xor", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, unwrap(lhs), unwrap(rhs) - ))); + return self.create(loc, lhs, rhs); }) - .def("create_or", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_or", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, unwrap(lhs), unwrap(rhs) - ))); + return self.create(loc, lhs, rhs); }) // // Input/Output - .def("create_load", [](mlir::OpBuilder &self, MlirValue &ptrs) -> MlirValue { + .def("create_load", [](mlir::OpBuilder &self, mlir::Value &ptrs) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value( - self.create(loc, unwrap(ptrs)) - )); + return self.create(loc, ptrs); }) - .def("create_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &value) -> void { + .def("create_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value) -> void { auto loc = self.getUnknownLoc(); - self.create(loc, unwrap(ptrs), unwrap(value)); + self.create(loc, ptrs, value); }) - .def("create_masked_load", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &mask, MlirValue &other) -> MlirValue { + .def("create_masked_load", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask, mlir::Value &other, + mlir::triton::CacheModifier cacheModifier, + mlir::triton::EvictionPolicy evictionPolicy, + bool isVolatile) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = unwrap(ptrs).getType().dyn_cast(); + auto ptrType = ptrs.getType().dyn_cast(); std::vector shape = ptrType.getShape(); mlir::Type elementType = ptrType.getElementType().dyn_cast().getPointeeType(); - return wrap(mlir::Value(self.create( - loc, mlir::RankedTensorType::get(shape, elementType), unwrap(ptrs), unwrap(mask), unwrap(other)) - )); + return self.create( + loc, mlir::RankedTensorType::get(shape, elementType), ptrs, mask, other, + cacheModifier, evictionPolicy, isVolatile); }) - .def("create_masked_store", [](mlir::OpBuilder &self, MlirValue &ptrs, MlirValue &val, MlirValue &mask) -> void { + .def("create_masked_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val, mlir::Value &mask) -> void { auto loc = self.getUnknownLoc(); - self.create(loc, unwrap(ptrs), unwrap(val), unwrap(mask)); + self.create(loc, ptrs, val, mask); }) // Block instruction - .def("create_reshape", [](mlir::OpBuilder &self, MlirValue &arg, std::vector &shape) -> MlirValue { + .def("create_reshape", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto argType = unwrap(arg).getType().dyn_cast(); - return wrap(mlir::Value(self.create( - loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg), self.getI64ArrayAttr(shape) - ))); + auto argType = arg.getType().dyn_cast().getElementType(); + return self.create( + loc, mlir::RankedTensorType::get(shape, argType), arg, self.getI64ArrayAttr(shape) + ); }) - .def("create_cat", [](mlir::OpBuilder &self, MlirValue &lhs, MlirValue &rhs) -> MlirValue { + .def("create_cat", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto lhsType = unwrap(lhs).getType().dyn_cast(); - auto rhsType = unwrap(rhs).getType().dyn_cast(); + auto lhsType = lhs.getType().dyn_cast(); + auto rhsType = rhs.getType().dyn_cast(); if (!(lhsType.getShape().size() == 1 && rhsType.getShape().size() == 1)) throw std::runtime_error("shape not supported by cat. Expecting rank-1 inputs"); std::vector shape {lhsType.getShape()[0] + rhsType.getShape()[0]}; - return wrap(mlir::Value(self.create( - loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), unwrap(lhs), unwrap(rhs) - ))); + return self.create( + loc, mlir::RankedTensorType::get(shape, lhsType.getElementType()), lhs, rhs + ); }) - .def("create_broadcast", [](mlir::OpBuilder &self, MlirValue &arg, std::vector &shape) -> MlirValue { + .def("create_broadcast", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto argType = unwrap(arg).getType(); - return wrap(mlir::Value(self.create( - loc, mlir::RankedTensorType::get(shape, argType), unwrap(arg) - ))); + // TODO: should be scalar type here + auto argType = arg.getType(); + return self.create( + loc, mlir::RankedTensorType::get(shape, argType), arg + ); + }) + .def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { + auto loc = self.getUnknownLoc(); + auto argType = arg.getType(); + return self.create( + loc, mlir::RankedTensorType::get(shape, argType), arg + ); }) // // atomic - .def("create_atomic_cas", [](mlir::OpBuilder &self, MlirValue &ptr, - MlirValue &cmp, MlirValue &val) -> MlirValue { + .def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr, + mlir::Value &cmp, mlir::Value &val) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = unwrap(ptr).getType().dyn_cast(); + auto ptrType = ptr.getType().dyn_cast(); mlir::Type dstType = ptrType.getPointeeType(); - return wrap(mlir::Value(self.create( - loc, dstType, unwrap(ptr), unwrap(cmp), unwrap(val) - ))); + return self.create( + loc, dstType, ptr, cmp, val + ); }) .def("create_atomic_rmw", [](mlir::OpBuilder &self, mlir::triton::RMWOp rmwOp, - MlirValue &ptr, MlirValue &val, MlirValue &mask) -> MlirValue { + mlir::Value &ptr, mlir::Value &val, mlir::Value &mask) -> mlir::Value { auto loc = self.getUnknownLoc(); - auto ptrType = unwrap(ptr).getType().dyn_cast(); + auto ptrType = ptr.getType().dyn_cast(); mlir::Type dstType = ptrType.getPointeeType(); - return wrap(mlir::Value(self.create( - loc, dstType, rmwOp, unwrap(ptr), unwrap(val), unwrap(mask) - ))); + return self.create( + loc, dstType, rmwOp, ptr, val, mask + ); }) // Built-in instruction - .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> MlirValue { + .def("create_get_program_id", [](mlir::OpBuilder &self, int axis) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( + return self.create( loc, self.getI32Type(), self.getI32IntegerAttr(axis) - ))); + ); }) - .def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> MlirValue { + .def("create_get_num_programs", [](mlir::OpBuilder &self, int axis) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( + return self.create( loc, self.getI32Type(), self.getI32IntegerAttr(axis) - ))); + ); }) - .def("create_dot", [](mlir::OpBuilder &self, MlirValue &a, MlirValue &b, MlirValue &c) -> MlirValue { + .def("create_dot", [](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b, mlir::Value &c) -> mlir::Value { auto loc = self.getUnknownLoc(); - return wrap(mlir::Value(self.create( - loc, unwrap(c).getType(), unwrap(a), unwrap(b), unwrap(c) - ))); + return self.create(loc, c.getType(), a, b, c); }) // .def("create_exp", &ir::builder::create_exp, ret::reference) // .def("create_cos", &ir::builder::create_cos, ret::reference) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 1e743ffb0..06b19ce85 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -67,9 +67,10 @@ class CodeGenerator(ast.NodeVisitor): elif name in self.builtins: ret = self.builtins[name] else: + print(self.lscope) raise ValueError(f'{name} is not defined') if self.is_triton_tensor(ret): - return self._get_tensor(name) + return self._get_tensor(name, self.builder.get_insertion_block()) return ret def set_value(self, name: str, @@ -86,12 +87,15 @@ class CodeGenerator(ast.NodeVisitor): # # SSA-construction # - def _get_tensor(self, name: str, bb: Optional[_triton.ir.basic_block] = None) -> triton.language.tensor: + def _get_tensor(self, name: str, bb: _triton.ir.basic_block) -> triton.language.tensor: if not bb: bb = self.builder.get_insertion_block() # local value numbering if (name, bb) in self.lvalues: return self.lvalues[(name, bb)] + # param. FIXME: should delete this + if (name, None) in self.lvalues: + return self.lvalues[(name, None)] print(self.lvalues) assert False, f'Cannot find {name} in {bb}' # global value numbering @@ -217,10 +221,15 @@ class CodeGenerator(ast.NodeVisitor): self.lscope[kwarg_names] = self.kwargs # initialize function if inline: - pass + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.visit_compound_statement(node.body) + return self.last_ret else: fn = self.builder.create_function(node.name, self.prototype.to_ir(self.builder)) self.module.push_back(fn) + entry = fn.add_entry_block() + self._seal_block(entry) arg_values = [] idx = 0 for i, arg_name in enumerate(arg_names): @@ -239,17 +248,11 @@ class CodeGenerator(ast.NodeVisitor): # attr = _triton.ir.attribute(attr, self.attributes[i]) # fn.add_attr(idx + 1, attr) # fn.args[idx].name = arg_name - # arg_values.append(triton.language.tensor(fn.args[idx], self.prototype.param_types[idx])) - # idx += 1 + arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) + idx += 1 - for arg_name, arg_value in zip(arg_names, arg_values): - self.set_value(arg_name, arg_value) - if inline: - self.visit_compound_statement(node.body) - return self.last_ret - else: - entry = fn.add_entry_block() - self._seal_block(entry) + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) @@ -821,50 +824,6 @@ class Kernel: return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) - # Compile to ttir, for the propose of testing MLIR rewriting - def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): - # TODO: share code with _compile & __call__ - - # preparing args - tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] - # attributes - attributes = dict() - for i, arg in enumerate(wargs): - if i in self.fn.do_not_specialize: - continue - if isinstance(arg, int): - attributes[i] = Kernel.pow2_divisor(arg) - elif i in tensor_idxs: - addr = arg.data_ptr() - range_size = _triton.runtime.get_pointer_range_size(addr) - attributes[i] = min(Kernel.pow2_divisor(addr), - Kernel.pow2_divisor(range_size)) - # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} - constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) - constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) - arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - - # create IR module - context = _triton.ir.context() - context.load_triton() - # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] - ret_type = triton.language.void - prototype = triton.language.function_type([ret_type], arg_types) - # generate Triton-IR - # export symbols visible from self into code-generator object - gscope = self.__globals__ - generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) - try: - generator.visit(self.parse()) - except Exception as e: - node = generator.last_node - if node is None or isinstance(e, (NotImplementedError, CompilationError)): - raise e - raise CompilationError(self.src, node) from e - return generator.module - class Launcher: def __init__(self, kernel, grid): @@ -1209,6 +1168,53 @@ class JITFunction: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") return Binary(backend, name, asm, shared_mem, num_warps) + # Compile to ttir, for the propose of testing MLIR rewriting + def compile_to_ttir(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + # TODO: share code with _compile & __call__ + + # preparing args + tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] + # attributes + attributes = dict() + for i, arg in enumerate(wargs): + if isinstance(arg, int): + attributes[i] = Kernel.pow2_divisor(arg) + elif i in tensor_idxs: + addr = arg.data_ptr() + range_size = _triton.runtime.get_pointer_range_size(addr) + attributes[i] = min(Kernel.pow2_divisor(addr), + Kernel.pow2_divisor(range_size)) + # transforms ints whose value is one into constants for just-in-time compilation + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} + constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) + arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] + + print(f'wargs: {wargs}') + print(f'constants: {constants}') + print(f'arg_types: {arg_types}') + # create IR module + context = _triton.ir.context() + context.load_triton() + # get just-in-time proto-type of kernel + arg_types = [Kernel._to_triton_ir(arg) for arg in arg_types] + ret_type = triton.language.void + prototype = triton.language.function_type([ret_type], arg_types) + # generate Triton-IR + # export symbols visible from self into code-generator object + gscope = self.__globals__ + generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict()) + try: + generator.visit(self.parse()) + except Exception as e: + node = generator.last_node + if node is None or isinstance(e, (NotImplementedError, CompilationError)): + raise e + raise CompilationError(self.src, node) from e + # FIXME: now we need to return context, otherwise it will be deleted + return generator.module, context + + def __getitem__(self, grid): return Launcher(self._init_kernel(), grid) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index c07c44cfd..17e653778 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -280,7 +280,7 @@ class function_type(dtype): self.param_types = param_types def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_type}' + return f'fn ({self.param_types}) -> {self.ret_types}' def to_ir(self, builder: ir.builder): ir_param_types = [ty.to_ir(builder) for ty in self.param_types] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 4063b86fc..53bcc6d3e 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -121,7 +121,7 @@ def add(input: tl.tensor, if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): input, other = other, input if input_scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type) + return tl.tensor(builder.create_gep(input.handle, other.handle), input.type) # float + float elif input_scalar_ty.is_floating(): return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) @@ -138,7 +138,7 @@ def sub(input: tl.tensor, scalar_ty = input.type.scalar # ptr - offset if scalar_ty.is_ptr(): - return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]), + return tl.tensor(builder.create_gep(input.handle, minus(other, builder).handle), input.type) # float - float if scalar_ty.is_floating(): @@ -438,7 +438,7 @@ def not_equal(input: tl.tensor, def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: shape = [end - start] ret_ty = tl.block_type(tl.int32, shape) - return tl.tensor(builder.get_range(start, end), ret_ty) + return tl.tensor(builder.create_make_range(start, end), ret_ty) def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: