[Triton-MLIR][Backend]fix mma-v2 transpose error (#888)
This commit is contained in:
@@ -3218,8 +3218,8 @@ public:
|
||||
cMatShape = matShape[order[0]];
|
||||
sMatShape = matShape[order[1]];
|
||||
|
||||
cTileStride = smemStrides[order[0]];
|
||||
sTileStride = smemStrides[order[1]];
|
||||
cStride = smemStrides[1];
|
||||
sStride = smemStrides[0];
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
@@ -3322,7 +3322,7 @@ public:
|
||||
for (int i = 0; i < numPtr; ++i) {
|
||||
Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat));
|
||||
cMatOffI = xor_(cMatOffI, phase);
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride));
|
||||
offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride));
|
||||
}
|
||||
|
||||
return offs;
|
||||
@@ -3358,7 +3358,7 @@ public:
|
||||
Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape)));
|
||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride));
|
||||
offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride));
|
||||
}
|
||||
}
|
||||
return offs;
|
||||
@@ -3398,7 +3398,7 @@ public:
|
||||
// To prevent out-of-bound access when tile is too small.
|
||||
cOff = urem(cOff, i32_val(tileShape[order[0]]));
|
||||
sOff = urem(sOff, i32_val(tileShape[order[1]]));
|
||||
offs[ptrOff] = add(cOff, mul(sOff, sTileStride));
|
||||
offs[ptrOff] = add(cOff, mul(sOff, sStride));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3433,7 +3433,7 @@ public:
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
Value sOffset =
|
||||
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride);
|
||||
mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride);
|
||||
Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset);
|
||||
|
||||
PTXBuilder builder;
|
||||
@@ -3467,10 +3467,10 @@ public:
|
||||
Value ptr2 = getPtr(ptrIdx + 1);
|
||||
assert(sMatStride == 1);
|
||||
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride);
|
||||
int sOffsetArrElem = sMatStride * sMatShape;
|
||||
Value sOffsetArrElemVal =
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride));
|
||||
|
||||
Value elems[4];
|
||||
Type elemTy = type::f32Ty(ctx);
|
||||
@@ -3510,10 +3510,10 @@ public:
|
||||
|
||||
assert(sMatStride == 1);
|
||||
int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride);
|
||||
Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride);
|
||||
int sOffsetArrElem = 1 * (sMatStride * sMatShape);
|
||||
Value sOffsetArrElemVal =
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride));
|
||||
add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride));
|
||||
|
||||
std::array<Value, 4> i8v4Elems;
|
||||
std::array<Value, 4> i32Elems;
|
||||
@@ -3581,8 +3581,8 @@ private:
|
||||
int cMatShape;
|
||||
int sMatShape;
|
||||
|
||||
Value cTileStride;
|
||||
Value sTileStride;
|
||||
Value cStride;
|
||||
Value sStride;
|
||||
|
||||
bool needTrans;
|
||||
bool canUseLdmatrix;
|
||||
@@ -4224,7 +4224,8 @@ struct MMA16816ConversionHelper {
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/,
|
||||
1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/);
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/,
|
||||
true /*isA*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
// load from registers, used in gemm fuse
|
||||
// TODO(Superjomn) Port the logic.
|
||||
@@ -4255,7 +4256,8 @@ struct MMA16816ConversionHelper {
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/,
|
||||
0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/);
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/,
|
||||
false /*isA*/);
|
||||
|
||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
@@ -4369,7 +4371,7 @@ private:
|
||||
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
Value warpId, ValueTable &vals) const {
|
||||
Value warpId, ValueTable &vals, bool isA) const {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
@@ -4379,8 +4381,6 @@ private:
|
||||
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
bool needTrans = kOrder != order[0];
|
||||
|
||||
// the original register_lds2, but discard the prefetch logic.
|
||||
auto ld2 = [](ValueTable &vals, int mn, int k, Value val) {
|
||||
vals[{mn, k}] = val;
|
||||
@@ -4406,7 +4406,7 @@ private:
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
||||
ptrs, helper.getMatType(), helper.getShemPtrTy());
|
||||
|
||||
if (!needTrans) {
|
||||
if (isA) {
|
||||
ld2(vals, a, b, ha0);
|
||||
ld2(vals, a + 1, b, ha1);
|
||||
ld2(vals, a, b + 1, ha2);
|
||||
@@ -6009,6 +6009,7 @@ struct AtomicRMWOpConversion
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
|
||||
auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i * vec + ii] =
|
||||
|
Reference in New Issue
Block a user