Documentation

This commit is contained in:
Yan Da
2022-04-07 16:00:53 +08:00
parent 16d44e5c4c
commit 6b4da6f016
3 changed files with 26 additions and 29 deletions

View File

@@ -17,13 +17,14 @@ class TensorSizeTrait : public TraitBase<ConcreteType, TensorSizeTrait> {
public:
// TODO: move impl to .cc files
static LogicalResult verifyTrait(Operation *op) {
int constexpr maxElement = 1048576;
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > 1048576)
return op->emitError("Maximum allowed number of elements is 1048576, but ")
if (numElements > maxElement)
return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but "
<< *op << " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")
@@ -36,8 +37,8 @@ public:
int64_t numElements = 1;
for (int64_t s : tensorType.getShape())
numElements *= s;
if (numElements > 1048576)
return op->emitError("Maximum allowed number of elements is 1048576, but ")
if (numElements > maxElement)
return op->emitError("Maximum allowed number of elements is ") << maxElement << ", but "
<< *op << " has more than that";
if ((numElements & (numElements - 1)) != 0)
return op->emitError("Number of elements must be power-of-two, but ")

View File

@@ -195,6 +195,10 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs"> {
def TT_DotOp : TT_Op<"dot", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "dot";
let description = [{
$d = matrix_multiply($a, $b) + $c
}];
let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c);
let results = (outs TT_FpIntTensor:$d);
@@ -279,6 +283,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
let summary = "make range";
let description = [{
Returns an 1D int32 tensor.
Values span from $start to $end (exclusive), with step = 1
}];
let arguments = (ins I32Attr:$start, I32Attr:$end);
let results = (outs TT_IntegerTensor:$result);

View File

@@ -690,18 +690,6 @@ void init_triton_ir(py::module &&m) {
})
;
// py::class_<mlir::ModuleOp>(m, "module")
// .def(py::init<std::string, ir::builder &>())
// .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
// const auto metadatas = self->get_metadatas();
// auto it = metadatas.find(name);
// if (it != metadatas.end())
// if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
// instr->set_metadata(it->second.first, it->second.second);
// }
// })
// .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference);
// using eattr = ir::attribute_kind_t;
// py::enum_<eattr>(m, "attribute_kind")
// .value("readonly", eattr::readonly)
@@ -715,19 +703,6 @@ void init_triton_ir(py::module &&m) {
// py::class_<mlir::Attribute>(m, "attribute");
// // .def(py::init<eattr, int>());
py::class_<mlir::FuncOp>(m, "function")
// .def_property_readonly("args", &ir::function::args)
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
.def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
return self.getArgument(idx);
})
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
.def("dump", [](mlir::FuncOp &self) { self.dump(); })
;
// Ops
py::class_<mlir::OpState>(m, "OpState")
.def("get_result", [](mlir::OpState &self, unsigned idx) -> mlir::Value {
@@ -771,6 +746,17 @@ void init_triton_ir(py::module &&m) {
})
;
py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
.def("args", [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
return self.getArgument(idx);
})
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
;
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
py::class_<mlir::OpBuilder>(m, "builder", py::dynamic_attr())