From 66c36c4378621a0997422b5f0d766869dd97fe4d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 29 Nov 2022 17:56:45 +0100 Subject: [PATCH] [BACKEND] Fixed bounds-wrapping issues (#926) This fixes an issue that led to out-of-bounds shared memory accesses on small matrices --- lib/Conversion/TritonGPUToLLVM/DotHelpers.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h index a50dbb66d..d08a83918 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -973,8 +973,9 @@ struct MMA16816ConversionHelper { if (aTensorTy.getEncoding().isa()) { Value warpM = getWarpM(shape[0]); // load from smem + int wpt = std::min(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM); loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, true /*isA*/); @@ -1016,8 +1017,9 @@ struct MMA16816ConversionHelper { int numRepN = getNumRepN(tensorTy, shape[1]); Value warpN = getWarpN(shape[1]); + int wpt = std::min(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN); auto loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, false /*isA*/);