[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)

This commit is contained in:
Yan Chunwei
2022-12-07 16:58:38 +08:00
committed by GitHub
parent b2b793dfb5
commit 4eab9dcedf
6 changed files with 58 additions and 24 deletions

View File

@@ -72,6 +72,11 @@ struct DotOpMmaV1ConversionHelper {
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isAVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
SmallVector<int> fpw({2, 2, 1});
@@ -98,6 +103,11 @@ struct DotOpMmaV1ConversionHelper {
auto order = getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isBVec4 = true;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
@@ -1455,7 +1465,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;