#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/TypeSwitch.h" #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" using namespace mlir; using namespace mlir::triton::gpu; static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr, unsigned &value, StringRef desc) { auto intAttr = attr.dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integer type in ") << desc; return failure(); } if (intAttr.getType().isSignedInteger()) { int64_t attrVal = intAttr.getSInt(); if (attrVal < 0) { parser.emitError(parser.getNameLoc(), "expected an unsigned integer value in ") << desc; return failure(); } value = attrVal; } else if (intAttr.getType().isSignlessInteger()) { int64_t attrVal = intAttr.getInt(); if (attrVal < 0) { parser.emitError(parser.getNameLoc(), "expected an unsigned integer value in ") << desc; return failure(); } value = attrVal; } else { value = intAttr.getUInt(); } return success(); } // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, SmallVector &res, StringRef desc) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; return failure(); } for (Attribute i : arrayAttr) { unsigned value; if (parseIntAttrValue(parser, i, value, desc).failed()) return failure(); res.push_back(value); } return success(); }; static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, unsigned &value, StringRef desc) { return parseIntAttrValue(parser, attr.getValue(), value, desc); }; //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" //===----------------------------------------------------------------------===// // Blocked Encoding //===----------------------------------------------------------------------===// Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary DictionaryAttr dict; if (parser.parseAttribute(dict).failed()) return {}; if (parser.parseGreater().failed()) return {}; SmallVector sizePerThread; SmallVector threadsPerWarp; SmallVector warpsPerCTA; SmallVector order; for (const NamedAttribute &attr : dict) { if (attr.getName() == "sizePerThread") { if (parseIntArrayAttr(parser, attr, sizePerThread, "number of elements per thread") .failed()) return {}; } else if (attr.getName() == "threadsPerWarp") { if (parseIntArrayAttr(parser, attr, threadsPerWarp, "number of threads per warp") .failed()) return {}; } else if (attr.getName() == "warpsPerCTA") { if (parseIntArrayAttr(parser, attr, warpsPerCTA, "number of warps per CTA") .failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); return {}; } } return parser.getChecked( parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order); } void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "<{" << "sizePerThread = [" << getSizePerThread() << "]" << ", threadsPerWarp = [" << getThreadsPerWarp() << "]" << ", warpsPerCTA = [" << getWarpsPerCTA() << "]" << ", order = [" << getOrder() << "]" << "}>"; } //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; DictionaryAttr dict; if (parser.parseAttribute(dict).failed()) return {}; if (parser.parseGreater().failed()) return {}; unsigned version = 0; SmallVector warpsPerCTA; for (const NamedAttribute &attr : dict) { if (attr.getName() == "version") { if (parseUInt(parser, attr, version, "version").failed()) return {}; } if (attr.getName() == "warpsPerCTA") { if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) return {}; } } return parser.getChecked(parser.getContext(), version, warpsPerCTA); } void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "version = " << getVersion() << ", " << "warpsPerCTA = [" << getWarpsPerCTA() << "]" << "}>"; } //===----------------------------------------------------------------------===// // Shared encoding //===----------------------------------------------------------------------===// Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary DictionaryAttr dict; if (parser.parseAttribute(dict).failed()) return {}; if (parser.parseGreater().failed()) return {}; unsigned vec = 0; unsigned perPhase = 0; unsigned maxPhase = 0; SmallVector order; for (const NamedAttribute &attr : dict) { if (attr.getName() == "vec") { if (parseUInt(parser, attr, vec, "vec").failed()) return {}; } else if (attr.getName() == "perPhase") { if (parseUInt(parser, attr, perPhase, "perPhase").failed()) return {}; } else if (attr.getName() == "maxPhase") { if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) return {}; } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); return {}; } } return parser.getChecked( parser.getContext(), vec, perPhase, maxPhase, order); } void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "vec = " << getVec() << ", perPhase = " << getPerPhase() << ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder() << "]" << "}>"; } //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// class TritonGPUOpAsmInterface : public OpAsmDialectInterface { public: using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { if (auto mmaAttr = attr.dyn_cast()) { os << "mma"; return AliasResult::FinalAlias; } else if (auto sharedAttr = attr.dyn_cast()) { os << "shared"; return AliasResult::FinalAlias; } else if (auto blockedAttr = attr.dyn_cast()) { os << "blocked"; return AliasResult::FinalAlias; } return OpAsmDialectInterface::getAlias(attr, os); } }; void TritonGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" >(); addInterfaces(); } namespace mlir { namespace triton { // Type inference static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); return Type(); } static Type getPointeeType(Type type) { if (auto tensorType = type.dyn_cast()) { // Tensor of pointers auto shape = tensorType.getShape(); auto ptrType = tensorType.getElementType().dyn_cast(); Type pointeeType = ptrType.getPointeeType(); return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding()); } else if (auto ptrType = type.dyn_cast()) { // scalar pointer Type pointeeType = ptrType.getPointeeType(); return pointeeType; } return Type(); } } // namespace triton } // namespace mlir static LogicalResult verify(CopyAsyncOp op) { Type resType = op.getResult().getType(); if (auto tensorType = resType.dyn_cast()) { Attribute encoding = tensorType.getEncoding(); if (!encoding.isa()) return op.emitOpError("copy_async should return a shared memory tensor"); } else return op.emitOpError("copy_async should return a tensor"); return success(); } #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" // verify TritonGPU ops LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // TODO: fill this. return success(); }