[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)
This commit is contained in:
@@ -72,6 +72,11 @@ struct DotOpMmaV1ConversionHelper {
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
|
||||
// TODO[Superjomn]: Support the case when isAVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
isAVec4 = true;
|
||||
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
@@ -98,6 +103,11 @@ struct DotOpMmaV1ConversionHelper {
|
||||
auto order = getOrder();
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
// TODO[Superjomn]: Support the case when isBVec4=false later
|
||||
// Currently, we only support ld.v2, for the mma layout varies with
|
||||
// different ld vector width.
|
||||
isBVec4 = true;
|
||||
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
SmallVector<int> fpw({2, 2, 1});
|
||||
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
|
||||
@@ -1455,7 +1465,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
|
||||
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0;
|
||||
bool isBVec4 = isBRow && shape[order[0]] <= 16;
|
||||
|
Reference in New Issue
Block a user