diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index a50dbb66d..d08a83918 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -973,8 +973,9 @@ struct MMA16816ConversionHelper { if (aTensorTy.getEncoding().isa()) { Value warpM = getWarpM(shape[0]); // load from smem + int wpt = std::min(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM); loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, true /*isA*/); @@ -1016,8 +1017,9 @@ struct MMA16816ConversionHelper { int numRepN = getNumRepN(tensorTy, shape[1]); Value warpN = getWarpN(shape[1]); + int wpt = std::min(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN); auto loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, false /*isA*/);