diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 0a1a72bc3..dd301b904 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -52,7 +52,7 @@ different cuda threads in the programs, via shared memory. In other words, for all indices i \in R^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. In order to avoid shared memory bank conflicts, elements may be swizzled -in memory. For example, a swizzled row-major layout could store its data +in memory. For example, a swizzled row-major layout could store its data as follows: A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2 @@ -215,9 +215,9 @@ def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> { An encoding for tensors that have been produced by tensor cores. It is characterized by two parameters: - A 'version' which specifies the generation the tensor cores -whose output is being partitioned: 1 for first-gen tensor cores (Volta), +whose output is being partitioned: 1 for first-gen tensor cores (Volta), and 2 for second-gen tensor cores (Turing/Ampere). -- A `blockTileSize` to indicate how data should be +- A `blockTileSize` to indicate how data should be partitioned between warps. // -------------------------------- version = 1 --------------------------- // @@ -229,7 +229,7 @@ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html For example, the matrix L corresponding to blockTileSize=[32,16] is: - warp 0 + warp 0 --------------------------------/\------------------------------- [ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ] [ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ] @@ -246,7 +246,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: [ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30] [ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31] - warp 1 = warp0 + 32 + warp 1 = warp0 + 32 --------------------------------/\------------------------------- [ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38] [ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39] @@ -260,29 +260,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: For second-gen tensor cores, the implicit warpTileSize is [16, 8]. Information about this layout can be found in the official PTX documentation https://docs.nvidia.com/cuda/parallel-thread-execution/index.html -(mma.16816 section, FP32 accumulator). +(mma.16816 section, FP32 accumulator). For example, the matrix L corresponding to blockTileSize=[32,16] is: warp 0 warp 1 -----------------/\------------- ----------------/\------------- [ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 [ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 -[ .............................. .............................. +[ .............................. .............................. [ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 [ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 [ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 -[ .............................. .............................. +[ .............................. .............................. [ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 - + warp 3 warp 4 ----------------/\------------- ----------------/\------------- [ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 [ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 -[ .............................. ............................... +[ .............................. ............................... [ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 [ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 [ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 -[ .............................. ............................... +[ .............................. ............................... [ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 }]; @@ -316,7 +316,7 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes. }]; - + let parameters = ( ins "unsigned":$dim, diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 58a59c477..18d539345 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -29,6 +29,7 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; @@ -1565,6 +1566,10 @@ public: dstLayout.isa()) { return lowerBlockedToShared(op, adaptor, rewriter); } + if (srcLayout.isa() && + dstLayout.isa()) { + return lowerSharedToDotOperand(op, adaptor, rewriter); + } if ((!srcLayout.isa() && !srcLayout.isa()) || (!dstLayout.isa() && @@ -1572,6 +1577,7 @@ public: // TODO: to be implemented return failure(); } + return lowerDistributedToDistributed(op, adaptor, rewriter); } @@ -1609,6 +1615,11 @@ private: LogicalResult lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const; + + // shared -> mma_operand + LogicalResult + lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; }; void ConvertLayoutOpConversion::processReplica( @@ -1915,6 +1926,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( rewriter.replaceOp(op, smemBase); return success(); } + /// ====================== dot codegen begin ========================== // Data loader for mma.16816 instruction. @@ -2383,16 +2395,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { private: // Convert to mma.m16n8k16 - LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adapter, + LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const; /// Convert to mma.m8n8k4 - LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adapter, + LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(false && "Not implemented yet."); return failure(); } - LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adapter, + LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(false && "Not implemented yet."); return failure(); @@ -2402,28 +2414,18 @@ private: struct DotOpConversionHelper { using TensorCoreType = DotOpConversion::TensorCoreType; - Value A, B, C, D; MmaEncodingAttr mmaLayout; - RankedTensorType ATensorTy, BTensorTy, DTensorTy; MLIRContext *ctx{}; - explicit DotOpConversionHelper(DotOp dot) - : dot(dot), mmaType(getMmaType(dot)) { - A = dot.a(); - B = dot.b(); - C = dot.c(); - D = dot.d(); - ctx = dot->getContext(); - mmaLayout = C.getType() - .cast() - .getEncoding() - .cast(); + explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout) + : mmaLayout(mmaLayout) { + ctx = mmaLayout.getContext(); } // Load SplatLike C which contains a constVal. It simply returns 4 fp32 // constVal. SmallVector loadSplatLikeC(Value C, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter) const { assert(isSplatLike(C)); int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32; @@ -2451,6 +2453,11 @@ struct DotOpConversionHelper { return {}; } + void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); } + void deduceMmaType(Type operandTy) const { + mmaType = getTensorCoreTypeFromOperand(operandTy); + } + Type getShemPtrTy() const { switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: @@ -2554,6 +2561,22 @@ struct DotOpConversionHelper { return mmaMatShape.at(mmaType); } + // Deduce the TensorCoreType from either $a or $b's type. This method is not + // safe, but we cannot get the DotOp in some getmaMatShape usage case. + TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) const { + auto tensorTy = operandTy.cast(); + auto elemTy = tensorTy.getElementType(); + if (elemTy.isF16()) + return TensorCoreType::FP32_FP16_FP16_FP32; + if (elemTy.isF32()) + return TensorCoreType::FP32_TF32_TF32_FP32; + if (elemTy.isBF16()) + return TensorCoreType::FP32_BF16_BF16_FP32; + if (elemTy.isInteger(8)) + return TensorCoreType::INT32_INT8_INT8_INT32; + return TensorCoreType::NOT_APPLICABLE; + } + int getVec() const { assert(mmaType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); @@ -2593,7 +2616,7 @@ struct DotOpConversionHelper { } private: - TensorCoreType mmaType; + mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE}; // Used on nvidia GPUs mma layout .version == 2 // Refer to @@ -2655,9 +2678,6 @@ private: {TensorCoreType::INT32_INT4_INT4_INT32, 32}, {TensorCoreType::INT32_INT8_INT8_INT32, 16}, }; - -private: - DotOp dot; }; // This class helps to adapt the existing DotOpConversion to the latest @@ -2666,21 +2686,12 @@ private: // 1. loading the specific operand matrix(for $a, $b, $c) from smem // 2. passing the loaded value and perform the mma codegen struct MMA16816ConversionHelper { - Value A, B, C, D; - RankedTensorType aTensorTy, bTensorTy, dTensorTy; - ArrayRef aShape, bShape, dShape; MmaEncodingAttr mmaLayout; ArrayRef wpt; - int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1}; - int matShapeM{-1}, matShapeN{-1}, matShapeK{-1}; - int numRepM{-1}, numRepN{-1}, numRepK{-1}; Value thread, lane, warp, warpMN, warpN, warpM; - size_t aElemBytes{}, bElemBytes{}; DotOpConversionHelper helper; - triton::DotOp op; - DotOpAdaptor adapter; ConversionPatternRewriter &rewriter; TypeConverter *typeConverter; Location loc; @@ -2688,64 +2699,75 @@ struct MMA16816ConversionHelper { using ValueTable = std::map, Value>; - MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter, + MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, Location loc) - : helper(op), op(op), adapter(adapter), rewriter(rewriter), - typeConverter(typeConverter), loc(loc), ctx(op.getContext()), + : mmaLayout(mmaLayout), helper(mmaLayout), rewriter(rewriter), + typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()), thread(thread) { - A = op.a(); - B = op.b(); - C = op.c(); - D = op.getResult(); - - aTensorTy = A.getType().cast(); - bTensorTy = B.getType().cast(); - dTensorTy = D.getType().cast(); - - aShape = aTensorTy.getShape(); - bShape = bTensorTy.getShape(); - dShape = dTensorTy.getShape(); - - mmaLayout = dTensorTy.getEncoding().cast(); - wpt = mmaLayout.getWarpsPerCTA(); - auto mmaInstrShape = helper.getMmaInstrShape(); - mmaInstrM = mmaInstrShape[0]; - mmaInstrN = mmaInstrShape[1]; - mmaInstrK = mmaInstrShape[2]; - - auto matShape = helper.getMmaMatShape(); - matShapeM = matShape[0]; - matShapeN = matShape[1]; - matShapeK = matShape[2]; - - int NK = aShape[1]; - // shape / shape_per_cta - numRepM = std::max(dShape[0] / (wpt[0] * mmaInstrM), 1); - numRepN = std::max(dShape[1] / (wpt[1] * mmaInstrN), 1); - numRepK = std::max(NK / mmaInstrK, 1); - Value _32 = i32_val(32); lane = urem(thread, _32); warp = udiv(thread, _32); warpMN = udiv(warp, i32_val(wpt[0])); warpM = urem(warp, i32_val(wpt[0])); warpN = urem(warpMN, i32_val(wpt[1])); + } - aElemBytes = aTensorTy.getElementTypeBitWidth() / 8; - bElemBytes = bTensorTy.getElementTypeBitWidth() / 8; + // Get the mmaInstrShape from either $a or $b. + std::tuple getMmaInstrShape(Type operand) const { + helper.deduceMmaType(operand); + auto mmaInstrShape = helper.getMmaInstrShape(); + int mmaInstrM = mmaInstrShape[0]; + int mmaInstrN = mmaInstrShape[1]; + int mmaInstrK = mmaInstrShape[2]; + return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK); + } + + std::tuple getMmaMatShape(Type operand) const { + helper.deduceMmaType(operand); + auto matShape = helper.getMmaMatShape(); + int matShapeM = matShape[0]; + int matShapeN = matShape[1]; + int matShapeK = matShape[2]; + return std::make_tuple(matShapeM, matShapeN, matShapeK); + } + + // \param operand is either $a or $b's type. + inline int getNumRepM(Type operand, int M) const { + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand); + return std::max(M / (wpt[0] * mmaInstrM), 1); + } + + // \param operand is either $a or $b's type. + inline int getNumRepN(Type operand, int N) const { + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand); + return std::max(N / (wpt[1] * mmaInstrN), 1); + } + + // \param operand is either $a or $b's type. + inline int getNumRepK(Type operand, int K) const { + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(operand); + return std::max(K / mmaInstrK, 1); } // Loading $a from smem to registers, returns a LLVM::Struct. - Value loadA() { + Value loadA(Value tensor, Value llTensor) const { + auto aTensorTy = tensor.getType().cast(); + auto shape = aTensorTy.getShape(); + ValueTable ha; std::function loadFn; + auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); + int numRepM = getNumRepM(aTensorTy, shape[0]); + int numRepK = getNumRepK(aTensorTy, shape[1]); + if (aTensorTy.getEncoding().isa()) { // load from smem loadFn = getLoadMatrixFn( - A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); } else if (aTensorTy.getEncoding().isa()) { @@ -2770,10 +2792,17 @@ struct MMA16816ConversionHelper { } // Loading $b from smem to registers, returns a LLVM::Struct. - Value loadB() { + Value loadB(Value tensor, Value llTensor) { ValueTable hb; + auto tensorTy = tensor.getType().cast(); + auto shape = tensorTy.getShape(); + auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy); + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy); + int numRepK = getNumRepK(tensorTy, shape[0]); + int numRepN = getNumRepN(tensorTy, shape[1]); + auto loadFn = getLoadMatrixFn( - B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + tensor, llTensor, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); @@ -2789,24 +2818,47 @@ struct MMA16816ConversionHelper { // Loading $c from smem(?) to registers, returns a Value. // NOTE Only SplatLike tensor is supported now. - Value loadC() { + Value loadC(Value tensor) const { // Currently, we only support a SplatLike C. For the other cases, e.g., C in // shared layout or blocked layout, we will support them by expanding // convert_layout. - auto hc = helper.loadSplatLikeC(C, loc, rewriter); + auto hc = helper.loadSplatLikeC(tensor, loc, rewriter); assert(hc.size() == 4UL && "Only splat-like C is supported now"); return hc[0]; } // Conduct the Dot conversion. - // Input the \param a, \param b, \param c, all of them are result of loading. - LogicalResult convertDot(Value a, Value b, Value c) { - ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK); + // \param a, \param b, \param c and \param d are DotOp operands. + // \param loadedA, \param loadedB, \param loadedC, all of them are result of + // loading. + LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA, + Value loadedB, Value loadedC, DotOp op, + DotOpAdaptor adaptor) const { + helper.deduceMmaType(op); + + auto aTensorTy = a.getType().cast(); + auto bTensorTy = b.getType().cast(); + auto cTensorTy = c.getType().cast(); + auto dTensorTy = d.getType().cast(); + + auto aShape = aTensorTy.getShape(); + auto dShape = dTensorTy.getShape(); + + int NK = aShape[1]; + // shape / shape_per_cta + auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); + int numRepM = getNumRepM(aTensorTy, dShape[0]); + int numRepN = getNumRepN(aTensorTy, dShape[1]); + int numRepK = getNumRepK(aTensorTy, aShape[1]); + + ValueTable ha = + getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK); ValueTable hb = getValuesFromDotOperandLayoutStruct( - b, std::max(numRepN / 2, 1), numRepK); + loadedB, std::max(numRepN / 2, 1), numRepK); const int fcSize = 4 * numRepM * numRepN; - SmallVector fc(fcSize, c); + SmallVector fc(fcSize, loadedC); auto callMma = [&](unsigned m, unsigned n, unsigned k) { unsigned colsPerThread = numRepN * 2; @@ -2855,10 +2907,10 @@ struct MMA16816ConversionHelper { private: std::function - getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder, - ArrayRef instrShape, ArrayRef matShape, - Value warpId, ValueTable &vals) { - + getLoadMatrixFn(Value tensor, Value llTensor, MmaEncodingAttr mmaLayout, + int wpt, int kOrder, ArrayRef instrShape, + ArrayRef matShape, Value warpId, + ValueTable &vals) const { auto tensorTy = tensor.getType().cast(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. @@ -2928,7 +2980,7 @@ private: // i \in [0, n0) and j \in [0, n1) // There should be \param n0 * \param n1 elements in the output Struct. Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, - int n1) { + int n1) const { std::vector elems; for (unsigned m = 0; m < n0; ++m) for (unsigned k = 0; k < n1; ++k) { @@ -2940,7 +2992,7 @@ private: assert(!elems.empty()); - Type fp16Ty = aTensorTy.getElementType(); + Type fp16Ty = type::f16Ty(ctx); Type fp16x2Ty = vec_ty(fp16Ty, 2); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(elems.size(), fp16x2Ty)); @@ -2948,7 +3000,8 @@ private: return result; } - ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) { + ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, + int n1) const { auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( loc, value, rewriter); @@ -2966,18 +3019,79 @@ private: } }; +LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value src = op.src(); + Value dst = op.result(); + auto srcTensorTy = src.getType().cast(); + auto dstTensorTy = dst.getType().cast(); + + auto sharedLayout = srcTensorTy.getEncoding().cast(); + auto dotOperandLayout = + dstTensorTy.getEncoding().cast(); + MmaEncodingAttr mmaLayout = + dotOperandLayout.getParent().dyn_cast_or_null(); + assert(mmaLayout); + + MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), + rewriter, getTypeConverter(), op.getLoc()); + + Value res; + if (dotOperandLayout.getOpIdx() == 0) { + // operand $a + res = mmaHelper.loadA(src, adaptor.src()); + } else if (dotOperandLayout.getOpIdx() == 1) { + // operand $b + res = mmaHelper.loadB(src, adaptor.src()); + } else if (dotOperandLayout.getOpIdx() == 2) { + // operand $c + res = mmaHelper.loadC(src); + } + + rewriter.replaceOp(op, res); + return success(); +} + LogicalResult -DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, +DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter, + auto mmaLayout = op.getResult() + .getType() + .cast() + .getEncoding() + .cast(); + MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), rewriter, getTypeConverter(), loc); - auto A = mmaHelper.loadA(); - auto B = mmaHelper.loadB(); - auto C = mmaHelper.loadC(); + Value A = op.a(); + Value B = op.b(); + Value C = op.c(); + auto ATensorTy = A.getType().cast(); + auto BTensorTy = B.getType().cast(); - return mmaHelper.convertDot(A, B, C); + Value loadedA, loadedB, loadedC; + // We support two kinds of operand layouts: 1. both $a, $b are dot_operand + // layout, 2. both of them are shared layout. + if (ATensorTy.getEncoding().isa()) { + assert(BTensorTy.getEncoding().isa() && + "Both $a and %b should be DotOperand layout."); + loadedA = adaptor.a(); + loadedB = adaptor.b(); + } else { + loadedA = mmaHelper.loadA(op.a(), adaptor.a()); + loadedB = mmaHelper.loadB(op.b(), adaptor.b()); + } + + // TODO[Superjomn]: Process C as a mma layout. + // Currently, C is simply treated as a Splat Op, and the data layout is not + // mattered. + loadedC = mmaHelper.loadC(op.c()); + + return mmaHelper.convertDot(A, B, C, op.d(), loadedA, loadedB, loadedC, op, + adaptor); } /// ====================== mma codegen end ============================ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 12d88f5b1..0d9bc568d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -610,4 +610,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // TODO: fill this. return success(); -} \ No newline at end of file +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index eb9bc404d..4652f6c2b 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -714,3 +714,27 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { return } } + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + %a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a> + %b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b> + + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> + %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked> + + %30 = tt.splat %ptr : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr, #blocked>) -> tensor<128x256x!tt.ptr, #blocked> + tt.store %36, %38 : tensor<128x256xf32, #blocked> + return + } +}