[BACKEND] Fixed bounds-wrapping issues (#926)

This fixes an issue that led to out-of-bounds shared memory accesses on
small matrices
This commit is contained in:
Philippe Tillet
2022-11-29 17:56:45 +01:00
committed by GitHub
parent 661be523c0
commit 66c36c4378

View File

@@ -973,8 +973,9 @@ struct MMA16816ConversionHelper {
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
Value warpM = getWarpM(shape[0]);
// load from smem
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM);
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
tensor, smemObj, mmaLayout, wpt /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/,
true /*isA*/);
@@ -1016,8 +1017,9 @@ struct MMA16816ConversionHelper {
int numRepN = getNumRepN(tensorTy, shape[1]);
Value warpN = getWarpN(shape[1]);
int wpt = std::min<int>(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN);
auto loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
tensor, smemObj, mmaLayout, wpt /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/,
false /*isA*/);