[Triton-MLIR][Backend] Fix the definition of MmaEncodingAttr v1, and the output sequence of DotConversion in MMAv1 (#941)

This commit is contained in:
goostavz
2022-12-03 21:12:48 +08:00
committed by GitHub
parent 521ff9ad74
commit 4d64589b22
2 changed files with 47 additions and 40 deletions

View File

@@ -2876,20 +2876,21 @@ private:
} else if (mmaLayout.getVersion() == 1) {
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16));
Value partId = udiv(laneId, _4);
Value partIdDiv4 = udiv(partId, _4);
Value partIdRem4 = urem(partId, _4);
Value partRowOffset = mul(udiv(partIdRem4, _2), _8);
partRowOffset = add(mul(partIdDiv4, _4), partRowOffset);
Value partColOffset = mul(urem(partIdRem4, _2), _8);
Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset);
Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset);
mmaRowIdx[0] = add(urem(laneId, _2), rowOffset);
Value laneIdDiv16 = udiv(laneId, _16);
Value laneIdRem16 = urem(laneId, _16);
Value laneIdRem2 = urem(laneId, _2);
Value laneIdRem16Div8 = udiv(laneIdRem16, _8);
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
mmaRowIdx[0] =
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
laneIdRem2);
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
mmaColIdx[0] = add(mul(udiv(urem(laneId, _4), _2), _2), colOffset);
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
mmaColIdx[1] = add(mmaColIdx[0], _1);
mmaColIdx[2] = add(mmaColIdx[0], _4);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(5));
mmaColIdx[2] = add(mmaColIdx[0], _8);
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
} else {
llvm_unreachable("Unexpected MMALayout version");
}
@@ -3543,6 +3544,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
// initialize accumulators
SmallVector<Value> acc = getElementsFromStruct(loc, loadedC, rewriter);
size_t resSize = acc.size();
SmallVector<Value> resVals(resSize);
auto callMMA = [&](unsigned m, unsigned n, unsigned k) {
auto ha = has[{m, k}];
@@ -3586,8 +3589,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto getIntAttr = [&](int v) {
return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)});
};
for (unsigned i = 0; i < 8; i++)
acc[idx[i]] = extract_val(f32_ty, res, getIntAttr(i));
for (unsigned i = 0; i < 8; i++) {
Value elem = extract_val(f32_ty, res, getIntAttr(i));
acc[idx[i]] = elem;
// TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been
// verified before MMA
resVals[(m * numN / 2 + n) * 8 + i] = elem;
}
};
for (unsigned k = 0; k < NK; k += 4)
@@ -3596,12 +3605,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
callMMA(m, n, k);
}
// replace with new packed result
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(acc.size(), type::f32Ty(ctx)));
Value res = getStructFromElements(loc, acc, rewriter, structTy);
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
rewriter.replaceOp(op, res);
return success();
}