[Triton-MLIR][Backend]fix mma-v2 transpose error (#888)

This commit is contained in:
Yan Chunwei
2022-11-20 11:29:09 +08:00
committed by GitHub
parent 8a5647782d
commit e8994209f4

View File

@@ -3218,8 +3218,8 @@ public:
cMatShape = matShape[order[0]];
sMatShape = matShape[order[1]];
cTileStride = smemStrides[order[0]];
sTileStride = smemStrides[order[1]];
cStride = smemStrides[1];
sStride = smemStrides[0];
// rule: k must be the fast-changing axis.
needTrans = kOrder != order[0];
@@ -3322,7 +3322,7 @@ public:
for (int i = 0; i < numPtr; ++i) {
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
cMatOffI = xor_(cMatOffI, phase);
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride));
}
return offs;
@@ -3358,7 +3358,7 @@ public:
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
cOff = urem(cOff, i32_val(tileShape[order[0]]));
sOff = urem(sOff, i32_val(tileShape[order[1]]));
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride));
}
}
return offs;
@@ -3398,7 +3398,7 @@ public:
// To prevent out-of-bound access when tile is too small.
cOff = urem(cOff, i32_val(tileShape[order[0]]));
sOff = urem(sOff, i32_val(tileShape[order[1]]));
offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
offs[ptrOff] = add(cOff, mul(sOff, sStride));
}
}
}
@@ -3433,7 +3433,7 @@ public:
if (canUseLdmatrix) {
Value sOffset =
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride);
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
PTXBuilder builder;
@@ -3467,10 +3467,10 @@ public:
Value ptr2 = getPtr(ptrIdx + 1);
assert(sMatStride == 1);
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride);
int sOffsetArrElem = sMatStride * sMatShape;
Value sOffsetArrElemVal =
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride));
Value elems[4];
Type elemTy = type::f32Ty(ctx);
@@ -3510,10 +3510,10 @@ public:
assert(sMatStride == 1);
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride);
int sOffsetArrElem = 1 * (sMatStride * sMatShape);
Value sOffsetArrElemVal =
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride));
std::array<Value, 4> i8v4Elems;
std::array<Value, 4> i32Elems;
@@ -3581,8 +3581,8 @@ private:
int cMatShape;
int sMatShape;
Value cTileStride;
Value sTileStride;
Value cStride;
Value sStride;
bool needTrans;
bool canUseLdmatrix;
@@ -4224,7 +4224,8 @@ struct MMA16816ConversionHelper {
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/,
true /*isA*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
// TODO(Superjomn) Port the logic.
@@ -4255,7 +4256,8 @@ struct MMA16816ConversionHelper {
auto loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/,
false /*isA*/);
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
for (int k = 0; k < numRepK; ++k)
@@ -4369,7 +4371,7 @@ private:
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
ArrayRef<int> instrShape, ArrayRef<int> matShape,
Value warpId, ValueTable &vals) const {
Value warpId, ValueTable &vals, bool isA) const {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
@@ -4379,8 +4381,6 @@ private:
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();
bool needTrans = kOrder != order[0];
// the original register_lds2, but discard the prefetch logic.
auto ld2 = [](ValueTable &vals, int mn, int k, Value val) {
vals[{mn, k}] = val;
@@ -4406,7 +4406,7 @@ private:
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, helper.getMatType(), helper.getShemPtrTy());
if (!needTrans) {
if (isA) {
ld2(vals, a, b, ha0);
ld2(vals, a + 1, b, ha1);
ld2(vals, a, b + 1, ha2);
@@ -6009,6 +6009,7 @@ struct AtomicRMWOpConversion
}
atom.o(rmwOp).o(sTy);
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
for (int ii = 0; ii < vec; ++ii) {
resultVals[i * vec + ii] =