From 15bfd0cb796be2f7f6dff53f20ff5aea6b22b635 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Sun, 18 Sep 2022 05:58:42 +0800 Subject: [PATCH] [BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658) --- bin/triton-translate.cpp | 10 - include/triton/Analysis/Allocation.h | 7 +- include/triton/Analysis/Utility.h | 9 + .../TritonGPUToLLVM/TritonGPUToLLVM.h | 8 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 12 + .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 15 +- include/triton/tools/sys/getenv.hpp | 9 + lib/Analysis/Allocation.cpp | 74 +++ lib/Analysis/AxisInfo.cpp | 13 + .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 608 ++++++++++++++---- lib/Dialect/TritonGPU/IR/Dialect.cpp | 80 +++ lib/Target/LLVMIR/LLVMIRTranslation.cpp | 12 + python/src/triton.cc | 11 +- python/tests/test_transpose.py | 68 ++ python/tests/test_vecadd_no_scf.py | 1 + test/Conversion/triton_to_llvm.mlir | 37 -- test/Conversion/tritongpu_to_llvm.mlir | 242 ++++++- 17 files changed, 1025 insertions(+), 191 deletions(-) create mode 100644 python/tests/test_transpose.py delete mode 100644 test/Conversion/triton_to_llvm.mlir diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index 9aa8881f5..bfb658a96 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -64,16 +64,6 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, return nullptr; } - mlir::PassManager pm(module->getContext()); - applyPassManagerCLOptions(pm); - - pm.addPass(createConvertTritonGPUToLLVMPass()); - - if (failed(pm.run(module->getOperation()))) { - llvm::errs() << "Pass execution failed"; - return nullptr; - } - return module; } diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index c7353b3a5..cb4e77228 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -14,7 +14,12 @@ namespace mlir { namespace triton { class AllocationAnalysis; -} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec); + +} // namespace triton /// Modified from llvm-15.0: llvm/ADT/AddressRanges.h /// A class that represents an interval, specified using a start and an end diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 31869c10f..fccd13320 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -2,7 +2,10 @@ #define TRITON_ANALYSIS_UTILITY_H #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include #include + namespace mlir { bool isSharedEncoding(Value value); @@ -11,6 +14,12 @@ bool maybeSharedAllocationOp(Operation *op); std::string getValueOperandName(Value value, AsmState &state); +template Int product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); +} + +template Int ceil(Int m, Int n) { return (m + n - 1) / n; } + } // namespace mlir #endif // TRITON_ANALYSIS_UTILITY_H diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index 85ffc1944..ef81d82c3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -18,6 +18,14 @@ public: mlir::LLVMTypeConverter &typeConverter); }; +class TritonLLVMFunctionConversionTarget : public ConversionTarget { + mlir::LLVMTypeConverter &typeConverter; + +public: + explicit TritonLLVMFunctionConversionTarget( + MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter); +}; + namespace triton { // Names for identifying different NVVM annotations. It is used as attribute diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index ac02470be..66fae4de3 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -16,4 +16,16 @@ #define GET_OP_CLASSES #include "triton/Dialect/TritonGPU/IR/Ops.h.inc" +namespace mlir { +namespace triton { +namespace gpu { + +unsigned getElemsPerThread(Attribute layout, ArrayRef shape); + +unsigned getShapePerCTA(const Attribute &layout, unsigned d); + +} // namespace gpu +} // namespace triton +} // namespace mlir + #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2277b509e..af231c1ef 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -31,6 +31,10 @@ Then, attaching $\mathcal{L} to a tensor $T$ would mean that: Right now, Triton implements two classes of layouts: shared, and distributed. }]; + + code extraBaseClassDeclaration = [{ + unsigned getElemsPerThread(ArrayRef shape) const; + }]; } //===----------------------------------------------------------------------===// @@ -64,6 +68,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / "unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase, ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order ); + + let extraClassDeclaration = extraBaseClassDeclaration; } //===----------------------------------------------------------------------===// @@ -93,6 +99,8 @@ Then the data of A would be distributed as follow between the 16 CUDA threads: L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] }]; + + let extraClassDeclaration = extraBaseClassDeclaration; } //===----------------------------------------------------------------------===// @@ -171,11 +179,10 @@ for }]> ]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraBaseClassDeclaration # [{ SliceEncodingAttr squeeze(int axis); }]; - let parameters = ( ins ArrayRefParameter<"unsigned">:$sizePerThread, @@ -282,6 +289,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: "unsigned":$version, ArrayRefParameter<"unsigned">:$warpsPerCTA ); + + let extraClassDeclaration = extraBaseClassDeclaration; } def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { @@ -311,6 +320,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { // TODO: constraint here to only take distributed encodings "Attribute":$parent ); + + let extraClassDeclaration = extraBaseClassDeclaration; } diff --git a/include/triton/tools/sys/getenv.hpp b/include/triton/tools/sys/getenv.hpp index 1f1c57521..7dd960070 100644 --- a/include/triton/tools/sys/getenv.hpp +++ b/include/triton/tools/sys/getenv.hpp @@ -22,6 +22,7 @@ #ifndef TDL_TOOLS_SYS_GETENV_HPP #define TDL_TOOLS_SYS_GETENV_HPP +#include #include #include @@ -37,6 +38,14 @@ inline std::string getenv(const char *name) { return result; } +inline bool getBoolEnv(const std::string &env) { + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} + } // namespace tools } // namespace triton diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 7d883546a..6afa7ea1a 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -8,6 +8,11 @@ #include #include +#include + +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::MmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; namespace mlir { @@ -15,6 +20,54 @@ namespace mlir { // Shared Memory Allocation Analysis //===----------------------------------------------------------------------===// namespace triton { + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto srcTy = op.src().getType().cast(); + auto dstTy = op.result().getType().cast(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + assert(srcLayout && dstLayout && + "Unexpect layout in getScratchConfigForCvtLayout()"); + unsigned rank = dstTy.getRank(); + SmallVector paddedRepShape(rank); + // TODO: move to TritonGPUAttrDefs.h.inc + auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned { + if (auto blockedLayout = layout.dyn_cast()) { + return blockedLayout.getSizePerThread()[d] * + blockedLayout.getThreadsPerWarp()[d] * + blockedLayout.getWarpsPerCTA()[d]; + } else { + assert(0 && "Unimplemented usage of getShapePerCTA"); + return 0; + } + }; + if (srcLayout.isa() && + dstLayout.isa()) { + auto srcBlockedLayout = srcLayout.cast(); + auto dstBlockedLayout = dstLayout.cast(); + auto inOrd = srcBlockedLayout.getOrder(); + auto outOrd = dstBlockedLayout.getOrder(); + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + inVec = outOrd[0] == 0 ? 1 + : inOrd[0] == 0 ? 1 + : srcBlockedLayout.getSizePerThread()[inOrd[0]]; + outVec = + outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]]; + unsigned pad = std::max(inVec, outVec); + for (unsigned d = 0; d < rank; ++d) { + paddedRepShape[d] = std::max( + std::min(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), + std::min(dstTy.getShape()[d], + getShapePerCTA(dstLayout, d))); + } + paddedRepShape[outOrd[0]] += pad; + } + return paddedRepShape; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation *allocation) @@ -73,6 +126,27 @@ private: tensorType.getElementTypeBitWidth() / 8; allocation->addBuffer(op, bytes); } + } else if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.src().getType().cast(); + auto dstTy = cvtLayout.result().getType().cast(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (srcEncoding.isa() || + dstEncoding.isa()) { + // Only blocked -> blocked conversion requires for scratch allocation + return; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + unsigned inVec = 0; + unsigned outVec = 0; + auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto bytes = elems * srcTy.getElementTypeBitWidth() / 8; + allocation->addBuffer(op, bytes); } } diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 441530d25..e5472eff8 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1,4 +1,5 @@ #include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "llvm/Support/raw_ostream.h" #include @@ -46,6 +47,11 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); if (attr) divHint = attr.cast().getValue().getZExtValue(); + } else if (auto fun = dyn_cast(op)) { + Attribute attr = + fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); + if (attr) + divHint = attr.cast().getValue().getZExtValue(); } } DimVectorT contiguity(rank, 1); @@ -203,6 +209,13 @@ ChangeResult AxisInfoAnalysis::visitOperation( } curr = AxisInfo(contiguity, divisibility, constancy); } + // UnrealizedConversionCast + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + if (llvm::isa(op)) { + curr = operands[0]->getValue(); + } if (curr.getRank() == 0) { return markAllPessimisticFixpoint(op->getResults()); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index c6f0efbd9..ef1a0c2ed 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -11,7 +11,9 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" @@ -26,8 +28,11 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::getElemsPerThread; +using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; namespace mlir { namespace LLVM { @@ -43,11 +48,6 @@ namespace type = mlir::triton::type; class TritonGPUToLLVMTypeConverter; -// TODO(Superjomn) Move to somewhere general utilities locates. -template size_t product(llvm::ArrayRef arr) { - return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); -} - // FuncOpConversion/FuncOpConversionBase is borrowed from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // since it is not exposed on header files in mlir v14 @@ -214,36 +214,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { } }; -static int64_t getLinearIndex(std::vector multidim_index, - ArrayRef shape) { - assert(multidim_index.size() == shape.size()); - // sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} - int64_t rank = shape.size(); - int64_t acc_mul = 1; - for (int64_t i = 1; i < rank; ++i) { - acc_mul *= shape[i]; - } - int64_t linear_index = 0; - for (int64_t i = 0; i < rank; ++i) { - linear_index += multidim_index[i] * acc_mul; - if (i != (rank - 1)) { - acc_mul = acc_mul / shape[i + 1]; - } - } - return linear_index; -} - -static unsigned getElemsPerThread(BlockedEncodingAttr layout, - ArrayRef shape) { - size_t rank = shape.size(); - SmallVector elemsPerThreadPerDim(rank); - for (size_t i = 0; i < rank; ++i) { - unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i]; - elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t; - } - return product(elemsPerThreadPerDim); -} - static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { return builder.create( @@ -309,9 +279,9 @@ static T getLinearIndex(ArrayRef multidim_index, ArrayRef shape) { } struct ConvertTritonGPUOpToLLVMPatternBase { - SmallVector + static SmallVector getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) { SmallVector results(elems); for (unsigned i = 0; i < elems; ++i) { Type type = @@ -344,7 +314,12 @@ public: for (unsigned i = 0; i < rank; ++i) { reordered[i] = shape[order[i]]; } - return delinearize(rewriter, loc, linear, reordered); + auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; } SmallVector delinearize(ConversionPatternRewriter &rewriter, @@ -370,13 +345,29 @@ public: return multiDim; } - // Emit indices calculation within each ConversionPattern - // TODO: [goostavz] Double confirm the redundant indices calculations will - // be eliminated in the consequent MLIR/LLVM optimization - SmallVector> - emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b, - const BlockedEncodingAttr &blocked_layout, - ArrayRef shape) const { + Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape) const { + int rank = multiDim.size(); + Value linear = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), 0); + if (rank > 0) { + linear = multiDim.front(); + for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) { + Value dimSize = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), + std::get<1>(z)); + linear = rewriter.create( + loc, rewriter.create(loc, linear, dimSize), + std::get<0>(z)); + } + } + return linear; + } + + SmallVector + emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b, + const BlockedEncodingAttr &blocked_layout, + ArrayRef shape) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto cast = b.create( loc, TypeRange{llvmIndexTy}, @@ -391,7 +382,6 @@ public: auto warpsPerCTA = blocked_layout.getWarpsPerCTA(); auto order = blocked_layout.getOrder(); unsigned rank = shape.size(); - SmallVector threadIds(rank); // step 1, delinearize threadId to get the base index SmallVector multiDimWarpId = @@ -400,8 +390,19 @@ public: delinearize(b, loc, laneId, threadsPerWarp, order); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { - // multiDimBase[k] = (multiDimThreadId[k] + multiDimWarpId[k] * - // threadsPerWarp[k]) * + // Wrap around multiDimWarpId/multiDimThreadId incase + // shape[k] > shapePerCTA[k] + unsigned maxWarps = + ceil(shape[k], sizePerThread[k] * threadsPerWarp[k]); + unsigned maxThreads = ceil(shape[k], sizePerThread[k]); + multiDimWarpId[k] = b.create( + loc, multiDimWarpId[k], + createIndexAttrConstant(b, loc, llvmIndexTy, maxWarps)); + multiDimThreadId[k] = b.create( + loc, multiDimThreadId[k], + createIndexAttrConstant(b, loc, llvmIndexTy, maxThreads)); + // multiDimBase[k] = (multiDimThreadId[k] + + // multiDimWarpId[k] * threadsPerWarp[k]) * // sizePerThread[k]; Value threadsPerWarpK = createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]); @@ -413,17 +414,100 @@ public: loc, multiDimThreadId[k], b.create(loc, multiDimWarpId[k], threadsPerWarpK))); } + return multiDimBase; + } + + SmallVector> emitIndices(Location loc, + ConversionPatternRewriter &b, + const Attribute &layout, + ArrayRef shape) const { + if (auto blocked = layout.dyn_cast()) { + return emitIndicesForBlockedLayout(loc, b, blocked, shape); + } else if (auto slice = layout.dyn_cast()) { + return emitIndicesForSliceLayout(loc, b, slice, shape); + } else { + assert(0 && "emitIndices for layouts other than blocked & slice not " + "implemented yet"); + return {}; + } + } + + SmallVector> + emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &b, + const SliceEncodingAttr &sliceLayout, + ArrayRef shape) const { + auto parent = sliceLayout.getParent(); + unsigned dim = sliceLayout.getDim(); + size_t rank = shape.size(); + if (auto blockedParent = parent.dyn_cast()) { + SmallVector paddedShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) { + paddedShape[d] = shape[d]; + } else if (d == dim) { + paddedShape[d] = 1; + } else { + paddedShape[d] = shape[d - 1]; + } + } + auto paddedIndices = + emitIndicesForBlockedLayout(loc, b, blockedParent, paddedShape); + unsigned numIndices = paddedIndices.size(); + SmallVector> resultIndices(numIndices); + for (unsigned i = 0; i < numIndices; ++i) { + for (unsigned d = 0; d < rank + 1; ++d) { + if (d != dim) { + resultIndices[i].push_back(paddedIndices[i][d]); + } + } + } + return resultIndices; + + } else if (auto sliceParent = parent.dyn_cast()) { + assert(0 && "emitIndicesForSliceLayout with parent of sliceLayout" + "is not implemented yet"); + return {}; + + } else { + assert(0 && "emitIndicesForSliceLayout with parent other than blocked & " + "slice not implemented yet"); + return {}; + } + } + + // Emit indices calculation within each ConversionPattern + // TODO: [goostavz] Double confirm the redundant indices calculations will + // be eliminated in the consequent MLIR/LLVM optimization. We might + // implement a indiceCache if necessary. + SmallVector> + emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b, + const BlockedEncodingAttr &blockedLayout, + ArrayRef shape) const { + auto llvmIndexTy = this->getTypeConverter()->getIndexType(); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned k = 0; k < rank; ++k) { + shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]; + } + + // step 1, delinearize threadId to get the base index + auto multiDimBase = + emitBaseIndexForBlockedLayout(loc, b, blockedLayout, shape); // step 2, get offset of each element unsigned elemsPerThread = 1; SmallVector> offset(rank); SmallVector multiDimElemsPerThread(rank); for (unsigned k = 0; k < rank; ++k) { - multiDimElemsPerThread[k] = shape[k] / threadsPerWarp[k] / warpsPerCTA[k]; + multiDimElemsPerThread[k] = + ceil(shape[k], shapePerCTA[k]) * sizePerThread[k]; elemsPerThread *= multiDimElemsPerThread[k]; + // 1 block in minimum if shape[k] is less than shapePerCTA[k] for (unsigned blockOffset = 0; - blockOffset < - shape[k] / (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]); + blockOffset < ceil(shape[k], shapePerCTA[k]); ++blockOffset) for (unsigned warpOffset = 0; warpOffset < warpsPerCTA[k]; ++warpOffset) for (unsigned threadOffset = 0; threadOffset < threadsPerWarp[k]; @@ -445,7 +529,7 @@ public: std::multiplies()); SmallVector threadsPerDim(rank); for (unsigned k = 0; k < rank; ++k) { - threadsPerDim[k] = shape[k] / sizePerThread[k]; + threadsPerDim[k] = ceil(shape[k], sizePerThread[k]); } for (unsigned n = 0; n < elemsPerThread; ++n) { unsigned linearNanoTileId = n / accumSizePerThread; @@ -469,6 +553,20 @@ public: return multiDimIdx; } + + Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, + Value smem, const Allocation *allocation, + Operation *op) const { + auto ptrTy = LLVM::LLVMPointerType::get( + this->getTypeConverter()->convertType(rewriter.getIntegerType(8)), 3); + auto bufferId = allocation->getBufferId(op); + assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); + size_t offset = allocation->getOffset(bufferId); + auto llvmIndexTy = this->getTypeConverter()->getIndexType(); + Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset); + Value base = rewriter.create(loc, ptrTy, smem, offVal); + return base; + } }; // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a @@ -482,19 +580,10 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); - auto layout = tensorTy.getEncoding().cast(); + auto layout = tensorTy.getEncoding(); auto srcType = typeConverter->convertType(elemType); auto llSrc = rewriter.create(loc, srcType, constVal); - - auto numElems = layout.getSizePerThread(); - size_t totalElems = - std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1, - std::multiplies<>{}); - size_t numThreads = - product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp()); - // TODO(Superjomn) add numElemsPerThread to the layout encodings. - size_t numElemsPerThread = totalElems / numThreads; - + size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); llvm::SmallVector elems(numElemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); auto structTy = @@ -580,7 +669,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { auto shape = ty.getShape(); // Here, we assume that all inputs should have a blockedLayout - unsigned valueElems = getElemsPerThread(layout, shape); + unsigned valueElems = layout.getElemsPerThread(shape); auto llvmElemTy = typeConverter->convertType(ty.getElementType()); auto llvmElemPtrPtrTy = @@ -595,16 +684,15 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { auto ty = val.getType().cast(); // Here, we assume that all inputs should have a blockedLayout auto layout = ty.getEncoding().dyn_cast(); + assert(layout && "unexpected layout in getLayout"); auto shape = ty.getShape(); - unsigned valueElems = getElemsPerThread(layout, shape); + unsigned valueElems = layout.getElemsPerThread(shape); return std::make_tuple(layout, valueElems); } unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const { auto axisInfo = getAxisInfo(val); - auto order = layout.getOrder(); - unsigned maxMultiple = axisInfo->getDivisibility(order[0]); unsigned maxContig = axisInfo->getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); @@ -614,22 +702,18 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { unsigned getVectorizeSize(Value ptr, const BlockedEncodingAttr &layout) const { auto axisInfo = getAxisInfo(ptr); - auto contig = axisInfo->getContiguity(); // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. auto order = layout.getOrder(); unsigned align = getAlignment(ptr, layout); - auto getTensorShape = [](Value val) -> ArrayRef { - auto ty = val.getType().cast(); - auto shape = ty.getShape(); - return shape; - }; - - // unsigned contigPerThread = layout.getSizePerThread()[order[0]]; - unsigned contigPerThread = getElemsPerThread(layout, getTensorShape(ptr)); + auto ty = ptr.getType().dyn_cast(); + assert(ty); + auto shape = ty.getShape(); + unsigned contigPerThread = layout.getSizePerThread()[order[0]]; unsigned vec = std::min(align, contigPerThread); + vec = std::min(shape[order[0]], vec); return vec; } @@ -819,25 +903,22 @@ struct BroadcastOpConversion auto srcShape = srcTy.getShape(); auto resultShape = resultTy.getShape(); unsigned rank = srcTy.getRank(); - // TODO: [goostavz] double confirm the op semantics with Phil assert(rank == resultTy.getRank()); SmallVector srcLogicalShape(2 * rank); SmallVector resultLogicalShape(2 * rank); SmallVector broadcastDims; - SmallVector broadcastSizes; - int64_t duplicates = 1; for (unsigned d = 0; d < rank; ++d) { - int64_t numCtas = resultShape[d] / (resultLayout.getSizePerThread()[d] * - resultLayout.getThreadsPerWarp()[d] * - resultLayout.getWarpsPerCTA()[d]); + unsigned resultShapePerCTA = resultLayout.getSizePerThread()[d] * + resultLayout.getThreadsPerWarp()[d] * + resultLayout.getWarpsPerCTA()[d]; + int64_t numCtas = ceil(resultShape[d], resultShapePerCTA); if (srcShape[d] != resultShape[d]) { assert(srcShape[d] == 1); broadcastDims.push_back(d); - broadcastSizes.push_back(resultShape[d]); srcLogicalShape[d] = 1; - srcLogicalShape[d + rank] = 1; - duplicates *= resultShape[d]; + srcLogicalShape[d + rank] = + std::max(unsigned(1), srcLayout.getSizePerThread()[d]); } else { srcLogicalShape[d] = numCtas; srcLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; @@ -845,18 +926,37 @@ struct BroadcastOpConversion resultLogicalShape[d] = numCtas; resultLogicalShape[d + rank] = resultLayout.getSizePerThread()[d]; } - unsigned srcElems = getElemsPerThread(srcLayout, srcShape); + int64_t duplicates = 1; + SmallVector broadcastSizes(broadcastDims.size() * 2); + for (auto it : llvm::enumerate(broadcastDims)) { + // Incase there are multiple indices in the src that is actually + // calculating the same element, srcLogicalShape may not need to be 1. + // Such as the case when src of shape [256, 1], and with a blocked layout: + // sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2] + int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()]; + broadcastSizes[it.index()] = d; + duplicates *= d; + d = resultLogicalShape[it.value() + rank] / + srcLogicalShape[it.value() + rank]; + broadcastSizes[it.index() + broadcastDims.size()] = d; + duplicates *= d; + } + + unsigned srcElems = srcLayout.getElemsPerThread(srcShape); auto elemTy = resultTy.getElementType(); auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter); - unsigned resultElems = getElemsPerThread(resultLayout, resultShape); + unsigned resultElems = resultLayout.getElemsPerThread(resultShape); SmallVector resultVals(resultElems); for (unsigned i = 0; i < srcElems; ++i) { auto srcMultiDim = getMultiDimIndex(i, srcLogicalShape); - auto resultMultiDim = srcMultiDim; for (int64_t j = 0; j < duplicates; ++j) { + auto resultMultiDim = srcMultiDim; auto bcastMultiDim = getMultiDimIndex(j, broadcastSizes); for (auto bcastDim : llvm::enumerate(broadcastDims)) { - resultMultiDim[bcastDim.value()] = bcastMultiDim[bcastDim.index()]; + resultMultiDim[bcastDim.value()] += bcastMultiDim[bcastDim.index()]; + resultMultiDim[bcastDim.value() + rank] += + bcastMultiDim[bcastDim.index() + broadcastDims.size()] * + srcLogicalShape[bcastDim.index() + broadcastDims.size()]; } auto resultLinearIndex = getLinearIndex(resultMultiDim, resultLogicalShape); @@ -871,27 +971,29 @@ struct BroadcastOpConversion } }; -struct ViewOpConversion - : public ConvertTritonGPUOpToLLVMPattern { - using ConvertTritonGPUOpToLLVMPattern< - triton::ViewOp>::ConvertTritonGPUOpToLLVMPattern; +template +struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { + using OpAdaptor = typename SourceOp::Adaptor; + explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult - matchAndRewrite(triton::ViewOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We cannot directly // rewriter.replaceOp(op, adaptor.src()); // due to MLIR's restrictions Location loc = op->getLoc(); - auto resultTy = op.getType().cast(); - auto resultLayout = resultTy.getEncoding().dyn_cast(); + auto resultTy = op.getType().template cast(); auto resultShape = resultTy.getShape(); - unsigned elems = getElemsPerThread(resultLayout, resultShape); + unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); - Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); - auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); + Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); + auto vals = + this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter); Value view = getStructFromElements(loc, vals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); @@ -911,12 +1013,12 @@ struct MakeRangeOpConversion Location loc = op->getLoc(); auto rankedTy = op.result().getType().dyn_cast(); auto shape = rankedTy.getShape(); - auto layout = rankedTy.getEncoding().cast(); + auto layout = rankedTy.getEncoding(); auto elemTy = rankedTy.getElementType(); assert(elemTy.isInteger(32)); Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start()); - auto idxs = emitIndicesForBlockedLayout(loc, rewriter, layout, shape); + auto idxs = emitIndices(loc, rewriter, layout, shape); unsigned elems = idxs.size(); SmallVector retVals(elems); for (auto multiDim : llvm::enumerate(idxs)) { @@ -1184,8 +1286,9 @@ struct AddPtrOpConversion Location loc = op->getLoc(); auto resultTy = op.getType().dyn_cast(); auto resultLayout = resultTy.getEncoding().dyn_cast(); + assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion"); auto resultShape = resultTy.getShape(); - unsigned elems = getElemsPerThread(resultLayout, resultShape); + unsigned elems = resultLayout.getElemsPerThread(resultShape); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); @@ -1225,7 +1328,8 @@ public: auto resultLayout = resultTy.getEncoding().template dyn_cast(); auto resultShape = resultTy.getShape(); - unsigned elems = getElemsPerThread(resultLayout, resultShape); + assert(resultLayout && "Unexpected resultLayout in BinaryOpConversion"); + unsigned elems = resultLayout.getElemsPerThread(resultShape); Type elemTy = this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); @@ -1244,6 +1348,228 @@ public: } }; +struct ConvertLayoutOpConversion + : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + triton::gpu::ConvertLayoutOp>::ConvertTritonGPUOpToLLVMPattern; + + ConvertLayoutOpConversion(LLVMTypeConverter &converter, + const Allocation *allocation, Value smem, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(converter, + benefit), + allocation_(allocation), smem_(smem) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src = op.src(); + Value dst = op.result(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if ((!srcLayout.isa()) || + (!dstLayout.isa())) { + // TODO: not implemented + assert(0 && + "convert_layout except for blocked -> blocked is not implemented"); + return failure(); + } + auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + Value smemBase = getSharedMemoryBase(loc, rewriter, smem_, allocation_, + op.getOperation()); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + smemBase = rewriter.create(loc, elemPtrTy, smemBase); + + auto shape = dstTy.getShape(); + unsigned rank = dstTy.getRank(); + auto getContigPerThread = [&](const Attribute &layout, + unsigned d) -> unsigned { + if (auto blockedLayout = layout.dyn_cast()) { + return blockedLayout.getSizePerThread()[d]; + } else { + assert(0 && "Unimplemented usage of getContigPerThread"); + return 0; + } + }; + auto getAccumElemsPerThread = [&](const Attribute &layout) -> unsigned { + if (auto blockedLayout = layout.dyn_cast()) { + return product(blockedLayout.getSizePerThread()); + } else { + assert(0 && "Unimplemented usage of getAccumElemsPerThread"); + return 0; + } + }; + auto getOrder = [&](const Attribute &layout) -> ArrayRef { + if (auto blockedLayout = layout.dyn_cast()) { + return blockedLayout.getOrder(); + } else { + assert(0 && "Unimplemented usage of getAccumElemsPerThread"); + return {}; + } + }; + SmallVector numReplicates(rank); + SmallVector inNumCTAsEachRep(rank); + SmallVector outNumCTAsEachRep(rank); + SmallVector inNumCTAs(rank); + SmallVector outNumCTAs(rank); + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = + std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d)); + unsigned outPerCTA = + std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d)); + unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); + numReplicates[d] = ceil(shape[d], maxPerCTA); + inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; + outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; + // TODO: confirm this + assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); + inNumCTAs[d] = ceil(shape[d], inPerCTA); + outNumCTAs[d] = ceil(shape[d], outPerCTA); + } + // Potentially we need to store for multiple CTAs in this replication + unsigned accumNumReplicates = product(numReplicates); + unsigned accumInSizePerThread = getAccumElemsPerThread(srcLayout); + unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); + auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); + unsigned inVec = 0; + unsigned outVec = 0; + auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + + unsigned outElems = getElemsPerThread(dstLayout, shape); + auto outOrd = getOrder(dstLayout); + SmallVector outVals(outElems); + for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { + auto multiDimRepId = getMultiDimIndex(repId, numReplicates); + rewriter.create(loc); + if (auto srcBlockedLayout = srcLayout.dyn_cast()) { + processReplicaBlocked(loc, rewriter, /*stNotRd*/ true, srcTy, + inNumCTAsEachRep, multiDimRepId, inVec, + paddedRepShape, outOrd, vals, smemBase); + } else { + assert(0 && "ConvertLayout with input layout not implemented"); + return failure(); + } + rewriter.create(loc); + if (auto dstBlockedLayout = dstLayout.dyn_cast()) { + processReplicaBlocked(loc, rewriter, /*stNotRd*/ false, dstTy, + outNumCTAsEachRep, multiDimRepId, outVec, + paddedRepShape, outOrd, outVals, smemBase); + } else { + assert(0 && "ConvertLayout with output layout not implemented"); + return failure(); + } + } + + SmallVector types(outElems, llvmElemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); + Value result = getStructFromElements(loc, outVals, rewriter, structTy); + rewriter.replaceOp(op, result); + return success(); + } + +private: + template + SmallVector reorder(ArrayRef input, ArrayRef order) const { + size_t rank = order.size(); + assert(input.size() == rank); + SmallVector result(rank); + for (auto it : llvm::enumerate(order)) { + result[rank - 1 - it.value()] = input[it.index()]; + } + return result; + }; + + void processReplicaBlocked(Location loc, ConversionPatternRewriter &rewriter, + bool stNotRd, RankedTensorType type, + ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, + ArrayRef outOrd, + SmallVector &vals, Value smemBase) const { + unsigned accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding().cast(); + auto rank = type.getRank(); + auto sizePerThread = layout.getSizePerThread(); + auto accumSizePerThread = product(sizePerThread); + auto llvmIndexTy = getTypeConverter()->getIndexType(); + SmallVector numCTAs(rank); + SmallVector shapePerCTA(rank); + for (unsigned d = 0; d < rank; ++d) { + shapePerCTA[d] = layout.getSizePerThread()[d] * + layout.getThreadsPerWarp()[d] * + layout.getWarpsPerCTA()[d]; + numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); + } + auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); + auto multiDimOffsetFirstElem = + emitBaseIndexForBlockedLayout(loc, rewriter, layout, type.getShape()); + for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { + auto multiDimCTAInRepId = + getMultiDimIndex(ctaId, numCTAsEachRep); + SmallVector multiDimCTAId(rank); + for (auto it : llvm::enumerate(multiDimCTAInRepId)) { + auto d = it.index(); + multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); + } + + unsigned linearCTAId = getLinearIndex(multiDimCTAId, numCTAs); + // TODO: This is actually redundant index calculation, we should + // consider of caching the index calculation result in case + // of performance issue observed. + // for (unsigned elemId = linearCTAId * accumSizePerThread; + // elemId < (linearCTAId + 1) * accumSizePerThread; elemId += vec) { + for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { + auto multiDimElemId = + getMultiDimIndex(elemId, layout.getSizePerThread()); + SmallVector multiDimOffset(rank); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = rewriter.create( + loc, multiDimOffsetFirstElem[d], + createIndexAttrConstant(rewriter, loc, llvmIndexTy, + multiDimCTAInRepId[d] * shapePerCTA[d] + + multiDimElemId[d])); + } + Value offset = + linearize(rewriter, loc, reorder(multiDimOffset, outOrd), + reorder(paddedRepShape, outOrd)); + auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); + Value ptr = + rewriter.create(loc, elemPtrTy, smemBase, offset); + auto vecTy = VectorType::get(vec, llvmElemTy); + ptr = rewriter.create( + loc, LLVM::LLVMPointerType::get(vecTy, 3), ptr); + if (stNotRd) { + Value valVec = rewriter.create(loc, vecTy); + for (unsigned v = 0; v < vec; ++v) { + Value vVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), v); + valVec = rewriter.create( + loc, vecTy, valVec, + vals[elemId + linearCTAId * accumSizePerThread + v], vVal); + } + rewriter.create(loc, valVec, ptr); + } else { + Value valVec = rewriter.create(loc, ptr); + for (unsigned v = 0; v < vec; ++v) { + Value vVal = createIndexAttrConstant( + rewriter, loc, getTypeConverter()->getIndexType(), v); + vals[elemId + linearCTAId * accumSizePerThread + v] = + rewriter.create(loc, llvmElemTy, valVec, + vVal); + } + } + } + } + } + + const Allocation *allocation_; + Value smem_; +}; + class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { public: using TypeConverter::convertType; @@ -1266,9 +1592,10 @@ public: llvm::Optional convertTritonTensorType(RankedTensorType type) { Attribute layout = type.getEncoding(); - if (auto blocked_layout = layout.dyn_cast()) { + if (layout && (layout.isa() || + layout.isa())) { unsigned numElementsPerThread = - getElemsPerThread(blocked_layout, type.getShape()); + getElemsPerThread(layout, type.getShape()); SmallVector types(numElementsPerThread, convertType(type.getElementType())); return LLVM::LLVMStructType::getLiteral(&getContext(), types); @@ -1285,7 +1612,8 @@ public: void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &analysis, + AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, PatternBenefit benefit = 1) { patterns.add(typeConverter, benefit); patterns.add>(typeConverter, @@ -1296,17 +1624,19 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add>(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, numWarps, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, allocation, smem, + benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, analysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, analysis, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add>(typeConverter, benefit); + patterns.add>(typeConverter, + benefit); } class ConvertTritonGPUToLLVM @@ -1322,19 +1652,34 @@ public: // TODO: need confirm option.overrideIndexBitwidth(32); TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context, typeConverter); TritonLLVMConversionTarget target(*context, typeConverter); - RewritePatternSet patterns(context); - int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + // step 1: Convert FuncOp to LLVMFuncOp via partial conversion + // step 2: Allocate for shared memories + // step 3: Convert the rest of ops via partial conversion + // The reason for a seperation between 1/3 is that, step 2 is out of + // the scope of Dialect Conversion, thus we need to make sure the smem_ + // is not revised during the conversion of step 3. + RewritePatternSet func_patterns(context); + func_patterns.add(typeConverter, numWarps, 1 /*benefit*/); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(func_patterns)))) + return signalPassFailure(); + + Allocation allocation(mod); auto axisAnalysis = runAxisAnalysis(mod); + initSharedMemory(allocation.getSharedMemorySize(), typeConverter); // We set a higher benefit here to ensure triton's patterns runs before // arith patterns for some encoding not supported by the community // patterns. + RewritePatternSet patterns(context); populateTritonToLLVMPatterns(typeConverter, patterns, numWarps, - *axisAnalysis, 10 /*benefit*/); + *axisAnalysis, &allocation, smem_, + 10 /*benefit*/); // Add arith/math's patterns to help convert scalar expression to LLVM. mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, @@ -1352,10 +1697,35 @@ protected: auto axisAnalysisPass = std::make_unique(module->getContext()); axisAnalysisPass->run(module); + return axisAnalysisPass; } + + void initSharedMemory(size_t size, + TritonGPUToLLVMTypeConverter &typeConverter); + + Value smem_; }; +void ConvertTritonGPUToLLVM::initSharedMemory( + size_t size, TritonGPUToLLVMTypeConverter &typeConverter) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, size); + auto global = b.create( + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal, + "global_smem", /*value=*/Attribute(), + /*alignment=*/0, mlir::gpu::GPUDialect::getWorkgroupAddressSpace()); + SmallVector funcs; + mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); }); + assert(funcs.size() == 1 && + "Inliner pass is expected before TritonGPUToLLVM"); + b.setInsertionPointToStart(&funcs[0].getBody().front()); + smem_ = b.create(loc, global); +} + } // namespace namespace mlir { @@ -1366,10 +1736,20 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget( addLegalDialect(); addLegalDialect(); // addIllegalDialect(); + // addIllegalDialect(); addIllegalDialect(); addLegalOp(); } +TritonLLVMFunctionConversionTarget::TritonLLVMFunctionConversionTarget( + MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter) + : ConversionTarget(ctx), typeConverter(typeConverter) { + addLegalDialect(); + // addLegalDialect(); + addIllegalOp(); + addLegalOp(); +} + namespace triton { std::unique_ptr> createConvertTritonGPUToLLVMPass() { diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d337ed662..8ccaf6f3d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -39,6 +39,37 @@ static Type getPointeeType(Type type) { return Type(); } +namespace gpu { + +// TODO: Inheritation of layout attributes +unsigned getElemsPerThread(Attribute layout, ArrayRef shape) { + size_t rank = shape.size(); + if (auto blockedLayout = layout.dyn_cast()) { + return blockedLayout.getElemsPerThread(shape); + } else if (auto sliceLayout = layout.dyn_cast()) { + return sliceLayout.getElemsPerThread(shape); + } else if (auto mmaLayout = layout.dyn_cast()) { + return mmaLayout.getElemsPerThread(shape); + } else if (auto sharedLayout = layout.dyn_cast()) { + return sharedLayout.getElemsPerThread(shape); + } else { + assert(0 && "getElemsPerThread not implemented"); + return 0; + } +} + +unsigned getShapePerCTA(const Attribute &layout, unsigned d) { + if (auto blockedLayout = layout.dyn_cast()) { + return blockedLayout.getSizePerThread()[d] * + blockedLayout.getThreadsPerWarp()[d] * + blockedLayout.getWarpsPerCTA()[d]; + } else { + assert(0 && "Unimplemented usage of getShapePerCTA"); + return 0; + } +}; + +} // namespace gpu } // namespace triton } // namespace mlir @@ -108,6 +139,55 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { return SliceEncodingAttr::get(getContext(), axis, *this); } +unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef shape) const { + size_t rank = shape.size(); + assert(rank == getSizePerThread().size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThreadPerDim(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = + getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i]; + elemsPerThreadPerDim[i] = + ceil(shape[i], t) * getSizePerThread()[i]; + } + return product(elemsPerThreadPerDim); +} + +unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { + size_t rank = shape.size(); + auto parent = getParent(); + unsigned dim = getDim(); + if (auto blockedParent = parent.dyn_cast()) { + assert(rank == blockedParent.getSizePerThread().size() - 1 && + "unexpected rank in SliceEncodingAttr::getElemsPerThread"); + SmallVector paddedShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + paddedShape[d] = shape[d]; + else if (d == dim) + paddedShape[d] = 1; + else + paddedShape[d] = shape[d - 1]; + } + return blockedParent.getElemsPerThread(paddedShape); + } else { + assert(0 && "getElemsPerThread not implemented"); + return 0; + } +} + +unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { + // TODO: + assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented"); + return 0; +} + +unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { + // TODO: + assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented"); + return 0; +} + //===----------------------------------------------------------------------===// // Blocked Encoding //===----------------------------------------------------------------------===// diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 03ebb1578..5a04496d5 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/driver/llvm.h" +#include "triton/tools/sys/getenv.hpp" #include "llvm/IR/Constants.h" namespace mlir { @@ -124,6 +125,17 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) { mlir::PassManager pm(module->getContext()); applyPassManagerCLOptions(pm); + auto printingFlags = mlir::OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + pm.enableIRPrinting( + /*shouldPrintBeforePass=*/nullptr, + /*shouldPrintAfterPass=*/ + [](mlir::Pass *pass, mlir::Operation *) { + return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + }, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/true, + /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); pm.addPass(createConvertTritonGPUToLLVMPass()); // Conanicalize to eliminate the remaining UnrealizedConversionCastOp diff --git a/python/src/triton.cc b/python/src/triton.cc index 20d12ce56..6810d72fd 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -19,6 +19,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" +#include "triton/tools/sys/getenv.hpp" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -100,14 +101,6 @@ long pow2_divisor(long N) { return 1; } -bool getBoolEnv(const std::string &env) { - const char *s = std::getenv(env.c_str()); - std::string str(s ? s : ""); - std::transform(str.begin(), str.end(), str.begin(), - [](unsigned char c) { return std::tolower(c); }); - return (str == "on" || str == "true" || str == "1"); -} - // Returns something like "int16", whether dtype is a torch.dtype or // triton.language.dtype. std::string dtype_cache_key_part(const py::object &dtype) { @@ -1635,7 +1628,7 @@ void init_triton_ir(py::module &&m) { /*shouldPrintBeforePass=*/nullptr, /*shouldPrintAfterPass=*/ [](mlir::Pass *pass, mlir::Operation *) { - return getBoolEnv("MLIR_ENABLE_DUMP"); + return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); }, /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, diff --git a/python/tests/test_transpose.py b/python/tests/test_transpose.py new file mode 100644 index 000000000..3daee7bfe --- /dev/null +++ b/python/tests/test_transpose.py @@ -0,0 +1,68 @@ +import pytest +import torch +from torch.testing import assert_allclose + +import triton +import triton.language as tl +import triton.runtime as runtime + + +@triton.jit +def kernel(x_ptr, stride_xm, + z_ptr, stride_zn, + SIZE_M: tl.constexpr, SIZE_N: tl.constexpr): + off_m = tl.arange(0, SIZE_M) + off_n = tl.arange(0, SIZE_N) + Xs = x_ptr + off_m[:, None] * stride_xm + off_n[None, :] * 1 + Zs = z_ptr + off_m[:, None] * 1 + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + +# These sizes cover the case of: +# - blocked layout and sliced layout with block parent +# -- blocked layout in which sizePerThread/threadsPerWarp/warpsPerCTA +# need/need not to be wrapped +# -- sliced layout incase sizePerThread need to be wrapped +# -- different orders +# - LayoutConversion from blocked -> blocked +# - tt.Broadcast which requires for broadcast in either/both of +# CTA/perThread level + +# What is not covered and requires for TODO: +# - vectorization load/store of shared memory +# - multiple replication of layout conversion + + +@pytest.mark.parametrize('NUM_WARPS,SIZE_M,SIZE_N', [ + [1, 16, 16], + [1, 32, 32], + [1, 32, 64], + [2, 64, 128], + [2, 128, 64] +]) +def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N): + # TODO: this is to initialize the cuda context since it is not properly + # dealed with in the existing runtime, remove this when the runtime + # is updated + torch.zeros([10], device=torch.device('cuda')) + device = torch.cuda.current_device() + binary = runtime.build_kernel(kernel, + "*fp32,i32,*fp32,i32", + constants={"SIZE_M": SIZE_M, + "SIZE_N": SIZE_N}, + num_warps=NUM_WARPS, + num_stages=3) + grid = lambda META: (1, ) + + x = torch.randn((SIZE_M, SIZE_N), device='cuda', dtype=torch.float32) + z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype) + runtime.launch_kernel(kernel=binary, + device=device, + grid=grid, + x_ptr=x, + stride_xm=x.stride(0), + z_ptr=z, + stride_zn=z.stride(0), + SIZE_M=tl.constexpr(SIZE_M), + SIZE_N=tl.constexpr(SIZE_N)) + golden_z = torch.t(x) + assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py index 780d76926..6f72171ab 100644 --- a/python/tests/test_vecadd_no_scf.py +++ b/python/tests/test_vecadd_no_scf.py @@ -48,6 +48,7 @@ def vecadd_no_scf_tester(num_warps, block_size): def test_vecadd_no_scf(): + vecadd_no_scf_tester(num_warps=4, block_size=256) vecadd_no_scf_tester(num_warps=2, block_size=256) vecadd_no_scf_tester(num_warps=1, block_size=256) diff --git a/test/Conversion/triton_to_llvm.mlir b/test/Conversion/triton_to_llvm.mlir deleted file mode 100644 index ec16d2fbf..000000000 --- a/test/Conversion/triton_to_llvm.mlir +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s - -func @test_splat(%ptr: !tt.ptr) { - // Here, 128 elements, 64(2*32) threads, so each need to process 2 elements - // - // CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr to !llvm.ptr - // CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr, ptr)> - // CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr, ptr)> - // CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr, ptr)> - %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> - %a = arith.constant 1.0 : f32 - %true = arith.constant 1 : i1 - %b = tt.splat %a : (f32) -> tensor<128xf32> - - // Here, each thread process only 1 element - // CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)> - %mask = tt.splat %true : (i1) -> tensor<64xi1> - - return -} - -// ----- - -func @test_store_splat(%ptr: !tt.ptr) { - %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> - %a = arith.constant 1.0 : f32 - %true = arith.constant 1 : i1 - - %vs = tt.splat %a : (f32) -> tensor<128xf32> - %mask = tt.splat %true : (i1) -> tensor<128xi1> - - // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", - // CHECK-SAME: "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, i1) -> !llvm.void - tt.store %ptrs, %vs, %mask : tensor<128xf32> - - return -} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 1d84f6275..59ae739c6 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,16 +1,13 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s module attributes {"triton_gpu.num-warps" = 4 : i32} { - -// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) -// Here the 128 comes from the 4 in module attribute multiples 32 -// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}} -func @test_empty_kernel(%lb : index, %A : !tt.ptr) { - - // CHECK: llvm.return - return -} - + // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) + // Here the 128 comes from the 4 in module attribute multiples 32 + // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}} + func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + // CHECK: llvm.return + return + } } // end module // ----- @@ -58,7 +55,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- -// TODO: Pending on the support of isSplat constant +// TODO: masked load with vectorization is pending on TODO #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other @@ -71,10 +68,23 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // ----- +// TODO: masked load with vectorization is pending on TODO +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: masked_load_const_other_vec + func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> + %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> module attributes {"triton_gpu.num-warps" = 2 : i32} { - // CHECK-LABEL: kernel__Pfp32_Pfp32_Pfp32_i32__3c256 - func @kernel__Pfp32_Pfp32_Pfp32_i32__3c256(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + // CHECK-LABEL: global_load_store_no_vec + func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -86,22 +96,107 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> - // CHECK: ld.global.v4.b32 + // Load 4 elements from vector0 + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 4 elements from vector1 + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ]; %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> - // CHECK: ld.global.v4.b32 %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> // Store 4 elements to global - // CHECK: st.global.b32.v4 + // CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; + // CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; tt.store %13, %11 : tensor<256xf32, #blocked0> return } } +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> +module attributes {"triton_gpu.num-warps" = 2 : i32} { + // CHECK-LABEL: global_load_store_vec4 + func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + + // Load 4 elements from A with single one vectorized load instruction + // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 4 elements from B with single one vectorized load instruction + // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + + // Store 4 elements to global with single one vectorized store instruction + // CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256xf32, #blocked0> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: global_load_store_vec8 + func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id {axis = 0 : i32} : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0> + + // Load 8 elements from A with two vectorized load instruction + // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + // Load 8 elements from B with two vectorized load instruction + // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + // CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ]; + + %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0> + + // Store 8 elements to global with two vectorized store instruction + // CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + // CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }; + tt.store %13, %11 : tensor<256xf32, #blocked0> + return + } +} // TODO: Add a testcase to verify the optimization when ptr of the LoadOp // is from an addptr with const idx @@ -217,10 +312,121 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, i1) -> !llvm.void + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };", "r,l,b" %{{.*}}, %{{.*}}, %{{.*}} : (i32, !llvm.ptr, i1) -> !llvm.void + // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0> return } } + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1088 x i8> + // CHECK-LABEL: convert_layout_blocked_blocked + func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: nvvm.barrier0 + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<1280 x i8> + // CHECK-LABEL: convert_layout_blocked_blocked_vec + func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: nvvm.barrier0 + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> + return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<640 x i8> + // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep + func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: nvvm.barrier0 + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.store + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: nvvm.barrier0 + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + // CHECK: llvm.load + // CHECK-SAME: !llvm.ptr, 3> + %0 = triton_gpu.convert_layout %arg0 : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #blocked1> + return + } +} + +// TODO: problems in MLIR's parser on slice layout +// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +// module attributes {"triton_gpu.num-warps" = 1 : i32} { +// func @make_range_sliced_layout() { +// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> +// return +// } +// } \ No newline at end of file