diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 8fd1cd661..b90e1dc61 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -8,12 +8,14 @@ class TritonGPU_Attr traits = []> : AttrDef; def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> { - let mnemonic = "shared (memory) encoding"; + let mnemonic = "shared_layout"; let description = [{ -An encoding for tensors whose elements may be simultaneously accessed by different warps in the programs, via shared memory. +An encoding for tensors whose elements may be simultaneously accessed by +different warps in the programs, via shared memory. -In order to avoid shared memory bank conflicts, elements may be stored in a swizzled layout. +In order to avoid shared memory bank conflicts, elements may be stored in a +swizzled layout. For example, a swizzled row-major layout stores would store data as follows: A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2 @@ -29,10 +31,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / And the associated TritonGPU MLIR ```mlir -#SMEM = #triton_gpu.encoding<{ +#SMEM = #triton_gpu.shared_layout<{ vec = 2, perPhase = 2, - maxPhase = 4 + maxPhase = 4, + order = [1, 0] }> ``` }]; @@ -40,12 +43,13 @@ And the associated TritonGPU MLIR let parameters = ( ins // swizzle info - "unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase + "unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase, + ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order ); } -def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> { - let mnemonic = "coalesced encoding"; +def TritonGPUShardedEncodingAttr : TritonGPU_Attr<"TritonGPUShardedEncoding"> { + let mnemonic = "sharded_layout"; let description = [{ An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout @@ -70,7 +74,7 @@ size } .... A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63] And the associated TritonGPU MLIR -#SMEM = #triton_gpu.encoding<{ +#LAYOUT = #triton_gpu.sharded_layout<{ threadTileSize = {2, 2} blockTileSize = {32, 8} }> @@ -81,28 +85,55 @@ And the associated TritonGPU MLIR let parameters = ( ins - ArrayRefParameter<"unsigned">:$threadTileSize, - ArrayRefParameter<"unsigned">:$blockTileSize, + // TODO: should we rename this as laneTileSize? + ArrayRefParameter< + "unsigned", + /*desc*/"size of a tile that is holded by a thread" + >:$threadTileSize, + ArrayRefParameter< + "unsigned", + "size of the a tile that is holded by a warp" + >:$warpTileSize, + ArrayRefParameter< + "unsigned", + "size of a tile that is holded by a thread block" + >:$blockTileSize, + // // TODO: It seems that we don't need this (because we can re-compute this) + // ArrayRefParameter<"unsigned">:$reptitions, // fastest-changing axis first - ArrayRefParameter<"unsigned">:$order + ArrayRefParameter< + "unsigned", + "order of axes by the rate of changing" + >:$order + // "AffineMap":$threadOrdering, + // "AffineMap":warpOrdering, + // "AffineMap":$blockOrdering, + ); // let genVerifyDecl = 1; } def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { - let mnemonic = "mma encoding"; + let mnemonic = "mma_layout"; let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}]; let parameters = ( ins + // only used by Volta mma.884 ArrayRefParameter<"unsigned">:$fragmentPerWarp, + // aka shapeOfInstr (e.g., {16,8,16}) ArrayRefParameter<"unsigned">:$shapePerWarp, + // TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout) ArrayRefParameter<"unsigned">:$warpPerTile, + // TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout) ArrayRefParameter<"unsigned">:$shapePerTile, + // TODO: should Distributed layout also ArrayRefParameter<"unsigned">:$reptitions, ArrayRefParameter<"unsigned">:$contigPerThread + // "AffineMap":$warpOrdering, + // "AffineMap":$blockOrdering ); // let genVerifyDecl = 1; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index bc2b968c3..da5b09de5 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -4,46 +4,181 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" +using namespace mlir; using namespace mlir::triton::gpu; +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = attr.getValue().dyn_cast(); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") + << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + auto intAttr = i.dyn_cast(); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer value in ") + << desc; + return failure(); + } + res.push_back(intAttr.getUInt()); + } + return success(); +}; + //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" -mlir::Attribute -TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) { - llvm_unreachable("Not implemented"); +Attribute +TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector threadTileSize; + SmallVector warpTileSize; + SmallVector blockTileSize; + SmallVector order; + + // parse an array of integers + // auto parseIntArrayAttr = [&parser](const NamedAttribute &attr, + // SmallVector &res, + // StringRef desc) -> LogicalResult { + // auto arrayAttr = attr.getValue().dyn_cast(); + // if (!arrayAttr) { + // parser.emitError(parser.getNameLoc(), "expected an array for ") + // << desc; + // return failure(); + // } + // for (Attribute i : arrayAttr) { + // auto intAttr = i.dyn_cast(); + // if (!intAttr) { + // parser.emitError(parser.getNameLoc(), "expected an integer value in ") + // << desc; + // return failure(); + // } + // res.push_back(intAttr.getUInt()); + // } + // return success(); + // }; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "threadTileSize") { + if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed()) + return {}; + } else if (attr.getName() == "warpTileSize") { + if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed()) + return {}; + } else if (attr.getName() == "blockTileSize") { + if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + return parser.getChecked(parser.getContext(), + threadTileSize, + warpTileSize, + blockTileSize, + order); } -void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const { +void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "<" - << "threadTileSize = " << getThreadTileSize() - << ", blockTileSize = " << getBlockTileSize() - << ", order = " << getOrder() + << "threadTileSize = [" << getThreadTileSize() << "]" + << ", warpTileSize = [" << getWarpTileSize() << "]" + << ", blockTileSize = [" << getBlockTileSize() << "]" + << ", order = [" << getOrder() << "]" << ">"; } -mlir::Attribute -TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) { +Attribute +TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { llvm_unreachable("Not implemented"); } -void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const { +void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { llvm_unreachable("Not implemented"); } -mlir::Attribute -TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) { - llvm_unreachable("Not implemented"); +Attribute +TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + + auto parseUInt = [&parser](const NamedAttribute &attr, + unsigned &value, + StringRef desc) -> LogicalResult { + auto intAttr = attr.getValue().dyn_cast(); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer ") << desc; + return failure(); + } + value = intAttr.getUInt(); + return success(); + }; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + return parser.getChecked(parser.getContext(), + vec, + perPhase, + maxPhase, + order); } -void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const { +void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<" - // << "threadTileSize = " << getThreadTileSize() - // << ", blockTileSize = " << getBlockTileSize() - // << ", order = " << getOrder() + << "vec = " << getVec() + << ", perPhase = " << getPerPhase() + << ", order = [" << getOrder() << "]" << ">"; } @@ -92,9 +227,9 @@ static Type getPointeeType(Type type) { // verify TritonGPU ops -mlir::LogicalResult -TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op, - mlir::NamedAttribute attr) { +LogicalResult +TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { // TODO: fill this. return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 8ae549ca3..8edd88ad1 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -35,6 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // Now we assume: // contiguous = 1, order = 0, 1, 2, ..., llvm::SmallVector threadTileSize(rank, 1); // naive layout + llvm::SmallVector warpTileSize(rank, 1); llvm::SmallVector blockTileSize(rank); llvm::SmallVector order(rank); int remainingThreads = numThreads; @@ -45,8 +46,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, remainingThreads /= blockTileSize[dim]; // TODO: will we need repetition? } - Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get( - context, threadTileSize, blockTileSize, order); + Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get( + context, threadTileSize, warpTileSize, blockTileSize, order); return RankedTensorType::get(shape, elementType, encoding); }); diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp index 3749de00e..619701f85 100644 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -50,7 +50,7 @@ private: if (!encoding) return dotOp.emitError() << name << " should have encoding"; if (!encoding.isa() && - !encoding.isa()) + !encoding.isa()) return dotOp.emitError() << name << " should be of distributed layout"; if (name == 'c') cLayout = encoding;