diff --git a/include/triton/ir/Traits.h b/include/triton/ir/Traits.h index a470e2bb3..19864b12b 100644 --- a/include/triton/ir/Traits.h +++ b/include/triton/ir/Traits.h @@ -17,13 +17,14 @@ class TensorSizeTrait : public TraitBase { public: // TODO: move impl to .cc files static LogicalResult verifyTrait(Operation *op) { + int constexpr maxElement = 1048576; for (auto opType : op->getOperandTypes()) { if (auto tensorType = opType.dyn_cast()) { int64_t numElements = 1; for (int64_t s : tensorType.getShape()) numElements *= s; - if (numElements > 1048576) - return op->emitError("Maximum allowed number of elements is 1048576, but ") + if (numElements > maxElement) + return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but " << *op << " has more than that"; if ((numElements & (numElements - 1)) != 0) return op->emitError("Number of elements must be power-of-two, but ") @@ -36,8 +37,8 @@ public: int64_t numElements = 1; for (int64_t s : tensorType.getShape()) numElements *= s; - if (numElements > 1048576) - return op->emitError("Maximum allowed number of elements is 1048576, but ") + if (numElements > maxElement) + return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but " << *op << " has more than that"; if ((numElements & (numElements - 1)) != 0) return op->emitError("Number of elements must be power-of-two, but ") diff --git a/include/triton/ir/TritonOps.td b/include/triton/ir/TritonOps.td index 4d6319944..620556287 100644 --- a/include/triton/ir/TritonOps.td +++ b/include/triton/ir/TritonOps.td @@ -195,6 +195,10 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> { def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "dot"; + let description = [{ + $d = matrix_multiply($a, $b) + $c + }]; + let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c); let results = (outs TT_FpIntTensor:$d); @@ -279,6 +283,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> { def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { let summary = "make range"; + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + let arguments = (ins I32Attr:$start, I32Attr:$end); let results = (outs TT_IntegerTensor:$result); diff --git a/python/src/triton.cc b/python/src/triton.cc index 6cc60c761..6e25f898a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -690,18 +690,6 @@ void init_triton_ir(py::module &&m) { }) ; - // py::class_(m, "module") - // .def(py::init()) - // .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) { - // const auto metadatas = self->get_metadatas(); - // auto it = metadatas.find(name); - // if (it != metadatas.end()) - // if (auto *instr = dynamic_cast(value)) { - // instr->set_metadata(it->second.first, it->second.second); - // } - // }) - // .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference); - // using eattr = ir::attribute_kind_t; // py::enum_(m, "attribute_kind") // .value("readonly", eattr::readonly) @@ -715,19 +703,6 @@ void init_triton_ir(py::module &&m) { // py::class_(m, "attribute"); // // .def(py::init()); - py::class_(m, "function") - // .def_property_readonly("args", &ir::function::args) - // .def_property_readonly("attrs", &ir::function::attrs) - // .def("add_attr", &ir::function::add_attr); - .def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument { - return self.getArgument(idx); - }) - .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* { - return self.addEntryBlock(); - }, ret::reference) - .def("dump", [](mlir::FuncOp &self) { self.dump(); }) - ; - // Ops py::class_(m, "OpState") .def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value { @@ -771,6 +746,17 @@ void init_triton_ir(py::module &&m) { }) ; + py::class_(m, "function") + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument { + return self.getArgument(idx); + }) + .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* { + return self.addEntryBlock(); + }, ret::reference) + ; + py::class_(m, "InsertPoint"); py::class_(m, "builder", py::dynamic_attr())