From e8994209f41351225232fa1d307c4f463f822305 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Sun, 20 Nov 2022 11:29:09 +0800 Subject: [PATCH] [Triton-MLIR][Backend]fix mma-v2 transpose error (#888) --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 5804c06e3..b1d37a0a5 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3218,8 +3218,8 @@ public: cMatShape = matShape[order[0]]; sMatShape = matShape[order[1]]; - cTileStride = smemStrides[order[0]]; - sTileStride = smemStrides[order[1]]; + cStride = smemStrides[1]; + sStride = smemStrides[0]; // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; @@ -3322,7 +3322,7 @@ public: for (int i = 0; i < numPtr; ++i) { Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); cMatOffI = xor_(cMatOffI, phase); - offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sTileStride)); + offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride)); } return offs; @@ -3358,7 +3358,7 @@ public: Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); cOff = urem(cOff, i32_val(tileShape[order[0]])); sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sTileStride)); + offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride)); } } return offs; @@ -3398,7 +3398,7 @@ public: // To prevent out-of-bound access when tile is too small. cOff = urem(cOff, i32_val(tileShape[order[0]])); sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[ptrOff] = add(cOff, mul(sOff, sTileStride)); + offs[ptrOff] = add(cOff, mul(sOff, sStride)); } } } @@ -3433,7 +3433,7 @@ public: if (canUseLdmatrix) { Value sOffset = - mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sTileStride); + mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride); Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset); PTXBuilder builder; @@ -3467,10 +3467,10 @@ public: Value ptr2 = getPtr(ptrIdx + 1); assert(sMatStride == 1); int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); - Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride); + Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); int sOffsetArrElem = sMatStride * sMatShape; Value sOffsetArrElemVal = - add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride)); + add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); Value elems[4]; Type elemTy = type::f32Ty(ctx); @@ -3510,10 +3510,10 @@ public: assert(sMatStride == 1); int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); - Value sOffsetElemVal = mul(i32_val(sOffsetElem), sTileStride); + Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); int sOffsetArrElem = 1 * (sMatStride * sMatShape); Value sOffsetArrElemVal = - add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sTileStride)); + add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); std::array i8v4Elems; std::array i32Elems; @@ -3581,8 +3581,8 @@ private: int cMatShape; int sMatShape; - Value cTileStride; - Value sTileStride; + Value cStride; + Value sStride; bool needTrans; bool canUseLdmatrix; @@ -4224,7 +4224,8 @@ struct MMA16816ConversionHelper { loadFn = getLoadMatrixFn( tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, + true /*isA*/); } else if (aTensorTy.getEncoding().isa()) { // load from registers, used in gemm fuse // TODO(Superjomn) Port the logic. @@ -4255,7 +4256,8 @@ struct MMA16816ConversionHelper { auto loadFn = getLoadMatrixFn( tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, + false /*isA*/); for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { for (int k = 0; k < numRepK; ++k) @@ -4369,7 +4371,7 @@ private: getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, ArrayRef instrShape, ArrayRef matShape, - Value warpId, ValueTable &vals) const { + Value warpId, ValueTable &vals, bool isA) const { auto tensorTy = tensor.getType().cast(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. @@ -4379,8 +4381,6 @@ private: const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; auto order = sharedLayout.getOrder(); - bool needTrans = kOrder != order[0]; - // the original register_lds2, but discard the prefetch logic. auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { vals[{mn, k}] = val; @@ -4406,7 +4406,7 @@ private: (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, ptrs, helper.getMatType(), helper.getShemPtrTy()); - if (!needTrans) { + if (isA) { ld2(vals, a, b, ha0); ld2(vals, a + 1, b, ha1); ld2(vals, a, b + 1, ha2); @@ -6009,6 +6009,7 @@ struct AtomicRMWOpConversion } atom.o(rmwOp).o(sTy); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); + auto ret = ptxBuilder.launch(rewriter, loc, valueElemTy); for (int ii = 0; ii < vec; ++ii) { resultVals[i * vec + ii] =