diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9fcffd8a6..9621a580a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -28,6 +28,8 @@ lib/Analysis/Utility.cpp @Jokeren # ---------- # Pipeline pass lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @daadaada +# Prefetch pass +lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @daadaada # Coalesce pass lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @ptillet # Layout simplification pass diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 3c2953fe1..fb4a64607 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -4,6 +4,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -30,7 +31,15 @@ public: virtual LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, - Attribute &resultEncoding) const = 0; + Attribute &resultEncoding, + Optional location) const = 0; + + // Note: this function only verify operand encoding but doesn't infer result + // encoding + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + Optional location) const = 0; }; } // namespace triton diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 85e298dd4..8216a6317 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -330,7 +330,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { }]; } - def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> { let mnemonic = "dot_op"; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 4a8824193..488c6a72d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -37,7 +37,7 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { // Port Arith_CmpIOp & Arith_CmpFOp & Std_SelectOp to TritonGPU. // This is needed because these ops don't // handle encodings -// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111 +// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> { let summary = "integer comparison operation"; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 4b3b3e8b6..f70350d89 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -6,6 +6,9 @@ namespace mlir { std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); +// TODO(Keren): prefetch pass not working yet +std::unique_ptr createTritonGPUPrefetchPass(); + std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); std::unique_ptr createTritonGPUSwizzlePass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 493f9afd7..8f3f0f32f 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -7,7 +7,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ - TODO + Unroll loops to hide global memory -> shared memory latency. }]; let constructor = "mlir::createTritonGPUPipelinePass()"; @@ -23,6 +23,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { ]; } +def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { + let summary = "prefetch"; + + let description = [{ + Prefetch operands (a and b) of tt.dot into shared memory to hide shared memory -> register latency. + }]; + + let constructor = "mlir::createTritonGPUPrefetchPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithmeticDialect"]; +} + def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { let summary = "coalesce"; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 35795376b..e70dc935e 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -12,6 +12,7 @@ #include using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; @@ -26,6 +27,26 @@ namespace mlir { //===----------------------------------------------------------------------===// namespace triton { +static std::pair, SmallVector> +getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) { + auto srcBlockedLayout = srcLayout.dyn_cast(); + auto srcMmaLayout = srcLayout.dyn_cast(); + auto srcDotLayout = srcLayout.dyn_cast(); + auto dstBlockedLayout = dstLayout.dyn_cast(); + auto dstMmaLayout = dstLayout.dyn_cast(); + auto dstDotLayout = dstLayout.dyn_cast(); + assert(!(srcMmaLayout && dstMmaLayout) && + "Unexpected mma -> mma layout conversion"); + // mma or dot layout does not have an order, so the order depends on the + // layout of the other operand. + auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) + : getOrder(srcLayout); + auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) + : getOrder(dstLayout); + + return {inOrd, outOrd}; +} + SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec) { @@ -35,16 +56,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, Attribute dstLayout = dstTy.getEncoding(); assert(srcLayout && dstLayout && "Unexpect layout in getScratchConfigForCvtLayout()"); - unsigned rank = dstTy.getRank(); - SmallVector paddedRepShape(rank); - auto srcBlockedLayout = srcLayout.dyn_cast(); - auto srcMmaLayout = srcLayout.dyn_cast(); - auto dstBlockedLayout = dstLayout.dyn_cast(); - auto dstMmaLayout = dstLayout.dyn_cast(); - assert(!(srcMmaLayout && dstMmaLayout) && - "Unexpected mma -> mma layout conversion"); - auto inOrd = srcMmaLayout ? getOrder(dstLayout) : getOrder(srcLayout); - auto outOrd = dstMmaLayout ? getOrder(srcLayout) : getOrder(dstLayout); + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); unsigned srcContigPerThread = getSizePerThread(srcLayout)[inOrd[0]]; unsigned dstContigPerThread = getSizePerThread(dstLayout)[outOrd[0]]; // TODO: Fix the legacy issue that ourOrd[0] == 0 always means @@ -55,6 +67,8 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, auto srcShapePerCTA = getShapePerCTA(srcLayout); auto dstShapePerCTA = getShapePerCTA(dstLayout); + unsigned rank = dstTy.getRank(); + SmallVector paddedRepShape(rank); unsigned pad = std::max(inVec, outVec); for (unsigned d = 0; d < rank; ++d) { paddedRepShape[d] = @@ -143,8 +157,6 @@ private: /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { - // TODO(Keren): Add atomic ops - // TODO(Keren): Add convert ops if (auto reduceOp = dyn_cast(op)) { // TODO(Keren): Reduce with index is not supported yet. auto value = op->getOperand(0); @@ -167,7 +179,7 @@ private: auto dstEncoding = dstTy.getEncoding(); if (srcEncoding.isa() || dstEncoding.isa()) { - // Only blocked -> blocked conversion requires for scratch allocation + // Conversions from/to shared memory do not need scratch memory. return; } // ConvertLayoutOp with both input/output non-shared_layout diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 30711eb5c..4df17be02 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2326,6 +2326,19 @@ private: LogicalResult lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const; + + // shared -> dot_operand if the result layout is mma + Value lowerSharedToDotOperandMMA( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, + const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const; + + // shared -> dot_operand if the result layout is blocked + Value lowerSharedToDotOperandBlocked( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const BlockedEncodingAttr &blockedLayout, + const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const; }; void ConvertLayoutOpConversion::processReplica( @@ -3011,6 +3024,7 @@ public: Value i8Elems[4][4]; Type elemTy = type::i8Ty(ctx); Type elemPtrTy = ptr_ty(elemTy); + Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); if (kOrder == 1) { for (int i = 0; i < 2; ++i) for (int j = 0; j < 4; ++j) @@ -3025,7 +3039,7 @@ public: for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); - i32Elems[m] = bitcast(i8v4Elems[m], i32_ty); + i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); } } else { // k first for (int j = 0; j < 4; ++j) @@ -3041,7 +3055,7 @@ public: for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); - i32Elems[m] = bitcast(i8v4Elems[m], i32_ty); + i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); } } @@ -3725,8 +3739,7 @@ struct MMA16816ConversionHelper { loadFn(2 * m, 2 * k); // step2. Format the values to LLVM::Struct to passing to mma codegen. - Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); - return result; + return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); } // Loading $b from smem to registers, returns a LLVM::Struct. @@ -3963,31 +3976,14 @@ private: } }; -LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( +Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, + const DotOperandEncodingAttr &dotOperandLayout, bool isOuter) const { auto loc = op.getLoc(); Value src = op.src(); Value dst = op.result(); auto dstTensorTy = dst.getType().cast(); - - auto dotOperandLayout = - dstTensorTy.getEncoding().cast(); - - MmaEncodingAttr mmaLayout = - dotOperandLayout.getParent().dyn_cast_or_null(); - assert(mmaLayout); - - bool isOuter{}; - { - int K{}; - if (dotOperandLayout.getOpIdx() == 0) // $a - K = dstTensorTy.getShape()[1]; - else // $b - K = dstTensorTy.getShape()[0]; - isOuter = K == 1; - } - // TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it // is an attribute of DotOp. bool allowTF32 = false; @@ -4023,6 +4019,41 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( } else { assert(false && "Unsupported mma layout found"); } + return res; +} + +LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = op.src(); + Value dst = op.result(); + auto dstTensorTy = dst.getType().cast(); + auto srcTensorTy = src.getType().cast(); + auto dotOperandLayout = + dstTensorTy.getEncoding().cast(); + auto sharedLayout = srcTensorTy.getEncoding().cast(); + + bool isOuter{}; + int K{}; + if (dotOperandLayout.getOpIdx() == 0) // $a + K = dstTensorTy.getShape()[sharedLayout.getOrder()[0]]; + else // $b + K = dstTensorTy.getShape()[sharedLayout.getOrder()[1]]; + isOuter = K == 1; + + Value res; + if (auto mmaLayout = + dotOperandLayout.getParent().dyn_cast_or_null()) { + res = lowerSharedToDotOperandMMA(op, adaptor, rewriter, mmaLayout, + dotOperandLayout, isOuter); + } else if (auto blockedLayout = + dotOperandLayout.getParent() + .dyn_cast_or_null()) { + assert(false && "Blocked layout is not supported yet"); + } else { + assert(false && "Unsupported dot operand layout found"); + } rewriter.replaceOp(op, res); return success(); @@ -4046,23 +4077,13 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); - Value loadedA, loadedB, loadedC; - // We support two kinds of operand layouts: 1. both $a, $b are dot_operand - // layout, 2. both of them are shared layout. - if (ATensorTy.getEncoding().isa()) { - assert(BTensorTy.getEncoding().isa() && - "Both $a and %b should be DotOperand layout."); - loadedA = adaptor.a(); - loadedB = adaptor.b(); - } else { - SharedMemoryObject smemA = - getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter); - SharedMemoryObject smemB = - getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter); - loadedA = mmaHelper.loadA(op.a(), smemA); - loadedB = mmaHelper.loadB(op.b(), smemB); - } + assert(ATensorTy.getEncoding().isa() && + BTensorTy.getEncoding().isa() && + "Both $a and %b should be DotOperand layout."); + Value loadedA, loadedB, loadedC; + loadedA = adaptor.a(); + loadedB = adaptor.b(); loadedC = mmaHelper.loadC(op.c(), adaptor.c()); return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op, @@ -4753,20 +4774,26 @@ public: auto mmaLayout = dot_op_layout.getParent().cast(); auto wpt = mmaLayout.getWarpsPerCTA(); Type elemTy = type.getElementType(); + auto vecSize = 1; + if (elemTy.getIntOrFloatBitWidth() == 16) { + vecSize = 2; + } else if (elemTy.getIntOrFloatBitWidth() == 8) { + vecSize = 4; + } else { + assert(false && "Unsupported element type"); + } + Type vecTy = vec_ty(elemTy, vecSize); if (mmaLayout.getVersion() == 2) { - if (dot_op_layout.getOpIdx() == 0) { // $a int elems = MMA16816ConversionHelper::getANumElemsPerThread(type, wpt); - Type x2Ty = vec_ty(elemTy, 2); return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems, x2Ty)); + ctx, SmallVector(elems, vecTy)); } if (dot_op_layout.getOpIdx() == 1) { // $b int elems = MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); - Type x2Ty = vec_ty(elemTy, 2); - return struct_ty(SmallVector(elems, x2Ty)); + return struct_ty(SmallVector(elems, vecTy)); } } @@ -4775,13 +4802,11 @@ public: if (dot_op_layout.getOpIdx() == 0) { // $a int elems = helper.numElemsPerThreadA(type); - Type x2Ty = vec_ty(elemTy, 2); - return struct_ty(SmallVector(elems, x2Ty)); + return struct_ty(SmallVector(elems, vecTy)); } if (dot_op_layout.getOpIdx() == 1) { // $b int elems = helper.numElemsPerThreadB(type); - Type x2Ty = vec_ty(elemTy, 2); - return struct_ty(SmallVector(elems, x2Ty)); + return struct_ty(SmallVector(elems, vecTy)); } } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index a8e7ffeeb..32448f46d 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -221,6 +221,7 @@ struct TritonDotPattern : public OpConversionPattern { matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); + Attribute dEncoding = retType.cast().getEncoding(); // a & b must be of smem layout auto aType = adaptor.a().getType().cast(); auto bType = adaptor.b().getType().cast(); @@ -230,17 +231,16 @@ struct TritonDotPattern : public OpConversionPattern { return failure(); Value a = adaptor.a(); Value b = adaptor.b(); - SmallVector order{1, 0}; - if (!aEncoding.isa()) { + if (!aEncoding.isa()) { Attribute encoding = - triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); + triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding); auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); a = rewriter.create(a.getLoc(), dstType, a); } - if (!bEncoding.isa()) { + if (!bEncoding.isa()) { Attribute encoding = - triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); + triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding); auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 32a0b7204..3982c6b58 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -191,6 +191,20 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( // type is the same as the accumulator auto accTy = operands[2].getType().cast(); inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = operands[0].getType().cast().getEncoding(); + auto bEnc = operands[1].getType().cast().getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return mlir::failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return mlir::failure(); + } return mlir::success(); } @@ -244,7 +258,7 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { //-- ExpandDimsOp -- mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext *context, Optional loc, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { // infer shape @@ -260,11 +274,9 @@ mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes( Dialect &dialect = argEncoding.getDialect(); auto inferLayoutInterface = dyn_cast(&dialect); if (inferLayoutInterface - ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding) - .failed()) { - llvm::report_fatal_error("failed to infer layout for ExpandDimsOp"); - return mlir::failure(); - } + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); } // create type auto argEltTy = argTy.getElementType(); diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index e457a8564..dabd6b9fc 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -48,7 +48,8 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) { << " has more than that"; if ((numElements & (numElements - 1)) != 0) return op->emitError("Number of elements must be power-of-two, but ") - << *op << " doesn't follow the rule"; + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; } } for (auto opType : op->getResultTypes()) { @@ -62,7 +63,8 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) { << " has more than that"; if ((numElements & (numElements - 1)) != 0) return op->emitError("Number of elements must be power-of-two, but ") - << *op << " doesn't follow the rule"; + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; } } return success(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6ffed8df9..9d3a0bcd1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -57,6 +57,8 @@ unsigned getElemsPerThread(Type type) { return mmaLayout.getElemsPerThread(shape); } else if (auto sharedLayout = layout.dyn_cast()) { return sharedLayout.getElemsPerThread(shape); + } else if (auto dotLayout = layout.dyn_cast()) { + return dotLayout.getElemsPerThread(shape); } else { assert(0 && "getElemsPerThread not implemented"); return 0; @@ -73,6 +75,27 @@ SmallVector getSizePerThread(Attribute layout) { assert(mmaLayout.getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); return SmallVector{2, 2}; + } else if (auto dotLayout = layout.dyn_cast()) { + auto parentLayout = dotLayout.getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = parentLayout.dyn_cast()) { + assert(parentMmaLayout.getVersion() == 2 && + "mmaLayout version = 1 is not implemented yet"); + auto parentShapePerCTA = getShapePerCTA(parentLayout); + auto opIdx = dotLayout.getOpIdx(); + if (opIdx == 0) { + return {2, 4}; + } else if (opIdx == 1) { + return {4, 1}; + } else { + assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } + } else { + assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not " + "supported yet"); + return {}; + } } else { assert(0 && "getSizePerThread not implemented"); return {}; @@ -124,6 +147,25 @@ SmallVector getShapePerCTA(const Attribute &layout) { return {16 * mmaLayout.getWarpsPerCTA()[0], 16 * mmaLayout.getWarpsPerCTA()[1]}; assert(0 && "Unexpected MMA layout version found"); + } else if (auto dotLayout = layout.dyn_cast()) { + auto parentLayout = dotLayout.getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = parentLayout.dyn_cast()) { + assert(parentMmaLayout.getVersion() == 2 && + "mmaLayout version = 1 is not implemented yet"); + auto parentShapePerCTA = getShapePerCTA(parentLayout); + auto opIdx = dotLayout.getOpIdx(); + if (opIdx == 0) { + return {parentShapePerCTA[0], 16}; + } else if (opIdx == 1) { + return {16, parentShapePerCTA[1]}; + } else { + assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + } + } else { + assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not " + "supported yet"); + } } else { assert(0 && "Unimplemented usage of getShapePerCTA"); } @@ -136,6 +178,8 @@ SmallVector getOrder(const Attribute &layout) { blockedLayout.getOrder().end()); } else if (auto mmaLayout = layout.dyn_cast()) { return SmallVector{1, 0}; + } else if (auto dotLayout = layout.dyn_cast()) { + return SmallVector{1, 0}; } else if (auto sliceLayout = layout.dyn_cast()) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); unsigned dim = sliceLayout.getDim(); @@ -300,6 +344,12 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { return 0; } +unsigned +DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape) const { + assert(0 && "DotOPerandEncodingAttr::getElemsPerThread not implemented"); + return 0; +} + //===----------------------------------------------------------------------===// // Blocked Encoding //===----------------------------------------------------------------------===// @@ -471,6 +521,30 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { << "}>"; } +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned opIdx = attrs.get("opIdx").cast().getInt(); + Attribute parent = attrs.get("parent"); + + return parser.getChecked(parser.getContext(), opIdx, + parent); +} + +void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "opIdx = " << getOpIdx() << ", " + << "parent = " << getParent() << "}>"; +} + //===----------------------------------------------------------------------===// // InsertSliceAsyncOp //===----------------------------------------------------------------------===// @@ -530,30 +604,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer, printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); } -//===----------------------------------------------------------------------===// -// DotOperand Encoding -//===----------------------------------------------------------------------===// -Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { - if (parser.parseLess().failed()) - return {}; - NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs).failed()) - return {}; - if (parser.parseGreater().failed()) - return {}; - unsigned opIdx = attrs.get("opIdx").cast().getInt(); - Attribute parent = attrs.get("parent"); - - return parser.getChecked(parser.getContext(), opIdx, - parent); -} - -void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { - printer << "<{" - << "opIdx = " << getOpIdx() << ", " - << "parent = " << getParent() << "}>"; -} - //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// @@ -594,21 +644,32 @@ struct TritonGPUInferLayoutInterface LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, - Attribute &resultEncoding) const override { + Attribute &resultEncoding, + Optional location) const override { auto sliceEncoding = operandEncoding.dyn_cast(); - if (!sliceEncoding) { - llvm::report_fatal_error( - "ExpandDimsOp operand encoding must be SliceEncodingAttr"); - return failure(); - } - if (sliceEncoding.getDim() != axis) { - llvm::report_fatal_error( - "Incompatible slice dimension for ExpandDimsOp operand"); - return failure(); - } + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); resultEncoding = sliceEncoding.getParent(); return success(); } + + LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + Optional location) const override { + if (auto dotOpEnc = operandEncoding.dyn_cast()) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } }; void TritonGPUDialect::initialize() { diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 464a1aabf..6f440df5d 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms CanonicalizeLoops.cpp Combine.cpp Pipeline.cpp + Prefetch.cpp Swizzle.cpp TritonGPUConversion.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index d7c622e50..db8a0ebe9 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -12,21 +12,13 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include using namespace mlir; - -static bool isSharedLayout(Value v) { - if (auto tensorType = v.getType().dyn_cast()) { - Attribute encoding = tensorType.getEncoding(); - return encoding.isa(); - } - return false; -} - namespace { #include "TritonGPUCombine.inc" @@ -37,7 +29,7 @@ namespace { // convert(blocked, dot_operand) -> // convert(blocked, mma) + convert(mma, dot_operand) // if this value is itself the result of a dot operation -// this is a hueiristics to accomodate some pattern seen in fused attention +// this is a heuristic to accomodate some pattern seen in fused attention // kernels. // TODO: replace this by something more generic, i.e. layout-aware CSE class DecomposeDotOperand : public mlir::RewritePattern { @@ -59,9 +51,8 @@ public: dstType.getEncoding().isa()) { auto tmpType = RankedTensorType::get(dstType.getShape(), dstType.getElementType(), - dstType.getEncoding() - .cast() - .getParent()); + triton::gpu::SharedEncodingAttr::get( + op->getContext(), 1, 1, 1, {1, 0})); auto tmp = rewriter.create( convert.getLoc(), tmpType, convert.getOperand()); auto newConvert = rewriter.create( @@ -87,11 +78,12 @@ public: if (!llvm::isa(op)) return mlir::failure(); auto convert = llvm::cast(op); + auto srcType = convert.getOperand().getType().cast(); auto dstType = convert.getType().cast(); // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accomodate fused attention - if (dstType.getEncoding().isa()) - return mlir::failure(); + // if (dstType.getEncoding().isa()) + // return mlir::failure(); // convert to the same layout -- we can delete if (op->getResultTypes() == op->getOperandTypes()) { rewriter.replaceOp(op, op->getOperands()); @@ -122,8 +114,8 @@ public: rewriter.replaceOpWithNewOp( op, newType, insert_slice.src(), newArg.getResult(), insert_slice.index(), insert_slice.mask(), insert_slice.other(), - insert_slice.cache(), insert_slice.evict(), - insert_slice.isVolatile(), insert_slice.axis()); + insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(), + insert_slice.axis()); return mlir::success(); } // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) @@ -133,7 +125,10 @@ public: auto newType = RankedTensorType::get( origType.getShape(), origType.getElementType(), op->getResult(0).getType().cast().getEncoding()); - auto resType = op->getResult(0).getType().cast(); + auto origResType = op->getResult(0).getType().cast(); + auto resType = RankedTensorType::get( + origResType.getShape(), origResType.getElementType(), + extract_slice.getType().cast().getEncoding()); // Ensure that the new extract_slice op is placed in the same place as the // old extract_slice op. Otherwise, the new extract_slice op may be placed // after the async_wait op, which is not allowed. @@ -148,8 +143,21 @@ public: extract_slice.static_strides()); return mlir::success(); } + // cvt(type2, x) if (llvm::isa(arg)) { + auto argType = arg->getOperand(0).getType().cast(); + if (arg->getOperand(0).getDefiningOp() && + !argType.getEncoding().isa() && + srcType.getEncoding().isa() && + !dstType.getEncoding().isa()) { + + return mlir::failure(); + } + auto srcShared = + srcType.getEncoding().dyn_cast(); + if (srcShared && srcShared.getVec() > 1) + return mlir::failure(); rewriter.replaceOpWithNewOp( op, op->getResultTypes().front(), arg->getOperand(0)); return mlir::success(); @@ -253,8 +261,8 @@ public: if (!op) return mlir::failure(); // we don't want to rematerialize any conversion to/from shared - if (isSharedLayout(cvt->getResults()[0]) || - isSharedLayout(cvt->getOperand(0))) + if (isSharedEncoding(cvt->getResults()[0]) || + isSharedEncoding(cvt->getOperand(0))) return mlir::failure(); // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accomodate fused attention @@ -325,7 +333,6 @@ public: for (Operation *op : tmp) sortedValues.push_back(op->getResult(0)); - // llvm::outs() << "----\n"; BlockAndValueMapping mapping; for (Value currOperand : sortedValues) { // unpack information @@ -346,7 +353,6 @@ public: newOperand->moveAfter(currOperation); mapping.map(currOperand, newOperand); } - // llvm::outs() << cvt->getParentOfType() << "\n"; rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); return mlir::success(); } @@ -356,8 +362,6 @@ public: // // ----------------------------------------------------------------------------- -// int test = 0; - class MoveConvertOutOfLoop : public mlir::RewritePattern { public: MoveConvertOutOfLoop(mlir::MLIRContext *context) @@ -435,9 +439,25 @@ public: auto users = iterArg.value().getUsers(); // check first condition SetVector cvtTargetTypes; - for (auto user : users) - if (isa(user)) - cvtTargetTypes.insert(user->getResults()[0].getType()); + for (auto user : users) { + if (isa(user)) { + auto newType = + user->getResults()[0].getType().cast(); + auto oldType = user->getOperand(0).getType().cast(); + if (oldType.getEncoding().isa() && + newType.getEncoding() + .isa()) { + continue; + } + if (newType.getEncoding().isa()) { + if (newType.getEncoding() + .cast() + .getVec() == 1) + continue; + } + cvtTargetTypes.insert(newType); + } + } if (cvtTargetTypes.size() != 1) continue; // TODO: check second condition @@ -446,6 +466,7 @@ public: continue; } // check + // llvm::outs() << "replacing " << iterArg.index() << "\n"; for (auto op : iterArg.value().getUsers()) { auto cvt = dyn_cast(op); if (!cvt) @@ -597,10 +618,23 @@ public: auto oldAcc = dotOp.getOperand(2); auto newAcc = rewriter.create( oldAcc.getLoc(), newRetType, oldAcc); - // convert output + Value a = dotOp.a(); + Value b = dotOp.b(); + auto oldAType = a.getType().cast(); + auto oldBType = b.getType().cast(); + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), + triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, + newRetType.getEncoding())); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), + triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, + newRetType.getEncoding())); + a = rewriter.create(a.getLoc(), newAType, a); + b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1), - newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB()); + dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(), + dotOp.transA(), dotOp.transB()); rewriter.replaceOpWithNewOp( op, oldRetType, newDot.getResult()); @@ -623,7 +657,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context); + // patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 1115fa200..aa27c1aad 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -11,6 +12,7 @@ //===----------------------------------------------------------------------===// using namespace mlir; +namespace ttg = triton::gpu; #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -24,6 +26,7 @@ static Type getI1SameShape(Value v) { } namespace { + class LoopPipeliner { /// cache forOp we are working on scf::ForOp forOp; @@ -37,6 +40,8 @@ class LoopPipeliner { DenseMap loadsMapping; /// load => buffer DenseMap loadsBuffer; + /// load => buffer type (with shared layout after swizzling) + DenseMap loadsBufferType; /// load => buffer at stage N DenseMap> loadStageBuffer; /// load => after extract @@ -67,8 +72,11 @@ class LoopPipeliner { Value lookupOrDefault(Value origin, int stage); /// returns a empty buffer of size - triton::gpu::AllocTensorOp allocateEmptyBuffer(Operation *op, - OpBuilder &builder); + ttg::AllocTensorOp allocateEmptyBuffer(Operation *op, OpBuilder &builder); + + /// compute type of shared buffers (with swizzled shared layouts) + RankedTensorType getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc, + RankedTensorType tensorType); public: LoopPipeliner(scf::ForOp forOp, int numStages) @@ -128,25 +136,82 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet &deps) { } } -triton::gpu::AllocTensorOp -LoopPipeliner::allocateEmptyBuffer(Operation *op, OpBuilder &builder) { +ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, + OpBuilder &builder) { // allocate a buffer for each pipelined tensor // shape: e.g. (numStages==4), <32x64xbf16> -> <4x32x64xbf16> Value convertLayout = loadsMapping[op->getResult(0)]; if (auto tensorType = convertLayout.getType().dyn_cast()) { - SmallVector shape(tensorType.getShape().begin(), - tensorType.getShape().end()); - shape.insert(shape.begin(), numStages); - Type elementType = tensorType.getElementType(); - // The encoding of the buffer is similar to the original tensor - Attribute encoding = tensorType.getEncoding(); - auto bufferType = RankedTensorType::get(shape, elementType, encoding); - return builder.create(convertLayout.getLoc(), - bufferType); + return builder.create( + convertLayout.getLoc(), loadsBufferType[op->getResult(0)]); } llvm_unreachable("Async copy's return should be of RankedTensorType"); } +// TODO: I copied the code from Swizzle.cpp. Should find a way to unify the +// code path. +// Swizzle has to be performed before pipeline for now. If we do swizzle +// after pipeline, we need to propagate the swizzled layout to all +// operands that is an alias of the swizzled tensor. The alias analysis +// component maybe helpful for this purpose. +RankedTensorType +LoopPipeliner::getSwizzleType(ttg::DotOperandEncodingAttr dotOpEnc, + RankedTensorType ty) { + int opIdx = dotOpEnc.getOpIdx(); + int vec = 1; + int maxPhase = 1; + int perPhase = 1; + llvm::SmallVector order; + if (auto mmaEnc = dotOpEnc.getParent().dyn_cast()) { + // Only support row major for now + // TODO(Keren): check why column major code crashes + order = {1, 0}; + int version = mmaEnc.getVersion(); + auto tyEncoding = ty.getEncoding().cast(); + // number of rows per phase + perPhase = 128 / (ty.getShape()[order[0]] * + (ty.getElementType().getIntOrFloatBitWidth() / 8)); + perPhase = std::max(perPhase, 1); + + // index of the inner dimension in `order` + unsigned inner = (opIdx == 0) ? 0 : 1; + if (version == 1) { + maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; + // TODO: handle rep (see + // https://github.com/openai/triton/blob/master/lib/codegen/analysis/layout.cc#L209) + } else if (version == 2) { + auto eltTy = ty.getElementType(); + std::vector matShape = {8, 8, + 2 * 64 / eltTy.getIntOrFloatBitWidth()}; + // for now, disable swizzle when using transposed int8 tensor cores + if (ty.getElementType().isInteger(8) && order[0] == inner) + perPhase = 1; + else { + if (opIdx == 0) { // compute swizzling for A operand + vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m + int mmaStride = order[0] == 1 ? matShape[0] : matShape[2]; + maxPhase = mmaStride / perPhase; + } else if (opIdx == 1) { // compute swizzling for B operand + vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k + int mmaStride = order[0] == 1 ? matShape[2] : matShape[1]; + maxPhase = mmaStride / perPhase; + } else + llvm_unreachable("invalid operand index"); + } + } else // version not in [1, 2] + llvm_unreachable("unsupported swizzling for provided MMA version"); + } else { // If the layout of dot is not mma, we don't need to swizzle + auto blockedEnc = dotOpEnc.getParent().cast(); + order = llvm::SmallVector(blockedEnc.getOrder().begin(), + blockedEnc.getOrder().end()); + } + auto newEncoding = ttg::SharedEncodingAttr::get(ty.getContext(), vec, + perPhase, maxPhase, order); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numStages); + return RankedTensorType::get(bufferShape, ty.getElementType(), newEncoding); +} + /// A load instruction can be pipelined if: /// - the load doesn't depend on any other loads (after loop peeling) /// - (?) this load is not a loop-invariant value (we should run LICM before @@ -186,19 +251,21 @@ LogicalResult LoopPipeliner::initialize() { } } - // For now, we only pipeline loads that have one covert_layout (to smem) use + // We only pipeline loads that have one covert_layout (to dot_op) use // TODO: lift this constraint in the future if (isCandiate && loadOp.getResult().hasOneUse()) { isCandiate = false; Operation *use = *loadOp.getResult().getUsers().begin(); - if (auto convertLayout = - llvm::dyn_cast(use)) { + if (auto convertLayout = llvm::dyn_cast(use)) { if (auto tensorType = convertLayout.getResult() .getType() .dyn_cast()) { - if (tensorType.getEncoding().isa()) { + if (auto dotOpEnc = tensorType.getEncoding() + .dyn_cast()) { isCandiate = true; loadsMapping[loadOp] = convertLayout; + loadsBufferType[loadOp] = getSwizzleType( + dotOpEnc, loadOp.getType().cast()); } } } @@ -238,6 +305,9 @@ void LoopPipeliner::emitPrologue() { setValueMapping(arg, operand.get(), 0); } + // helper to construct int attribute + auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); }; + // prologue from [0, numStage-1) Value iv = forOp.getLowerBound(); pipelineIterIdx = builder.create(iv.getLoc(), 0, 32); @@ -330,14 +400,15 @@ void LoopPipeliner::emitPrologue() { builder.create(iv.getLoc(), 1, 32)); } // for (int stage = 0; stage < numStages - 1; ++stage) - auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; - // async.wait & extract_slice - builder.create(loads[0].getLoc(), - loads.size() * (numStages - 2)); + builder.create(loads[0].getLoc(), + loads.size() * (numStages - 2)); loopIterIdx = builder.create(iv.getLoc(), 0, 32); for (Value loadOp : loads) { auto sliceType = loadsMapping[loadOp].getType().cast(); + sliceType = + RankedTensorType::get(sliceType.getShape(), sliceType.getElementType(), + loadsBufferType[loadOp].getEncoding()); Value extractSlice = builder.create( loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1], SmallVector{intAttr(0), intAttr(0), intAttr(0)}, @@ -366,6 +437,7 @@ void LoopPipeliner::emitEpilogue() { scf::ForOp LoopPipeliner::createNewForOp() { OpBuilder builder(forOp); + auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; // order of new args: // (original args), @@ -477,8 +549,6 @@ scf::ForOp LoopPipeliner::createNewForOp() { extractSliceIndex = builder.create( extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex); - auto intAttr = [&](int64_t v) { return builder.getI64IntegerAttr(v); }; - for (Operation *op : orderedDeps) { Operation *nextOp = nullptr; // update loading mask @@ -508,6 +578,9 @@ scf::ForOp LoopPipeliner::createNewForOp() { loadOp.evict(), loadOp.isVolatile(), /*axis*/ 0); nextBuffers.push_back(insertAsyncOp); auto sliceType = loadsMapping[loadOp].getType().cast(); + sliceType = RankedTensorType::get(sliceType.getShape(), + sliceType.getElementType(), + loadsBufferType[loadOp].getEncoding()); nextOp = builder.create( op->getLoc(), sliceType, insertAsyncOp, SmallVector{extractSliceIndex, intAttr(0), intAttr(0)}, @@ -534,8 +607,37 @@ scf::ForOp LoopPipeliner::createNewForOp() { } } + { + OpBuilder::InsertionGuard guard(builder); + for (Operation &op : *newForOp.getBody()) { + if (auto dotOp = llvm::dyn_cast(&op)) { + builder.setInsertionPoint(&op); + auto dotType = dotOp.getType().cast(); + Value a = dotOp.a(); + Value b = dotOp.b(); + auto layoutCast = [&](Value dotOperand, int opIdx) -> Value { + auto tensorType = dotOperand.getType().cast(); + if (!tensorType.getEncoding().isa()) { + auto newEncoding = ttg::DotOperandEncodingAttr::get( + tensorType.getContext(), opIdx, dotType.getEncoding()); + auto newType = + RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), newEncoding); + return builder.create(dotOperand.getLoc(), + newType, dotOperand); + } + return dotOperand; + }; + a = layoutCast(a, 0); + b = layoutCast(b, 1); + dotOp->setOperand(0, a); + dotOp->setOperand(1, b); + } + } + } + // async.wait & extract_slice - Operation *asyncWait = builder.create( + Operation *asyncWait = builder.create( loads[0].getLoc(), loads.size() * (numStages - 2)); for (auto it = extractSlices.rbegin(); it != extractSlices.rend(); ++it) { // move extract_slice after asyncWait diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp new file mode 100644 index 000000000..92be1bf7e --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -0,0 +1,304 @@ +//===----------------------------------------------------------------------===// +// +// This pass tries to prefetch operands (a and b) of tt.dot. +// Those ConvertLayoutOps will be lowered to shared memory loads. +// +// For example: +// %a: tensor<128x32xf16, #enc> +// scf.for %iv = ... iter_args(%a_arg = %a, ...) { +// %d = tt.dot %a_arg, %b, %c +// ... +// scf.yield %a_next, ... +// } +// +// will be translated to +// +// %a: tensor<128x32xf16, #enc> +// %a_tmp = tensor.extract_slice %a[0, 0] [128, 16] +// %a_prefetch = triton_gpu.convert_layout %a_tmp +// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) +// { +// %x = tt.dot %a_arg, %b, %c +// %a_tmp_rem = tensor.extract_slice %a_buf[0, 16] [128, 16] +// %a_prefetch_next = triton_gpu.convert_layout %a_tmp_rem +// ... +// scf.yield %next_a, ..., %a_prefetch_next +// } +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BlockAndValueMapping.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +class Prefetcher { + /// cache the ForOp we are working on + scf::ForOp forOp; + /// cache the YieldOp of this ForOp + scf::YieldOp yieldOp; + /// + // TODO: add a hook to infer prefetchWidth + unsigned prefetchWidth = 16; + + /// dots to be prefetched + SetVector dots; + /// dot => dot operand + DenseMap dot2aLoopArg; + DenseMap dot2aHeaderDef; + DenseMap dot2bLoopArg; + DenseMap dot2bHeaderDef; + DenseMap dot2aYield; + DenseMap dot2bYield; + /// operand => defining + DenseMap operand2headPrefetch; + + LogicalResult isForOpOperand(Value v); + + Value generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, + Attribute dotEncoding, OpBuilder &builder, + llvm::Optional offsetK = llvm::None, + llvm::Optional shapeK = llvm::None); + +public: + Prefetcher() = delete; + + Prefetcher(scf::ForOp forOp) : forOp(forOp) { + yieldOp = cast(forOp.getBody()->getTerminator()); + } + + LogicalResult initialize(); + + void emitPrologue(); + + scf::ForOp createNewForOp(); +}; + +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrefetch, + Attribute dotEncoding, OpBuilder &builder, + llvm::Optional offsetK, + llvm::Optional shapeK) { + // opIdx: 0 => a, 1 => b + auto type = v.getType().cast(); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; + SmallVector offset{0, 0}; + Type elementType = type.getElementType(); + + auto intAttr = [&](int64_t val) { return builder.getI64IntegerAttr(val); }; + + // k => (prefetchWidth, k - prefetchWidth) + int64_t kIdx = opIdx == 0 ? 1 : 0; + + offset[kIdx] = isPrefetch ? 0 : prefetchWidth; + shape[kIdx] = isPrefetch ? prefetchWidth : (shape[kIdx] - prefetchWidth); + + if (shapeK) + shape[kIdx] = *shapeK; + if (offsetK) + offset[kIdx] = *offsetK; + + Value newSmem = builder.create( + v.getLoc(), + // TODO: encoding? + RankedTensorType::get(shape, elementType, type.getEncoding()), v, + SmallVector{intAttr(offset[0]), intAttr(offset[1])}, + SmallVector{intAttr(shape[0]), intAttr(shape[1])}, + SmallVector{intAttr(1), intAttr(1)}); + + auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + + return prefetchSlice; +} + +LogicalResult Prefetcher::initialize() { + Block *loop = forOp.getBody(); + + SmallVector dotsInFor; + for (Operation &op : *loop) + if (auto dotOp = dyn_cast(op)) + dotsInFor.push_back(dotOp); + + if (dotsInFor.empty()) + return failure(); + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> Value { + // TODO: Check if the layout of src is SharedEncodingAttr + if (auto cvt = v.getDefiningOp()) + return cvt.src(); + return Value(); + }; + + auto getIncomingOp = [this](Value v) -> Value { + if (auto arg = v.dyn_cast()) + if (arg.getOwner()->getParentOp() == forOp.getOperation()) + return forOp.getOpOperandForRegionIterArg(arg).get(); + return Value(); + }; + + auto getYieldOp = [this](Value v) -> Value { + auto arg = v.cast(); + unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); + return yieldOp.getOperand(yieldIdx); + }; + + for (triton::DotOp dot : dotsInFor) { + Value aSmem = getPrefetchSrc(dot.a()); + Value bSmem = getPrefetchSrc(dot.b()); + if (aSmem && bSmem) { + Value aHeaderDef = getIncomingOp(aSmem); + Value bHeaderDef = getIncomingOp(bSmem); + // Only prefetch loop arg + if (aHeaderDef && bHeaderDef) { + dots.insert(dot); + dot2aHeaderDef[dot] = aHeaderDef; + dot2bHeaderDef[dot] = bHeaderDef; + dot2aLoopArg[dot] = aSmem; + dot2bLoopArg[dot] = bSmem; + dot2aYield[dot] = getYieldOp(aSmem); + dot2bYield[dot] = getYieldOp(bSmem); + } + } + } + + return success(); +} + +void Prefetcher::emitPrologue() { + OpBuilder builder(forOp); + + for (Value dot : dots) { + Attribute dotEncoding = + dot.getType().cast().getEncoding(); + Value aPrefetched = + generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder); + operand2headPrefetch[dot.getDefiningOp().a()] = aPrefetched; + Value bPrefetched = + generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder); + operand2headPrefetch[dot.getDefiningOp().b()] = bPrefetched; + } +} + +scf::ForOp Prefetcher::createNewForOp() { + OpBuilder builder(forOp); + + SmallVector loopArgs; + for (auto v : forOp.getIterOperands()) + loopArgs.push_back(v); + for (Value dot : dots) { + loopArgs.push_back( + operand2headPrefetch[dot.getDefiningOp().a()]); + loopArgs.push_back( + operand2headPrefetch[dot.getDefiningOp().b()]); + } + + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), loopArgs); + + auto largestPow2 = [](int64_t n) -> int64_t { + while ((n & (n - 1)) != 0) + n = n & (n - 1); + return n; + }; + + builder.setInsertionPointToStart(newForOp.getBody()); + BlockAndValueMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + + for (Operation &op : forOp.getBody()->without_terminator()) { + Operation *newOp = nullptr; + auto dot = dyn_cast(&op); + if (dots.contains(dot)) { + Attribute dotEncoding = + dot.getType().cast().getEncoding(); + // prefetched dot + Operation *firstDot = builder.clone(*dot, mapping); + if (Value a = operand2headPrefetch.lookup(dot.a())) + firstDot->setOperand( + 0, newForOp.getRegionIterArgForOpOperand(*a.use_begin())); + if (Value b = operand2headPrefetch.lookup(dot.b())) + firstDot->setOperand( + 1, newForOp.getRegionIterArgForOpOperand(*b.use_begin())); + + // remaining part + int64_t kOff = prefetchWidth; + int64_t kRem = dot.a().getType().cast().getShape()[1] - + prefetchWidth; + Operation *prevDot = firstDot; + while (kRem != 0) { + int64_t kShape = largestPow2(kRem); + Value aRem = + generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, + dotEncoding, builder, kOff, kShape); + Value bRem = + generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, + dotEncoding, builder, kOff, kShape); + newOp = builder.clone(*dot, mapping); + newOp->setOperand(0, aRem); + newOp->setOperand(1, bRem); + newOp->setOperand(2, prevDot->getResult(0)); + prevDot = newOp; + kOff += kShape; + kRem -= kShape; + } + } else { + newOp = builder.clone(op, mapping); + } + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + + // prefetch next iteration + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookup(v)); + for (Value dot : dots) { + Attribute dotEncoding = + dot.getType().cast().getEncoding(); + yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, + true, dotEncoding, builder)); + yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, + true, dotEncoding, builder)); + } + // Update ops of yield + builder.create(yieldOp.getLoc(), yieldValues); + return newForOp; +} + +struct PrefetchPass : public TritonGPUPrefetchBase { + void runOnOperation() override { + getOperation()->walk([&](scf::ForOp forOp) { + Prefetcher prefetcher(forOp); + + if (prefetcher.initialize().failed()) + return; + + prefetcher.emitPrologue(); + + scf::ForOp newForOp = prefetcher.createNewForOp(); + + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); + }); + } +}; + +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUPrefetchPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp index 776fc9973..a519e32db 100644 --- a/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Swizzle.cpp @@ -39,23 +39,23 @@ struct SwizzlePass : public TritonGPUSwizzleBase { return SwizzleInfo{vec, perPhase, maxPhase}; } else if (version == 2) { auto eltTy = ty.getElementType(); - std::vector mat_shape = {8, 8, - 2 * 64 / eltTy.getIntOrFloatBitWidth()}; + std::vector matShape = {8, 8, + 2 * 64 / eltTy.getIntOrFloatBitWidth()}; // for now, disable swizzle when using transposed int8 tensor cores - bool is_int8_mma = ty.getElementType().isInteger(8); - if (is_int8_mma && order[0] == inner) + bool isInt8Mma = ty.getElementType().isInteger(8); + if (isInt8Mma && order[0] == inner) return noSwizzling; // compute swizzling for A operand if (opIdx == 0) { - int vec = order[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m - int mmaStride = order[0] == 1 ? mat_shape[0] : mat_shape[2]; + int vec = order[0] == 1 ? matShape[2] : matShape[0]; // k : m + int mmaStride = order[0] == 1 ? matShape[0] : matShape[2]; int maxPhase = mmaStride / perPhase; return SwizzleInfo{vec, perPhase, maxPhase}; } // compute swizzling for B operand else if (opIdx == 1) { - int vec = order[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k - int mmaStride = order[0] == 1 ? mat_shape[2] : mat_shape[1]; + int vec = order[0] == 1 ? matShape[1] : matShape[2]; // n : k + int mmaStride = order[0] == 1 ? matShape[2] : matShape[1]; int maxPhase = mmaStride / perPhase; return SwizzleInfo{vec, perPhase, maxPhase}; } else { @@ -67,32 +67,64 @@ struct SwizzlePass : public TritonGPUSwizzleBase { void runOnOperation() override { Operation *op = getOperation(); - op->walk([&](triton::DotOp dotOp) -> void { - OpBuilder builder(dotOp); - auto _retEncoding = - dotOp.getResult().getType().cast().getEncoding(); - auto retEncoding = _retEncoding.dyn_cast(); - if (!retEncoding) - return; - for (int opIdx : {0, 1}) { - Value op = dotOp.getOperand(opIdx); - auto ty = op.getType().template cast(); - // compute new swizzled encoding - SwizzleInfo swizzle = getSwizzleMMA(opIdx, retEncoding, ty); - auto newEncoding = triton::gpu::SharedEncodingAttr::get( - &getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase, - ty.getEncoding() - .cast() - .getOrder()); - // create conversion - auto newType = RankedTensorType::get(ty.getShape(), ty.getElementType(), - newEncoding); - Operation *newOp = builder.create( - op.getLoc(), newType, op); - // bind new op to dot operand - dotOp->replaceUsesOfWith(op, newOp->getResult(0)); + // replace blocked -> dot_op with + // blocked -> shared -> dot_op in order to + // expose opportunities for swizzling + op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + if (srcType.getEncoding().isa() && + dstType.getEncoding().isa()) { + auto tmpType = + RankedTensorType::get(dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + op->getContext(), 1, 1, 1, {1, 0})); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); } }); + + op->walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto arg = cvtOp.getOperand(); + auto retType = cvtOp.getResult().getType().cast(); + auto retEncoding = + retType.getEncoding().dyn_cast(); + auto argType = arg.getType().cast(); + auto argEncoding = + argType.getEncoding().dyn_cast(); + if (!argEncoding || !retEncoding) + return; + auto opIdx = retEncoding.getOpIdx(); + // compute new swizzled encoding + auto parentEncoding = + retEncoding.getParent().dyn_cast(); + if (!parentEncoding) + return; + auto swizzleType = argType; + if (arg.getDefiningOp() && + isa(arg.getDefiningOp())) { + swizzleType = arg.getDefiningOp() + ->getOperand(0) + .getType() + .cast(); + } + SwizzleInfo swizzle = getSwizzleMMA(opIdx, parentEncoding, swizzleType); + auto newEncoding = triton::gpu::SharedEncodingAttr::get( + &getContext(), swizzle.vec, swizzle.perPhase, swizzle.maxPhase, + argEncoding.getOrder()); + // create conversion + auto newType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newEncoding); + Operation *newArg = builder.create( + cvtOp.getLoc(), newType, arg); + // bind new op to cvt operand + cvtOp->replaceUsesOfWith(arg, newArg->getResult(0)); + }); } }; } // anonymous namespace diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 6bd11de81..14ff03615 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -95,8 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( dotOp.a().getType().cast().getEncoding(); Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); - if (aEncoding && aEncoding.isa() && - bEncoding && bEncoding.isa()) + if (aEncoding && aEncoding.isa() && + bEncoding && bEncoding.isa()) return true; return false; }); diff --git a/python/src/triton.cc b/python/src/triton.cc index 89d401b6a..b4b6b068d 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1255,6 +1255,10 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self, int numStages) { self.addPass(mlir::createTritonGPUPipelinePass(numStages)); }) + .def("add_tritongpu_prefetch_pass", + [](mlir::PassManager &self) { + self.addPass(mlir::createTritonGPUPrefetchPass()); + }) .def("add_triton_gpu_combine_pass", [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUCombineOpsPass()); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index e264fe086..2bac2c83f 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -171,63 +171,65 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False) -@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ - [32, 32, 16, 4, 32, 32, 16], - [32, 16, 16, 4, 32, 32, 16], - [128, 8, 8, 4, 32, 32, 16], - [127, 41, 43, 4, 32, 32, 16], -]) -def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): - @triton.jit - def matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ): - pid = tl.program_id(axis=0) - # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, K, BLOCK_SIZE_K): - a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) - b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) - a = tl.load(a_ptrs, a_mask) - b = tl.load(b_ptrs, b_mask) - # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering - accumulator += tl.dot(a, b, allow_tf32=False) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - offs_k += BLOCK_SIZE_K - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, c_mask) - - a = torch.randn((M, K), device='cuda', dtype=torch.float32) - b = torch.randn((K, N), device='cuda', dtype=torch.float32) - c = torch.empty((M, N), device=a.device, dtype=torch.float32) - - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) - matmul_kernel[grid](a, b, c, - M, N, K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) - - golden = torch.matmul(a, b) - torch.testing.assert_close(c, golden) +# XXX(Keren): Temporarily disable this test until we have shared -> dot conversion implemented +#@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ +# [32, 32, 16, 4, 32, 32, 16], +# [32, 16, 16, 4, 32, 32, 16], +# [128, 8, 8, 4, 32, 32, 16], +# [127, 41, 43, 4, 32, 32, 16], +#]) +#def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): +# @triton.jit +# def matmul_kernel( +# a_ptr, b_ptr, c_ptr, +# M, N, K, +# stride_am, stride_ak, +# stride_bk, stride_bn, +# stride_cm, stride_cn, +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# ): +# pid = tl.program_id(axis=0) +# # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# pid_m = pid // num_pid_n +# pid_n = pid % num_pid_n +# +# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) +# b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) +# +# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) +# b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) +# a = tl.load(a_ptrs, a_mask) +# b = tl.load(b_ptrs, b_mask) +# # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering +# accumulator += tl.dot(a, b, allow_tf32=False) +# a_ptrs += BLOCK_SIZE_K * stride_ak +# b_ptrs += BLOCK_SIZE_K * stride_bk +# offs_k += BLOCK_SIZE_K +# +# offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +# offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) +# c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn +# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) +# tl.store(c_ptrs, accumulator, c_mask) +# +# a = torch.randn((M, K), device='cuda', dtype=torch.float32) +# b = torch.randn((K, N), device='cuda', dtype=torch.float32) +# c = torch.empty((M, N), device=a.device, dtype=torch.float32) +# +# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) +# matmul_kernel[grid](a, b, c, +# M, N, K, +# stride_am=a.stride(0), stride_ak=a.stride(1), +# stride_bk=b.stride(0), stride_bn=b.stride(1), +# stride_cm=c.stride(0), stride_cn=c.stride(1), +# BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) +# +# golden = torch.matmul(a, b) +# torch.testing.assert_close(c, golden) +# diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 7be9e6553..c1dbbfd16 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -876,6 +876,9 @@ def ttir_to_ttgir(mod, num_warps, num_stages): pm = _triton.ir.pass_manager(mod.context) pm.add_convert_triton_to_tritongpu_pass(num_warps) pm.enable_debug() + # Convert blocked layout to mma layout for dot ops so that pipeline + # can get shared memory swizzled correctly. + pm.add_triton_gpu_combine_pass() pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass() diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index fc6e1e289..ea3c1e7e6 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -2,11 +2,14 @@ #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> // CHECK-LABEL: matmul_loop +// There shouldn't be any aliasing with the dot op encoding. func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -19,12 +22,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: %4 -> %4 - %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> - // CHECK-NEXT: %6 -> %6 - %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> @@ -36,10 +37,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK-LABEL: alloc func @alloc(%A : !tt.ptr) { // CHECK: %cst -> %cst - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK: %0 -> %0 - %cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A> + %cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> return } @@ -47,7 +48,7 @@ func @alloc(%A : !tt.ptr) { func @convert(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 - %cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> + %cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> return } @@ -57,38 +58,38 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %cst_0 -> %cst_0 - %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : i32 // CHECK: %2 -> %cst_0 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A_SHARED> return } // CHECK-LABEL: extract_slice func @extract_slice(%A : !tt.ptr) { // CHECK: %cst -> %cst - %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index // CHECK-NEXT: %0 -> %cst - %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A> + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> return } // CHECK-LABEL: if_cat func @if_cat(%i1 : i1) { // CHECK: %cst -> %cst - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: %cst_0 -> %cst_0 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: %0 -> %1,%1 - %cst2 = scf.if %i1 -> tensor<32x16xf16, #A> { + %cst2 = scf.if %i1 -> tensor<32x16xf16, #A_SHARED> { // CHECK: %1 -> %1 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - scf.yield %a : tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %a : tensor<32x16xf16, #A_SHARED> } else { // CHECK: %1 -> %1 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - scf.yield %b : tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %b : tensor<32x16xf16, #A_SHARED> } return } @@ -96,14 +97,14 @@ func @if_cat(%i1 : i1) { // CHECK-LABEL: if_alias func @if_alias(%i1 : i1) { // CHECK: %cst -> %cst - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: %0 -> %cst,%cst_0 - %cst2 = scf.if %i1 -> tensor<16x16xf16, #A> { - scf.yield %cst0 : tensor<16x16xf16, #A> + %cst2 = scf.if %i1 -> tensor<16x16xf16, #A_SHARED> { + scf.yield %cst0 : tensor<16x16xf16, #A_SHARED> } else { - scf.yield %cst1 : tensor<16x16xf16, #A> + scf.yield %cst1 : tensor<16x16xf16, #A_SHARED> } return } @@ -111,19 +112,19 @@ func @if_alias(%i1 : i1) { // CHECK-LABEL: for func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %cst -> %cst - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_1 -> %cst_1 - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %arg6 -> %cst // CHECK-NEXT: %arg7 -> %cst_0 // CHECK-NEXT: %arg8 -> %cst_1 // CHECK-NEXT: %0#0 -> %cst,%cst_0 // CHECK-NEXT: %0#1 -> %cst,%cst_0 // CHECK-NEXT: %0#2 -> %cst,%cst_0 - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { - scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return } @@ -131,25 +132,25 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p // CHECK-LABEL: for_if func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %cst -> %cst - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_1 -> %cst_1 - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %arg7 -> %cst // CHECK-NEXT: %arg8 -> %cst_0 // CHECK-NEXT: %arg9 -> %cst_1 // CHECK-NEXT: %0#0 -> %cst,%cst_0 // CHECK-NEXT: %0#1 -> %cst,%cst_0 // CHECK-NEXT: %0#2 -> %cst,%cst_0 - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { scf.if %i1 { %index = arith.constant 8 : index // CHECK-NEXT: %1 -> %cst,%cst_0 - %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A> + %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return } @@ -157,34 +158,34 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if_for func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %cst -> %cst - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_1 -> %cst_1 - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %arg7 -> %cst // CHECK-NEXT: %arg8 -> %cst_0 // CHECK-NEXT: %arg9 -> %cst_1 // CHECK-NEXT: %0#0 -> %cst // CHECK-NEXT: %0#1 -> %cst_0 // CHECK-NEXT: %0#2 -> %cst_2,%cst_2 - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { // CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2 // CHECK-NEXT: %1 -> %cst_2,%cst_2 - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { // CHECK-NEXT: %2 -> %cst_2,%cst_2 - %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A> { + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { // CHECK-NEXT: %cst_2 -> %cst_2 - %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - scf.yield %cst0 : tensor<128x32xf16, #A> + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst0 : tensor<128x32xf16, #A_SHARED> } else { // CHECK-NEXT: %cst_2 -> %cst_2 - %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - scf.yield %cst0 : tensor<128x32xf16, #A> + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst0 : tensor<128x32xf16, #A_SHARED> } - scf.yield %c_shared_next_next : tensor<128x32xf16, #A> + scf.yield %c_shared_next_next : tensor<128x32xf16, #A_SHARED> } - scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 85e50005c..7e2a7d675 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -3,9 +3,11 @@ #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> // CHECK-LABEL: matmul_loop func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { @@ -23,20 +25,20 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: offset = 0, size = 8192 - %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + // CHECK: offset = 0, size = 4608 + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> - // CHECK-NEXT: offset = 8192, size = 8192 - %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + // CHECK-NEXT: offset = 0, size = 4224 + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } return - // CHECK-NEXT: size = 16384 + // CHECK-NEXT: size = 4608 } // Shared memory is available after a tensor's liveness range ends @@ -51,21 +53,21 @@ func @reusable(%A : !tt.ptr) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: offset = 0, size = 8192 - %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + // CHECK-NEXT: offset = 0, size = 4608 + %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: offset = 8192, size = 8192 - %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> + // CHECK-NEXT: offset = 0, size = 1152 + %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: offset = 16384, size = 8192 - %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> - %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK-NEXT: offset = 0, size = 4608 + %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> + %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: offset = 0, size = 8192 - %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> - %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + // CHECK-NEXT: offset = 0, size = 1152 + %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> + %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> return - // CHECK-NEXT: size = 24576 + // CHECK-NEXT: size = 4608 } // A tensor's shared memory offset is larger than it needs to accommodate further tensors @@ -75,33 +77,33 @@ func @reusable(%A : !tt.ptr) { // CHECK-LABEL: preallocate func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 512 - %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 2048, size = 1024 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 3072, size = 1024 - %b = tt.cat %cst0, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 0, size = 1024 - %c = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %c = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 1024 - %cst4 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A> + %cst4 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 6144, size = 2048 - %e = tt.cat %a, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + %e = tt.cat %a, %cst4 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 2048 - %d = tt.cat %b, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + %d = tt.cat %b, %cst4 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> // CHECK-NEXT: offset = 10240, size = 2048 - %f = tt.cat %c, %cst4 {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + %f = tt.cat %c, %cst4 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> // CHECK-NEXT: offset = 0, size = 2048 - %cst5 = arith.constant dense<0.000000e+00> : tensor<64x16xf16, #A> + %cst5 = arith.constant dense<0.000000e+00> : tensor<64x16xf16, #A_SHARED> // CHECK-NEXT: offset = 2048, size = 4096 - %g = tt.cat %e, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> + %g = tt.cat %e, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED> // CHECK-NEXT: offset = 2048, size = 4096 - %h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> + %h = tt.cat %d, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED> // CHECK-NEXT: offset = 2048, size = 4096 - %i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A>, tensor<64x16xf16, #A>) -> tensor<128x16xf16, #A> + %i = tt.cat %f, %cst5 {axis = 0} : (tensor<64x16xf16, #A_SHARED>, tensor<64x16xf16, #A_SHARED>) -> tensor<128x16xf16, #A_SHARED> return // CHECK-NEXT: size = 12288 } @@ -110,13 +112,13 @@ func @preallocate(%A : !tt.ptr) { // CHECK-LABEL: unused func @unused(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 0, size = 512 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 - %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 1024 - %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> return // CHECK: size = 2048 } @@ -125,27 +127,27 @@ func @unused(%A : !tt.ptr) { // CHECK-LABEL: longlive func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 - %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 1024 - %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst1, %cst2 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 - %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 - %cst4 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst4 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 1024 - %b = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 512 - %cst5 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst5 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 512 - %cst6 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst6 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 1024 - %c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %c = tt.cat %cst3, %cst4 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 1024 - %d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %d = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> return // CHECK-NEXT: size = 2560 } @@ -153,10 +155,10 @@ func @longlive(%A : !tt.ptr) { // CHECK-LABEL: alloc func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A> + %cst2 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> return // CHECK-NEXT: size = 512 } @@ -176,9 +178,9 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: offset = 0, size = 512 - %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %tensor = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : i32 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A_SHARED> return // CHECK-NEXT: size = 512 } @@ -186,9 +188,9 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { // CHECK-LABEL: extract_slice func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index - %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A> + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> return // CHECK-NEXT: size = 512 } @@ -198,21 +200,21 @@ func @extract_slice(%A : !tt.ptr) { // CHECK-LABEL: if func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { // CHECK-NEXT: offset = 1024, size = 1024 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 1024 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> } // CHECK-NEXT: offset = 0, size = 512 - %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 - %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 1024 - %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> return // CHECK-NEXT: size = 2048 } @@ -222,24 +224,24 @@ func @if(%i1 : i1) { // CHECK-LABEL: if_else func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { // CHECK-NEXT: offset = 1024, size = 1024 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 1024 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> } else { // CHECK-NEXT: offset = 1024, size = 512 - %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1536, size = 512 - %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 2048, size = 1024 - %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> } // CHECK-NEXT: offset = 1024, size = 1024 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> return // CHECK-NEXT: size = 3072 } @@ -249,13 +251,13 @@ func @if_else(%i1 : i1) { // CHECK-LABEL: for func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { - scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return // CHECK-NEXT: size = 24576 @@ -264,18 +266,18 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p // CHECK-LABEL: for_if_slice func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { scf.if %i1 { %index = arith.constant 8 : index - %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A> to tensor<32xf16, #A> + %cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return // CHECK-NEXT: size = 24576 @@ -286,28 +288,28 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // CHECK-LABEL: for_if_for func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A>) { - %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A> { + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { // CHECK-NEXT: offset = 24576, size = 8192 - %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - scf.yield %cst0 : tensor<128x32xf16, #A> + %cst0 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst0 : tensor<128x32xf16, #A_SHARED> } else { // CHECK-NEXT: offset = 32768, size = 8192 - %cst1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - scf.yield %cst1 : tensor<128x32xf16, #A> + %cst1 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + scf.yield %cst1 : tensor<128x32xf16, #A_SHARED> } - scf.yield %c_shared_next_next : tensor<128x32xf16, #A> + scf.yield %c_shared_next_next : tensor<128x32xf16, #A_SHARED> } - scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + scf.yield %a_shared, %b_shared, %c_shared_next : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } // CHECK-NEXT: offset = 0, size = 8192 - %cst2 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %cst2 = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> return // CHECK-NEXT: size = 40960 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 14d3844a4..8aeb7b2dd 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -3,11 +3,14 @@ #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> // CHECK-LABEL: matmul_loop +// There shouldn't be any membar with the dot op encoding. func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -23,11 +26,10 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> - %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - // CHECK: Membar 13 - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> @@ -42,9 +44,9 @@ func @raw_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> // CHECK: Membar 5 - %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A> + %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #A_SHARED> return } @@ -54,56 +56,56 @@ func @war_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> // CHECK: Membar 5 - %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #AL> + %a2 = triton_gpu.convert_layout %a1 : (tensor<128x32xf16, #A_SHARED>) -> tensor<128x32xf16, #AL> // a2's liveness range ends here, and a3 and a2 have the same address range. // So it makes sense to have a WAR dependency between a2 and a3. // CHECK-NEXT: Membar 7 - %a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a3 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_SHARED> return } // CHECK-LABEL: scratch func @scratch() { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: Membar 1 - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: Membar 3 - %aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL> + %aa = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> %b = tt.reduce %aa {redOp = 1 : i32, axis = 0 : i32} : tensor<32x16xf16, #AL> -> tensor<16xf16, #sliceAd0> return } // CHECK-LABEL: async_wait func @async_wait() { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: Membar 1 - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> triton_gpu.async_wait {num = 4 : i32} // CHECK-NEXT: Membar 4 - %a_ = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL> + %a_ = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> return } // CHECK-LABEL: alloc func @alloc() { - %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A> - %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK: Membar 2 - %b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A>) -> tensor<32x16xf16, #AL> + %b = triton_gpu.convert_layout %a : (tensor<32x16xf16, #A_SHARED>) -> tensor<32x16xf16, #AL> return } // CHECK-LABEL: extract_slice func @extract_slice() { - %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index - %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A> to tensor<16x16xf16, #A> + %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> // CHECK: Membar 3 - %cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + %cst2 = triton_gpu.convert_layout %cst1 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> // CHECK-NEXT: Membar 5 - %cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> + %cst3 = triton_gpu.convert_layout %cst2 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> return } @@ -112,119 +114,119 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A> + %tensor = triton_gpu.alloc_tensor : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : i32 - %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A> - %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A>, tensor<1x16x16xf16, #A>) -> tensor<2x16x16xf16, #A> + %a = triton_gpu.insert_slice_async %a_ptr, %tensor, %index, %mask, %other {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16x!tt.ptr, #AL> -> tensor<1x16x16xf16, #A_SHARED> + %b = tt.cat %a, %a {axis = 0} : (tensor<1x16x16xf16, #A_SHARED>, tensor<1x16x16xf16, #A_SHARED>) -> tensor<2x16x16xf16, #A_SHARED> // CHECK: Membar 7 - %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A>, tensor<2x16x16xf16, #A>) -> tensor<4x16x16xf16, #A> + %c = tt.cat %b, %b {axis = 0} : (tensor<2x16x16xf16, #A_SHARED>, tensor<2x16x16xf16, #A_SHARED>) -> tensor<4x16x16xf16, #A_SHARED> return } // If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks func @multi_blocks(%i1 : i1) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } else { - %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst2 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst3 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: Membar 7 - %b = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst2, %cst3 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } // CHECK-NEXT: Membar 10 - %c = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %c = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> return } // Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region // CHECK-LABEL: multi_blocks_join_barrier func @multi_blocks_join_barrier(%i1 : i1) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } else { // CHECK-NEXT: Membar 5 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } - %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> return } // Read yielded tensor requires a barrier // CHECK-LABEL: multi_blocks_yield func @multi_blocks_yield(%i1 : i1) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %a = scf.if %i1 -> (tensor<32x16xf16, #A>) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) { // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - scf.yield %a : tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %a : tensor<32x16xf16, #A_SHARED> } else { // CHECK-NEXT: Membar 5 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> - scf.yield %b : tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + scf.yield %b : tensor<32x16xf16, #A_SHARED> } - %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> // CHECK-NEXT: Membar 9 - %b = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A>, tensor<32x16xf16, #A>) -> tensor<64x16xf16, #A> + %b = tt.cat %a, %a {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #A_SHARED> return } // Conservatively add a barrier as if the branch (%i1) is never taken // CHECK-LABEL: multi_blocks_noelse func @multi_blocks_noelse(%i1 : i1) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { // CHECK: Membar 2 - %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %a = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } - %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> return } // Conservatively add a barrier as if the branch (%i2) is never taken // CHECK-LABEL: multi_blocks_nested_scf func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { - %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> - %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { scf.if %i2 { // CHECK: Membar 2 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } scf.yield } else { // CHECK-NEXT: Membar 6 - %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A>, tensor<16x16xf16, #A>) -> tensor<32x16xf16, #A> + %b = tt.cat %cst0, %cst1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> scf.yield } // CHECK-NEXT: Membar 9 - %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A>) -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> return } // CHECK-LABEL: for func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { // CHECK-NEXT: Membar 3 - %cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> - scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + %cst0 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } return } @@ -233,18 +235,18 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p // they are reassociated with aliases (c_shared) and thus require a barrier. // CHECK-LABEL: for_alias func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { - %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> + %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: Membar 2 - %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> - %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) { - %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> + %cst0 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + %cst1 = tt.cat %a_shared_init, %b_shared_init {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> // CHECK-NEXT: Membar 6 - %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A>, tensor<128x32xf16, #A>) -> tensor<256x32xf16, #A> - scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A>, tensor<128x32xf16, #A>, tensor<128x32xf16, #A> + %cst2 = tt.cat %a_shared, %b_shared {axis = 0} : (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) -> tensor<256x32xf16, #A_SHARED> + scf.yield %c_shared, %a_shared, %b_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED> } // CHECK-NEXT: Membar 9 - %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A>, tensor<256x32xf16, #A>) -> tensor<512x32xf16, #A> + %cst3 = tt.cat %cst0, %cst0 {axis = 0} : (tensor<256x32xf16, #A_SHARED>, tensor<256x32xf16, #A_SHARED>) -> tensor<512x32xf16, #A_SHARED> return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index c3a6fd63a..77f8db1db 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -669,22 +669,26 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}> #mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + %AA_DOT = triton_gpu.convert_layout %AA : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = triton_gpu.convert_layout %BB : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - // CHECK: llvm.inline_asm - // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - %D = tt.dot %AA, %BB, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0> + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> return } @@ -813,6 +817,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { } // ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> @@ -821,12 +826,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { - %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> - // CHECK: llvm.intr.fmuladd - %28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked> - %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> - %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> - tt.store %36, %28 : tensor<32x32xf32, #blocked> + // We are going to completely depracate using shared layout for operands of dot + //%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + //%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked> + //%30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> + //%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> + //tt.store %36, %28 : tensor<32x32xf32, #blocked> return } } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index fe9a10e27..a8d0fef14 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -4,9 +4,9 @@ // matmul: 128x32 @ 32x128 -> 128x128 #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> // CHECK: func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 @@ -30,7 +30,9 @@ // CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] -// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index @@ -87,15 +89,17 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0] // CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0] // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] -// CHECK: tt.dot %[[arg_a0]], %[[arg_b0]], {{.*}} +// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] +// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] +// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} // CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]] // CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} -// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] +// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] +// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] @@ -141,7 +145,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu -tritongpu-combine -tritongpu-pipeline=num-stages=3 -tritongpu-combine -test-print-allocation 2>&1 | FileCheck %s // CHECK: offset = 0, size = 49152 // CHECK: offset = 49152, size = 49152 diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir new file mode 100644 index 000000000..efba86b90 --- /dev/null +++ b/test/TritonGPU/prefetch.mlir @@ -0,0 +1,65 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-prefetch | FileCheck %s + +// 4 warps +// matmul: 128x32 @ 32x128 -> 128x128 +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> +#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> + + +// CHECK: func @matmul_loop +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} +// CHECK-DAG: %[[A_REM_SMEM:.*]] = tensor.extract_slice %[[arg_a0]][0, 16] [128, 16] +// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = tensor.extract_slice %[[arg_b0]][16, 0] [16, 128] +// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]] +// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [128, 16] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]] +func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %b_init = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + + scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op = triton_gpu.convert_layout %a : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A_OP> + %b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP> + %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + %next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + %next_b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C> + } + return +} + diff --git a/test/TritonGPU/swizzle.mlir b/test/TritonGPU/swizzle.mlir index bb97bf1cb..256b4f1b9 100644 --- a/test/TritonGPU/swizzle.mlir +++ b/test/TritonGPU/swizzle.mlir @@ -13,14 +13,25 @@ #shared2 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> #shared3 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0]}> +#mma1w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma1w}> +#mma1w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma1w}> +#mma2w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma2w}> +#mma2w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma2w}> +#mma4w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma4w}> +#mma4w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma4w}> +#mma8w_op0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma8w}> +#mma8w_op1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma8w}> + module attributes {"triton_gpu.num-warps" = 8 : i32} { // CHECK-LABEL: swizzle_mma_f16_128x256x64_w8 - func @swizzle_mma_f16_128x256x64_w8(%A: tensor<128x64xf16, #shared>, %B: tensor<64x256xf16, #shared>) { + func @swizzle_mma_f16_128x256x64_w8(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x256xf16, #shared>) { %cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma8w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x256xf16, {{.*}}>) -> tensor<64x256xf16, [[shared_v8p1m8]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x256xf16, #shared> -> tensor<128x256xf32, #mma8w> + %A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma8w_op0> + %B = triton_gpu.convert_layout %B_SMEM : (tensor<64x256xf16, #shared>) -> tensor<64x256xf16, #mma8w_op1> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma8w_op0> * tensor<64x256xf16, #mma8w_op1> -> tensor<128x256xf32, #mma8w> return } } @@ -28,44 +39,52 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: swizzle_mma_f16_128x128x64_w4 - func @swizzle_mma_f16_128x128x64_w4(%A: tensor<128x64xf16, #shared>, %B: tensor<64x128xf16, #shared>) { + func @swizzle_mma_f16_128x128x64_w4(%A_SMEM: tensor<128x64xf16, #shared>, %B_SMEM: tensor<64x128xf16, #shared>) { %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x64xf16, {{.*}}>) -> tensor<128x64xf16, [[shared_v8p1m8]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<64x128xf16, {{.*}}>) -> tensor<64x128xf16, [[shared_v8p1m8]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #shared> * tensor<64x128xf16, #shared> -> tensor<128x128xf32, #mma4w> + %A = triton_gpu.convert_layout %A_SMEM : (tensor<128x64xf16, #shared>) -> tensor<128x64xf16, #mma4w_op0> + %B = triton_gpu.convert_layout %B_SMEM : (tensor<64x128xf16, #shared>) -> tensor<64x128xf16, #mma4w_op1> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x64xf16, #mma4w_op0> * tensor<64x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w> return } } module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: swizzle_mma_f16_128x128x32_w4 - func @swizzle_mma_f16_128x128x32_w4(%A: tensor<128x32xf16, #shared>, %B: tensor<32x128xf16, #shared>) { + func @swizzle_mma_f16_128x128x32_w4(%A_SMEM: tensor<128x32xf16, #shared>, %B_SMEM: tensor<32x128xf16, #shared>) { %cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma4w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<128x32xf16, {{.*}}>) -> tensor<128x32xf16, [[shared_v8p2m4]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x128xf16, {{.*}}>) -> tensor<32x128xf16, [[shared_v8p1m8]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma4w> + %A = triton_gpu.convert_layout %A_SMEM : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #mma4w_op0> + %B = triton_gpu.convert_layout %B_SMEM : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #mma4w_op1> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #mma4w_op0> * tensor<32x128xf16, #mma4w_op1> -> tensor<128x128xf32, #mma4w> return } } module attributes {"triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: swizzle_mma_f16_32x32x32_w2 - func @swizzle_mma_f16_32x32x32_w2(%A: tensor<32x32xf16, #shared>, %B: tensor<32x32xf16, #shared>) { + func @swizzle_mma_f16_32x32x32_w2(%A_SMEM: tensor<32x32xf16, #shared>, %B_SMEM: tensor<32x32xf16, #shared>) { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma2w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<32x32xf16, {{.*}}>) -> tensor<32x32xf16, [[shared_v8p2m4]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #shared> * tensor<32x32xf16, #shared> -> tensor<32x32xf32, #mma2w> + %A = triton_gpu.convert_layout %A_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op0> + %B = triton_gpu.convert_layout %B_SMEM : (tensor<32x32xf16, #shared>) -> tensor<32x32xf16, #mma2w_op1> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<32x32xf16, #mma2w_op0> * tensor<32x32xf16, #mma2w_op1> -> tensor<32x32xf32, #mma2w> return } } module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: swizzle_mma_f16_16x16x16_w1 - func @swizzle_mma_f16_16x16x16_w1(%A: tensor<16x16xf16, #shared>, %B: tensor<16x16xf16, #shared>) { + func @swizzle_mma_f16_16x16x16_w1(%A_SMEM: tensor<16x16xf16, #shared>, %B_SMEM: tensor<16x16xf16, #shared>) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1w> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> // CHECK: {{.*}} = triton_gpu.convert_layout {{.*}} : (tensor<16x16xf16, {{.*}}>) -> tensor<16x16xf16, [[shared_v8p4m2]]> - %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #shared> * tensor<16x16xf16, #shared> -> tensor<16x16xf32, #mma1w> + %A = triton_gpu.convert_layout %A_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op0> + %B = triton_gpu.convert_layout %B_SMEM : (tensor<16x16xf16, #shared>) -> tensor<16x16xf16, #mma1w_op1> + %D = tt.dot %A, %B, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #mma1w_op0> * tensor<16x16xf16, #mma1w_op1> -> tensor<16x16xf32, #mma1w> return } }