From 1baa4e125f5193fd3ac54388bff0999692c34309 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 12 Oct 2022 10:45:17 +0800 Subject: [PATCH] [triton-mlir][BACKEND] decouple loading from mma codegen in dot conversion (#764) This PR decouples the operand loading from the mma codegen to make it ready for the ongoing `DotOperandEncodingAttr` migration. The existing DotOp conversion is composed of the following two procedures: 1. Loading the $a,$b,$c operand from smem to registers 2. Conducting the MMA instruction codegen. While in the latest design, the 1st stage should be part of the `convert_layout(shared_layout) -> dot_operand_layout`, that's why the decoupling is necessary. Some details, this PR introduces a `MMA16816ConversionHelper` class, it has `loadA`, `loadB` and `loadC` methods to help load $a, $b and $c from smem to registers, both `loadA` and `loadB` methods returns a `LLVM::Struct` which should be compatible with the new `DotOperandEncodingAttr` conversion. The conversion layout for $a and $b is as follows: ```c++ // The layout is a list of Value with coordinate of (i,j), the order is as // the follows: // [ // (0,0), (0,1), (1,0), (1,1), # i=0, j=0 // (0,2), (0,3), (1,2), (1,3), # i=0, j=1 // (0,4), (0,5), (1,4), (1,5), # i=0, j=2 // ... // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 // (2,4), (2,5), (2,4), (2,5), # i=1, j=2 // ... // ] // i \in [0, n0) and j \in [0, n1) ``` In the `convertDot` method, it takes loaded $a, $b, $c($a and $b are type of `LLVM::Struct` while $c is a scalar Value), extract the elements from `LLVM::Struct` following the layout above, and pass the elements to MMA inline asm. --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 453 +++++++++++------- 1 file changed, 286 insertions(+), 167 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ac7111bf2..2820e03c0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -337,12 +337,13 @@ static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { struct ConvertTritonGPUOpToLLVMPatternBase { static SmallVector - getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems, + getElementsFromStruct(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { - SmallVector results(elems); - for (unsigned i = 0; i < elems; ++i) { - Type type = - llvmStruct.getType().cast().getBody()[i]; + ArrayRef types = + llvmStruct.getType().cast().getBody(); + SmallVector results(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i)); } return results; @@ -715,8 +716,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { auto shape = value.getType().cast().getShape(); // Here, we assume that all inputs should have a blockedLayout - unsigned valueElems = layout.getElemsPerThread(shape); - auto valueVals = getElementsFromStruct(loc, llValue, valueElems, rewriter); + auto valueVals = getElementsFromStruct(loc, llValue, rewriter); return valueVals; } @@ -977,7 +977,7 @@ struct BroadcastOpConversion unsigned srcElems = srcLayout.getElemsPerThread(srcShape); auto elemTy = resultTy.getElementType(); - auto srcVals = getElementsFromStruct(loc, src, srcElems, rewriter); + auto srcVals = getElementsFromStruct(loc, src, rewriter); unsigned resultElems = resultLayout.getElemsPerThread(resultShape); SmallVector resultVals(resultElems); for (unsigned i = 0; i < srcElems; ++i) { @@ -1025,8 +1025,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - auto vals = - this->getElementsFromStruct(loc, adaptor.src(), elems, rewriter); + auto vals = this->getElementsFromStruct(loc, adaptor.src(), rewriter); Value view = getStructFromElements(loc, vals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); @@ -1308,9 +1307,8 @@ struct AddPtrOpConversion getTypeConverter()->convertType(resultTensorTy.getElementType()); SmallVector types(elems, elemTy); Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); - auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter); - auto offsets = - getElementsFromStruct(loc, adaptor.offset(), elems, rewriter); + auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), rewriter); + auto offsets = getElementsFromStruct(loc, adaptor.offset(), rewriter); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { resultVals[i] = gep(elemTy, ptrs[i], offsets[i]); @@ -1416,9 +1414,9 @@ public: auto *concreteThis = static_cast(this); auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor), - elems, rewriter); + rewriter); auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor), - elems, rewriter); + rewriter); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i], @@ -1761,7 +1759,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( // Potentially we need to store for multiple CTAs in this replication unsigned accumNumReplicates = product(numReplicates); unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape()); - auto vals = getElementsFromStruct(loc, adaptor.src(), elems, rewriter); + auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); unsigned inVec = 0; unsigned outVec = 0; auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); @@ -1824,7 +1822,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( unsigned perPhase = dstSharedLayout.getPerPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape); - auto inVals = getElementsFromStruct(loc, adaptor.src(), numElems, rewriter); + auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter); unsigned srcAccumSizeInThreads = product(srcBlockedLayout.getSizePerThread()); auto elemTy = srcTy.getElementType(); @@ -2661,75 +2659,205 @@ private: DotOp dot; }; -LogicalResult -DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, - ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); - MLIRContext *ctx = op->getContext(); - // D = A * B + C - Value A = op.a(); - Value B = op.b(); - Value C = op.c(); - Value D = op.getResult(); - bool allowTF32 = op.allowTF32(); +// This class helps to adapt the existing DotOpConversion to the latest +// DotOpOperand layout design. It decouples the exising implementation to two +// parts: +// 1. loading the specific operand matrix(for $a, $b, $c) from smem +// 2. passing the loaded value and perform the mma codegen +struct MMA16816ConversionHelper { + Value A, B, C, D; + RankedTensorType aTensorTy, bTensorTy, dTensorTy; + ArrayRef aShape, bShape, dShape; + MmaEncodingAttr mmaLayout; + ArrayRef wpt; - auto aTensorTy = A.getType().cast(); - auto bTensorTy = B.getType().cast(); - auto dTensorTy = D.getType().cast(); + int mmaInstrM{-1}, mmaInstrN{-1}, mmaInstrK{-1}; + int matShapeM{-1}, matShapeN{-1}, matShapeK{-1}; + int numRepM{-1}, numRepN{-1}, numRepK{-1}; + Value thread, lane, warp, warpMN, warpN, warpM; + size_t aElemBytes{}, bElemBytes{}; - auto aShape = aTensorTy.getShape(); - auto bShape = bTensorTy.getShape(); - auto dShape = dTensorTy.getShape(); + DotOpConversionHelper helper; + triton::DotOp op; + DotOpAdaptor adapter; + ConversionPatternRewriter &rewriter; + TypeConverter *typeConverter; + Location loc; + MLIRContext *ctx{}; - auto mmaLayout = dTensorTy.getEncoding().cast(); + using ValueTable = std::map, Value>; - auto wpt = mmaLayout.getWarpsPerCTA(); + MMA16816ConversionHelper(triton::DotOp op, Value thread, DotOpAdaptor adapter, + ConversionPatternRewriter &rewriter, + TypeConverter *typeConverter, Location loc) + : helper(op), op(op), adapter(adapter), rewriter(rewriter), + typeConverter(typeConverter), loc(loc), ctx(op.getContext()), + thread(thread) { + A = op.a(); + B = op.b(); + C = op.c(); + D = op.getResult(); - // TODO(Superjomn) Process C->is_trans_a() logic + aTensorTy = A.getType().cast(); + bTensorTy = B.getType().cast(); + dTensorTy = D.getType().cast(); - DotOpConversionHelper helper(op); + aShape = aTensorTy.getShape(); + bShape = bTensorTy.getShape(); + dShape = dTensorTy.getShape(); - int NK = aShape[1]; + mmaLayout = dTensorTy.getEncoding().cast(); - auto mmaInstrShape = helper.getMmaInstrShape(); - const int mmaInstrM = mmaInstrShape[0]; - const int mmaInstrN = mmaInstrShape[1]; - const int mmaInstrK = mmaInstrShape[2]; + wpt = mmaLayout.getWarpsPerCTA(); - auto matShape = helper.getMmaMatShape(); - const int matShapeM = matShape[0]; - const int matShapeN = matShape[1]; - const int matShapeK = matShape[2]; + auto mmaInstrShape = helper.getMmaInstrShape(); + mmaInstrM = mmaInstrShape[0]; + mmaInstrN = mmaInstrShape[1]; + mmaInstrK = mmaInstrShape[2]; - // shape / shape_per_cta - const int numRepM = std::max(dShape[0] / (wpt[0] * mmaInstrM), 1); - const int numRepN = std::max(dShape[1] / (wpt[1] * mmaInstrN), 1); - const int numRepK = std::max(NK / mmaInstrK, 1); + auto matShape = helper.getMmaMatShape(); + matShapeM = matShape[0]; + matShapeN = matShape[1]; + matShapeK = matShape[2]; - Value _32 = i32_val(32); - Value thread = getThreadId(rewriter, loc); - Value lane = urem(thread, _32); - Value warp = udiv(thread, _32); - Value warpMN = udiv(warp, i32_val(wpt[0])); - Value warpM = urem(warp, i32_val(wpt[0])); - Value warpN = urem(warpMN, i32_val(wpt[1])); + int NK = aShape[1]; + // shape / shape_per_cta + numRepM = std::max(dShape[0] / (wpt[0] * mmaInstrM), 1); + numRepN = std::max(dShape[1] / (wpt[1] * mmaInstrN), 1); + numRepK = std::max(NK / mmaInstrK, 1); - size_t aElemBytes = aTensorTy.getElementTypeBitWidth() / 8; - size_t bElemBytes = bTensorTy.getElementTypeBitWidth() / 8; + Value _32 = i32_val(32); + lane = urem(thread, _32); + warp = udiv(thread, _32); + warpMN = udiv(warp, i32_val(wpt[0])); + warpM = urem(warp, i32_val(wpt[0])); + warpN = urem(warpMN, i32_val(wpt[1])); - std::map, Value> ha; - std::map, Value> hb; + aElemBytes = aTensorTy.getElementTypeBitWidth() / 8; + bElemBytes = bTensorTy.getElementTypeBitWidth() / 8; + } - // the original register_lds2, but discard the prefetch logic. - auto ld2 = [](decltype(ha) &vals, int mn, int k, Value val) { - vals[{mn, k}] = val; - }; + // Loading $a from smem to registers, returns a LLVM::Struct. + Value loadA() { + ValueTable ha; + std::function loadFn; + if (aTensorTy.getEncoding().isa()) { + // load from smem + loadFn = getLoadMatrixFn( + A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); + } else if (aTensorTy.getEncoding().isa()) { + // load from registers, used in gemm fuse + // TODO(Superjomn) Port the logic. + assert(false && "Loading A from register is not supported yet."); + } else { + assert(false && "A's layout is not supported."); + } + + // step1. Perform loading. + for (unsigned m = 0; m < numRepM; ++m) + for (unsigned k = 0; k < numRepK; ++k) + loadFn(2 * m, 2 * k); + + // step2. Format the values to LLVM::Struct to passing to mma codegen. + Value result = composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); + + // TODO[Superjomn]: Replace the convert_layout op with the result once the + // DotOperandEncodingAttr is ready. + return result; + } + + // Loading $b from smem to registers, returns a LLVM::Struct. + Value loadB() { + ValueTable hb; + auto loadFn = getLoadMatrixFn( + B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); + + for (unsigned n = 0; n < std::max(numRepN / 2, 1); ++n) { + for (unsigned k = 0; k < numRepK; ++k) + loadFn(2 * n, 2 * k); + } + + Value result = composeValuesToDotOperandLayoutStruct( + hb, std::max(numRepN / 2, 1), numRepK); + return result; + } + + // Loading $c from smem(?) to registers, returns a Value. + // NOTE Only SplatLike tensor is supported now. + Value loadC() { + // Currently, we only support a SplatLike C. For the other cases, e.g., C in + // shared layout or blocked layout, we will support them by expanding + // convert_layout. + auto hc = helper.loadSplatLikeC(C, loc, rewriter); + assert(hc.size() == 4UL && "Only splat-like C is supported now"); + return hc[0]; + } + + // Conduct the Dot conversion. + // Input the \param a, \param b, \param c, all of them are result of loading. + LogicalResult convertDot(Value a, Value b, Value c) { + ValueTable ha = getValuesFromDotOperandLayoutStruct(a, numRepM, numRepK); + ValueTable hb = getValuesFromDotOperandLayoutStruct( + b, std::max(numRepN / 2, 1), numRepK); + + const int fcSize = 4 * numRepM * numRepN; + SmallVector fc(fcSize, c); + + auto callMma = [&](unsigned m, unsigned n, unsigned k) { + unsigned colsPerThread = numRepN * 2; + PTXBuilder builder; + auto &mma = *builder.create(helper.getMmaInstr().str()); + auto retArgs = builder.newListOperand(4, "=r"); + auto aArgs = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs = + builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < 4; ++i) { + cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i], + std::to_string(i))); + // reuse the output registers + } + mma(retArgs, aArgs, bArgs, cArgs); + Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); + + auto getIntAttr = [&](int v) { + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); + }; + + for (int i = 0; i < 4; i++) + fc[m * colsPerThread + 4 * n + i] = + extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); + }; + + for (unsigned k = 0; k < numRepK; ++k) + for (unsigned m = 0; m < numRepM; ++m) + for (unsigned n = 0; n < numRepN; ++n) + callMma(2 * m, n, 2 * k); + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(fc.size(), type::f32Ty(ctx))); + Value res = getStructFromElements(loc, fc, rewriter, structTy); + rewriter.replaceOp(op, res); + + return success(); + } + +private: + std::function + getLoadMatrixFn(Value tensor, Value llTensor, int wpt, int kOrder, + ArrayRef instrShape, ArrayRef matShape, + Value warpId, ValueTable &vals) { - // Load A or B matrix. - auto getLoadMatrixFn = - [&](Value tensor, Value llTensor, int wpt, int kOrder, - ArrayRef instrShape, ArrayRef matShape, Value warpId, - decltype(ha) &vals) -> std::function { auto tensorTy = tensor.getType().cast(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. @@ -2739,25 +2867,31 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; auto order = sharedLayout.getOrder(); - MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, - tensorTy.getShape() /*tileShape*/, instrShape, - matShape, perPhase, maxPhase, elemBytes, rewriter, - typeConverter, loc); - SmallVector offs = loader.computeOffsets(warpId, lane); - - const int numPtrs = loader.getNumPtr(); - SmallVector ptrs(numPtrs); - - Type smemPtrTy = helper.getShemPtrTy(); - for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = - bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); - } - bool needTrans = kOrder != order[0]; + // the original register_lds2, but discard the prefetch logic. + auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { + vals[{mn, k}] = val; + }; + // (a, b) is the coordinate. - auto load = [=, &vals, &helper, &ld2](int a, int b) { + auto load = [=, &vals, &ld2](int a, int b) { + MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, + tensorTy.getShape() /*tileShape*/, instrShape, + matShape, perPhase, maxPhase, elemBytes, + rewriter, typeConverter, loc); + SmallVector offs = loader.computeOffsets(warpId, lane); + + const int numPtrs = loader.getNumPtr(); + + SmallVector ptrs(numPtrs); + + Type smemPtrTy = helper.getShemPtrTy(); + for (int i = 0; i < numPtrs; ++i) { + ptrs[i] = + bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); + } + auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, ptrs, helper.getMatType(), helper.getShemPtrTy()); @@ -2775,89 +2909,74 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, }; return load; - }; - - std::function loadA; - if (aTensorTy.getEncoding().isa()) { - // load from smem - loadA = getLoadMatrixFn( - A, adapter.a() /*llTensor*/, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, - 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); - } else if (aTensorTy.getEncoding().isa()) { - // load from registers, used in gemm fuse - // TODO(Superjomn) Port the logic. - assert(false && "Loading A from register is not supported yet."); - } else { - assert(false && "A's layout is not supported."); } - std::function loadB = getLoadMatrixFn( - B, adapter.b() /*llTensor*/, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, - 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); - - const int fcSize = 4 * numRepM * numRepN; - SmallVector fc(fcSize); - - // Currently, we only support a SplatLike C. For the other cases, e.g., C in - // shared layout or blocked layout, we will support them by expanding - // convert_layout. - auto hc = helper.loadSplatLikeC(C, loc, rewriter); - assert(hc.size() == 4UL && "Only splat-like C is supported now"); - for (int i = 0; i < fc.size(); i++) - fc[i] = hc[0]; - - auto callMma = [&](unsigned m, unsigned n, unsigned k) { - unsigned colsPerThread = numRepN * 2; - PTXBuilder builder; - auto &mma = *builder.create(helper.getMmaInstr().str()); - auto retArgs = builder.newListOperand(4, "=r"); - auto aArgs = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs = - builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); - auto cArgs = builder.newListOperand(); - for (int i = 0; i < 4; ++i) { - cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i], - std::to_string(i))); - // reuse the output registers - } - mma(retArgs, aArgs, bArgs, cArgs); - Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); - - auto getIntAttr = [&](int v) { - return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); - }; - - for (int i = 0; i < 4; i++) - fc[m * colsPerThread + 4 * n + i] = - extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); - }; - - // Main program - for (unsigned k = 0; k < numRepK; ++k) { - for (unsigned m = 0; m < numRepM; ++m) - loadA(2 * m, 2 * k); - for (unsigned n = 0; n < numRepN; n += 2) - loadB(n, 2 * k); - for (unsigned m = 0; m < numRepM; ++m) - for (unsigned n = 0; n < numRepN; ++n) { - callMma(2 * m, n, 2 * k); + // Compose a map of Values to a LLVM::Struct. + // The layout is a list of Value with coordinate of (i,j), the order is as + // the follows: + // [ + // (0,0), (0,1), (1,0), (1,1), # i=0, j=0 + // (0,2), (0,3), (1,2), (1,3), # i=0, j=1 + // (0,4), (0,5), (1,4), (1,5), # i=0, j=2 + // ... + // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 + // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 + // (2,4), (2,5), (2,4), (2,5), # i=1, j=2 + // ... + // ] + // i \in [0, n0) and j \in [0, n1) + // There should be \param n0 * \param n1 elements in the output Struct. + Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, + int n1) { + std::vector elems; + for (unsigned m = 0; m < n0; ++m) + for (unsigned k = 0; k < n1; ++k) { + elems.push_back(vals.at({2 * m, 2 * k})); + elems.push_back(vals.at({2 * m, 2 * k + 1})); + elems.push_back(vals.at({2 * m + 1, 2 * k})); + elems.push_back(vals.at({2 * m + 1, 2 * k + 1})); } + + assert(!elems.empty()); + + Type fp16Ty = aTensorTy.getElementType(); + Type fp16x2Ty = vec_ty(fp16Ty, 2); + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems.size(), fp16x2Ty)); + auto result = getStructFromElements(loc, elems, rewriter, structTy); + return result; } - // replace with new packed result - Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(fc.size(), type::f32Ty(ctx))); - Value res = getStructFromElements(loc, fc, rewriter, structTy); - rewriter.replaceOp(op, res); + ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) { + auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( + loc, value, rewriter); - return success(); + int offset{}; + ValueTable vals; + for (int i = 0; i < n0; i++) { + for (int j = 0; j < n1; j++) { + vals[{2 * i, 2 * j}] = elems[offset++]; + vals[{2 * i, 2 * j + 1}] = elems[offset++]; + vals[{2 * i + 1, 2 * j}] = elems[offset++]; + vals[{2 * i + 1, 2 * j + 1}] = elems[offset++]; + } + } + return vals; + } +}; + +LogicalResult +DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + MMA16816ConversionHelper mmaHelper(op, getThreadId(rewriter, loc), adapter, + rewriter, getTypeConverter(), loc); + + auto A = mmaHelper.loadA(); + auto B = mmaHelper.loadB(); + auto C = mmaHelper.loadC(); + + return mmaHelper.convertDot(A, B, C); } /// ====================== mma codegen end ============================ @@ -3012,9 +3131,9 @@ struct InsertSliceAsyncOpConversion auto inOrder = srcBlockedLayout.getOrder(); auto outOrder = resSharedLayout.getOrder(); - // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over elements - // across phases. - // If perPhase * maxPhase == threadsPerCTA, swizzle is not allowd + // If perPhase * maxPhase > threadsPerCTA, we need to swizzle over + // elements across phases. If perPhase * maxPhase == threadsPerCTA, + // swizzle is not allowd auto numSwizzleRows = std::max( (perPhase * maxPhase) / threadsPerCTA[inOrder[1]], 1); // A sharedLayout encoding has a "vec" parameter.