diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 62c9e1dae..f3110988a 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -88,9 +88,7 @@ jobs: - name: Run python tests on V100 if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'V100'}} run: | - # TODO[Superjomn]: Remove the forloop-unroll setting after pipeline pass works cd python/tests - export TRITON_STATIC_LOOP_UNROLLING=1 pytest test_gemm.py::test_gemm_for_mmav1 - name: Run CXX unittests diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index 96a9b764c..8703bebcb 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -43,12 +43,51 @@ using ::mlir::triton::gpu::SharedEncodingAttr; struct DotOpMmaV1ConversionHelper { MmaEncodingAttr mmaLayout; ArrayRef wpt; + static constexpr std::array fpw{{2, 2, 1}}; using ValueTable = std::map, std::pair>; explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout) : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} + // Help to share some variables across multiple functions for A. + struct AParam { + SmallVector rep; + SmallVector spw; + + // TODO[Superjomn]: Support the case when isAVec4=false later + // Currently, we only support ld.v2, for the mma layout varies with + // different ld vector width. + // bool isAVec4 = !isARow && shapeTransed[orderTransed[0]] <= 16; + const bool isAVec4{true}; + + explicit AParam(bool isARow) { + int packSize0 = (isARow || isAVec4) ? 1 : 2; + int repM = 2 * packSize0; + int repK = 1; + int spwM = fpw[0] * 4 * repM; + rep.assign({repM, 0, repK}); + spw.assign({spwM, 0, 1}); + } + }; + + // Help to share some variables across multiple functions for A. + struct BParam { + SmallVector rep; + SmallVector spw; + // TODO[Superjomn]: Support the case when isBVec4=false later + // Currently, we only support ld.v2, for the mma layout varies with + // different ld vector width. + // bool isBVec4 = isBRow && shapeTransed[orderTransed[0]] <= 16; + const bool isBVec4{true}; + + explicit BParam(bool isBRow) { + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + rep.assign({0, 2 * packSize1, 1}); + spw.assign({0, fpw[1] * 4 * rep[1], 1}); + } + }; + int getRepM(int M) const { return std::max(M / (wpt[0] * instrShape[0]), 1); } @@ -65,29 +104,34 @@ struct DotOpMmaV1ConversionHelper { return struct_ty(SmallVector{8, fp32Ty}); } - // number of fp16x2 elements for $a. - int numElemsPerThreadA(RankedTensorType tensorTy) const { - auto shape = tensorTy.getShape(); - auto order = getOrder(); + // Get the number of fp16x2 elements for $a. + // \param shapeTransed: the shape or reordered shape if transpose needed. + // \param orderTransed: the order or reordered order if transpose needed. + unsigned getNumM(ArrayRef shapeTransed, + ArrayRef orderTransed) const { + bool isARow = orderTransed[0] != 0; + AParam param(isARow); - bool isARow = order[0] != 0; - bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes - // TODO[Superjomn]: Support the case when isAVec4=false later - // Currently, we only support ld.v2, for the mma layout varies with - // different ld vector width. - isAVec4 = true; + unsigned numM = param.rep[0] * shapeTransed[0] / (param.spw[0] * wpt[0]); + return numM; + } - int packSize0 = (isARow || isAVec4) ? 1 : 2; + // Get the number of fp16x2 elements for $b. + // \param shapeTransed: the shape or reordered shape if transpose needed. + // \param orderTransed: the order or reordered order if transpose needed. + unsigned getNumN(ArrayRef shapeTransed, + ArrayRef orderTransed) const { + bool isBRow = orderTransed[0] != 0; + BParam param(isBRow); - SmallVector fpw({2, 2, 1}); - int repM = 2 * packSize0; - int repK = 1; - int spwM = fpw[0] * 4 * repM; - SmallVector rep({repM, 0, repK}); // pad N with 0 - SmallVector spw({spwM, 0, 1}); // pad N with 0 + unsigned numN = param.rep[1] * shapeTransed[1] / (param.spw[1] * wpt[1]); + return numN; + } - int NK = shape[1]; - unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); + int numElemsPerThreadA(ArrayRef shapeTransed, + ArrayRef orderTransed) const { + int numM = getNumM(shapeTransed, orderTransed); + int NK = shapeTransed[1]; // NOTE: We couldn't get the vec from the shared layout. // int vecA = sharedLayout.getVec(); @@ -97,39 +141,27 @@ struct DotOpMmaV1ConversionHelper { return (numM / 2) * (NK / 4) * elemsPerLd; } - // number of fp16x2 elements for $b. - int numElemsPerThreadB(RankedTensorType tensorTy) const { - auto shape = tensorTy.getShape(); - auto order = getOrder(); - bool isBRow = order[0] != 0; - bool isBVec4 = isBRow && shape[order[0]] <= 16; - // TODO[Superjomn]: Support the case when isBVec4=false later - // Currently, we only support ld.v2, for the mma layout varies with - // different ld vector width. - isBVec4 = true; - - int packSize1 = (isBRow && !isBVec4) ? 2 : 1; - SmallVector fpw({2, 2, 1}); - SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 - SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 + int numElemsPerThreadB(ArrayRef shapeTransed, + ArrayRef orderTransed) const { + unsigned numN = getNumN(shapeTransed, orderTransed); + int NK = shapeTransed[0]; // NOTE: We couldn't get the vec from the shared layout. // int vecB = sharedLayout.getVec(); // TODO[Superjomn]: Consider the case when vecA > 4 bool vecGt4 = false; int elemsPerLd = vecGt4 ? 4 : 2; - int NK = shape[0]; - - unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); return (numN / 2) * (NK / 4) * elemsPerLd; } // Loading $a from smem to registers, returns a LLVM::Struct. - Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread, - Location loc, ConversionPatternRewriter &rewriter) const; + Value loadA(Value A, bool transA, const SharedMemoryObject &smemObj, + Value thread, Location loc, + ConversionPatternRewriter &rewriter) const; // Loading $b from smem to registers, returns a LLVM::Struct. - Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, - Location loc, ConversionPatternRewriter &rewriter) const; + Value loadB(Value B, bool transB, const SharedMemoryObject &smemObj, + Value thread, Location loc, + ConversionPatternRewriter &rewriter) const; static ArrayRef getOrder() { return mmaOrder; } @@ -1321,8 +1353,8 @@ struct DotOpFMAConversionHelper { }; Value DotOpMmaV1ConversionHelper::loadA( - Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, - ConversionPatternRewriter &rewriter) const { + Value tensor, bool transA, const SharedMemoryObject &smemObj, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const { auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); @@ -1336,24 +1368,11 @@ Value DotOpMmaV1ConversionHelper::loadA( Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isARow = order[0] != 0; - bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes - // TODO[Superjomn]: Support the case when isAVec4=false later - // Currently, we only support ld.v2, for the mma layout varies with different - // ld vector width. - isAVec4 = true; - int packSize0 = (isARow || isAVec4) ? 1 : 2; + AParam param(isARow); - SmallVector fpw({2, 2, 1}); - int repM = 2 * packSize0; - int repK = 1; - int spwM = fpw[0] * 4 * repM; - SmallVector rep({repM, 0, repK}); // pad N with 0 - SmallVector spw({spwM, 0, 1}); // pad N with 0 + auto [offsetAM, offsetAK, _0, _1] = computeOffsets( + thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc); - auto [offsetAM, offsetAK, _0, _1] = - computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); - // TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp. - bool transA = false; if (transA) { std::swap(shape[0], shape[1]); std::swap(offsetAM, offsetAK); @@ -1401,8 +1420,6 @@ Value DotOpMmaV1ConversionHelper::loadA( for (int i = 0; i < numPtrA; i++) ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]); - unsigned numM = std::max(rep[0] * shape[0] / (spw[0] * wpt[0]), 1); - Type f16PtrTy = ptr_ty(f16_ty); auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { @@ -1434,6 +1451,7 @@ Value DotOpMmaV1ConversionHelper::loadA( } }; + unsigned numM = getNumM(shape, order); for (unsigned k = 0; k < NK; k += 4) for (unsigned m = 0; m < numM / 2; ++m) loadA(m, k); @@ -1451,8 +1469,8 @@ Value DotOpMmaV1ConversionHelper::loadA( } Value DotOpMmaV1ConversionHelper::loadB( - Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, - ConversionPatternRewriter &rewriter) const { + Value tensor, bool transB, const SharedMemoryObject &smemObj, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const { // smem auto strides = smemObj.strides; @@ -1467,17 +1485,9 @@ Value DotOpMmaV1ConversionHelper::loadB( Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isBRow = order[0] != 0; - bool isBVec4 = isBRow && shape[order[0]] <= 16; - // TODO[Superjomn]: Support the case when isBVec4=false later - // Currently, we only support ld.v2, for the mma layout varies with different - // ld vector width. - isBVec4 = true; - int packSize1 = (isBRow && !isBVec4) ? 2 : 1; - SmallVector fpw({2, 2, 1}); - SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 - SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 - int vecB = sharedLayout.getVec(); + BParam param(isBRow); + int vecB = sharedLayout.getVec(); Value strideBN = isBRow ? i32_val(1) : strides[1]; Value strideBK = isBRow ? strides[0] : i32_val(1); Value strideB0 = isBRow ? strideBN : strideBK; @@ -1485,11 +1495,8 @@ Value DotOpMmaV1ConversionHelper::loadB( int strideRepN = wpt[1] * fpw[1] * 8; int strideRepK = 1; - // TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp. - bool transB = false; - - auto [_0, _1, offsetBN, offsetBK] = - computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); + auto [_0, _1, offsetBN, offsetBK] = computeOffsets( + thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc); if (transB) { std::swap(order[0], order[1]); std::swap(shape[0], shape[1]); @@ -1556,7 +1563,7 @@ Value DotOpMmaV1ConversionHelper::loadB( } }; - unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); + unsigned numN = getNumN(shape, order); for (unsigned k = 0; k < NK; k += 4) for (unsigned n = 0; n < numN / 2; ++n) { if (!hbs.count({n, k})) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 534ea9b01..91485c3e5 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1730,9 +1730,9 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto rhsVals = getElementsFromStruct(loc, adaptor.rhs(), rewriter); // concatenate (and potentially reorder) values SmallVector retVals; - for(Value v: lhsVals) + for (Value v : lhsVals) retVals.push_back(v); - for(Value v: rhsVals) + for (Value v : rhsVals) retVals.push_back(v); // pack and replace Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); @@ -3408,14 +3408,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( } else if (!isOuter && mmaLayout.getVersion() == 1 && isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); - if (dotOperandLayout.getOpIdx() == 0) { - // operand $a - res = - helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc, rewriter); - } else if (dotOperandLayout.getOpIdx() == 1) { - // operand $b - res = - helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc, rewriter); + if (dotOperandLayout.getOpIdx() == 0) { // operand $a + // TODO[Superjomn]: transA is not available here. + bool transA = false; + res = helper.loadA(src, transA, smemObj, getThreadId(rewriter, loc), loc, + rewriter); + } else if (dotOperandLayout.getOpIdx() == 1) { // operand $b + // TODO[Superjomn]: transB is not available here. + bool transB = false; + res = helper.loadB(src, transB, smemObj, getThreadId(rewriter, loc), loc, + rewriter); } } else { assert(false && "Unsupported mma layout found"); @@ -3537,6 +3539,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, bool isBRow = BOrder[0] != 0; bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes bool isBVec4 = isBRow && BShape[isBRow] <= 16; + // TODO[Superjomn]: ld.v4 is not supported. + isAVec4 = true; + isBVec4 = true; + int packSize0 = (isARow || isAVec4) ? 1 : 2; int packSize1 = (isBRow && !isBVec4) ? 2 : 1; SmallVector fpw({2, 2, 1}); @@ -3549,7 +3555,7 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, DotOpMmaV1ConversionHelper helper(mmaLayout); unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]); - unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]); + unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[1]); unsigned NK = AShape[1]; auto has = helper.extractLoadedOperand(loadedA, NK, rewriter); @@ -3836,7 +3842,8 @@ public: llvm::Optional convertTritonTensorType(RankedTensorType type) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); - auto shape = type.getShape(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + if (layout && (layout.isa() || layout.isa() || layout.isa())) { @@ -3899,13 +3906,22 @@ public: if (mmaLayout.getVersion() == 1) { DotOpMmaV1ConversionHelper helper(mmaLayout); + // TODO[Superjomn]: Both transA and transB are not available here. + bool trans = false; + // TODO[Superjomn]: The order of A and B are not available here. + SmallVector order({1, 0}); + if (trans) { + std::swap(shape[0], shape[1]); + std::swap(order[0], order[1]); + } + if (dotOpLayout.getOpIdx() == 0) { // $a - int elems = helper.numElemsPerThreadA(type); + int elems = helper.numElemsPerThreadA(shape, order); Type x2Ty = vec_ty(elemTy, 2); return struct_ty(SmallVector(elems, x2Ty)); } if (dotOpLayout.getOpIdx() == 1) { // $b - int elems = helper.numElemsPerThreadB(type); + int elems = helper.numElemsPerThreadB(shape, order); Type x2Ty = vec_ty(elemTy, 2); return struct_ty(SmallVector(elems, x2Ty)); }