[Triton-MLIR][BACKEND] make MMAv1 splitk works (#960)

This commit is contained in:
Yan Chunwei
2022-12-07 16:58:38 +08:00
committed by GitHub
parent b2b793dfb5
commit 4eab9dcedf
6 changed files with 58 additions and 24 deletions

View File

@@ -72,6 +72,11 @@ struct DotOpMmaV1ConversionHelper {
bool isARow = order[0] != 0;
bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes
// TODO[Superjomn]: Support the case when isAVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isAVec4 = true;
int packSize0 = (isARow || isAVec4) ? 1 : 2;
SmallVector<int> fpw({2, 2, 1});
@@ -98,6 +103,11 @@ struct DotOpMmaV1ConversionHelper {
auto order = getOrder();
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;
// TODO[Superjomn]: Support the case when isBVec4=false later
// Currently, we only support ld.v2, for the mma layout varies with
// different ld vector width.
isBVec4 = true;
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
SmallVector<int> fpw({2, 2, 1});
SmallVector<int> rep({0, 2 * packSize1, 1}); // pad M with 0
@@ -1455,7 +1465,6 @@ Value DotOpMmaV1ConversionHelper::loadB(
SmallVector<unsigned> order(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
bool isBRow = order[0] != 0;
bool isBVec4 = isBRow && shape[order[0]] <= 16;

View File

@@ -3555,14 +3555,18 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto has = helper.extractLoadedOperand(loadedA, NK, rewriter);
auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter);
// initialize accumulators
// Initialize accumulators with external values, the acc holds the accumulator
// value that is shared between the MMA instructions inside a DotOp, we can
// call the order of the values the accumulator-internal order.
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
size_t resSize = acc.size();
// The resVals holds the final result of the DotOp.
// NOTE The current order of resVals is different from acc, we call it the
// accumulator-external order. and
SmallVector<Value> resVals(resSize);
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
auto getIdx = [&](int m, int n) {
std::vector<size_t> idx{{
(m * 2 + 0) + (n * 4 + 0) * numM, // row0
(m * 2 + 0) + (n * 4 + 1) * numM,
@@ -3573,8 +3577,29 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
(m * 2 + 1) + (n * 4 + 2) * numM, // row3
(m * 2 + 1) + (n * 4 + 3) * numM,
}};
return idx;
};
{ // convert the acc's value from accumuator-external order to
// accumulator-internal order.
SmallVector<Value> accInit(acc.size());
for (unsigned m = 0; m < numM / 2; ++m)
for (unsigned n = 0; n < numN / 2; ++n) {
auto idx = getIdx(m, n);
for (unsigned i = 0; i < 8; ++i)
accInit[idx[i]] = acc[(m * numN / 2 + n) * 8 + i];
}
acc = accInit;
}
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has.at({m, k});
auto hb = hbs.at({n, k});
PTXBuilder builder;
auto idx = getIdx(m, n);
auto *resOprs = builder.newListOperand(8, "=f");
auto *AOprs = builder.newListOperand({
@@ -3606,8 +3631,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
for (unsigned i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, getIntAttr(i));
acc[idx[i]] = elem;
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
// verified before MMA
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};