#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" using namespace mlir; using namespace mlir::triton::gpu; // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, SmallVector &res, StringRef desc) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; return failure(); } for (Attribute i : arrayAttr) { auto intAttr = i.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integer value in ") << desc; return failure(); } res.push_back(intAttr.getUInt()); } return success(); }; //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" Attribute TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary DictionaryAttr dict; if (parser.parseAttribute(dict).failed()) return {}; if (parser.parseGreater().failed()) return {}; SmallVector threadTileSize; SmallVector warpTileSize; SmallVector blockTileSize; SmallVector order; // parse an array of integers // auto parseIntArrayAttr = [&parser](const NamedAttribute &attr, // SmallVector &res, // StringRef desc) -> LogicalResult { // auto arrayAttr = attr.getValue().dyn_cast(); // if (!arrayAttr) { // parser.emitError(parser.getNameLoc(), "expected an array for ") // << desc; // return failure(); // } // for (Attribute i : arrayAttr) { // auto intAttr = i.dyn_cast(); // if (!intAttr) { // parser.emitError(parser.getNameLoc(), "expected an integer value in ") // << desc; // return failure(); // } // res.push_back(intAttr.getUInt()); // } // return success(); // }; for (const NamedAttribute &attr : dict) { if (attr.getName() == "threadTileSize") { if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed()) return {}; } else if (attr.getName() == "warpTileSize") { if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed()) return {}; } else if (attr.getName() == "blockTileSize") { if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); return {}; } } return parser.getChecked(parser.getContext(), threadTileSize, warpTileSize, blockTileSize, order); } void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "<" << "threadTileSize = [" << getThreadTileSize() << "]" << ", warpTileSize = [" << getWarpTileSize() << "]" << ", blockTileSize = [" << getBlockTileSize() << "]" << ", order = [" << getOrder() << "]" << ">"; } Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { llvm_unreachable("Not implemented"); } void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { llvm_unreachable("Not implemented"); } Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary DictionaryAttr dict; if (parser.parseAttribute(dict).failed()) return {}; if (parser.parseGreater().failed()) return {}; unsigned vec = 0; unsigned perPhase = 0; unsigned maxPhase = 0; SmallVector order; auto parseUInt = [&parser](const NamedAttribute &attr, unsigned &value, StringRef desc) -> LogicalResult { auto intAttr = attr.getValue().dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integer ") << desc; return failure(); } value = intAttr.getUInt(); return success(); }; for (const NamedAttribute &attr : dict) { if (attr.getName() == "vec") { if (parseUInt(attr, vec, "vec").failed()) return {}; } else if (attr.getName() == "perPhase") { if (parseUInt(attr, perPhase, "perPhase").failed()) return {}; } else if (attr.getName() == "maxPhase") { if (parseUInt(attr, maxPhase, "maxPhase").failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); return {}; } } return parser.getChecked(parser.getContext(), vec, perPhase, maxPhase, order); } void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<" << "vec = " << getVec() << ", perPhase = " << getPerPhase() << ", order = [" << getOrder() << "]" << ">"; } void TritonGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" >(); } namespace mlir { namespace triton { // Type inference static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); return Type(); } static Type getPointeeType(Type type) { if (auto tensorType = type.dyn_cast()) { // Tensor of pointers auto shape = tensorType.getShape(); auto ptrType = tensorType.getElementType().dyn_cast(); Type pointeeType = ptrType.getPointeeType(); return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding()); } else if (auto ptrType = type.dyn_cast()) { // scalar pointer Type pointeeType = ptrType.getPointeeType(); return pointeeType; } return Type(); } } } #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" // verify TritonGPU ops LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // TODO: fill this. return success(); }