From 4d64589b22cfb0e4f8dd165c34148b9a633cda58 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Sat, 3 Dec 2022 21:12:48 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix the definition of MmaEncodingAttr v1, and the output sequence of DotConversion in MMAv1 (#941) --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 44 +++++++++---------- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 43 ++++++++++-------- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index d52c8985c..bc848eaf4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -293,7 +293,7 @@ partitioned between warps. // -------------------------------- version = 1 --------------------------- // For first-gen tensor cores, the implicit warpTileSize is [16, 16]. -Information about this layout can be found in the official PTX documentation +Note: the layout is different from the recommended in PTX ISA https://docs.nvidia.com/cuda/parallel-thread-execution/index.html (mma.884 section, FP32 accumulator). @@ -301,29 +301,29 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: warp 0 --------------------------------/\------------------------------- -[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ] -[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ] -[ 0 0 2 2 0 0 2 2 4 4 6 6 4 4 6 6 ] -[ 1 1 3 3 1 1 3 3 5 5 7 7 5 5 7 7 ] -[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22] -[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23] -[ 16 16 18 18 16 16 18 18 20 20 22 22 20 20 22 22] -[ 17 17 19 19 17 17 19 19 21 21 23 23 21 21 23 23] -[ 8 8 10 10 8 8 10 10 12 12 14 14 12 12 14 14] -[ 9 9 11 11 9 9 11 11 13 13 15 15 13 13 15 15] -[ .............................................................. -[ .............................................................. -[ 24 24 26 26 24 24 26 26 28 28 30 30 28 28 30 30] -[ 25 25 27 27 25 25 27 27 29 29 31 31 29 29 31 31] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] - warp 1 = warp0 + 32 + warp 1 = warp0 + 32 --------------------------------/\------------------------------- -[ 32 32 34 34 32 32 34 34 36 36 38 38 36 36 38 38] -[ 33 33 35 35 33 33 35 35 37 37 39 39 37 37 39 39] -[ .............................................................. -[ .............................................................. -[ 56 56 58 58 56 56 58 58 60 60 62 62 60 60 62 62] -[ 57 57 59 59 57 57 59 59 61 61 63 63 61 61 63 63] +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + // -------------------------------- version = 2 --------------------------- // diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 68fc4bac3..d8d7648c4 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2876,20 +2876,21 @@ private: } else if (mmaLayout.getVersion() == 1) { multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 16)); - Value partId = udiv(laneId, _4); - Value partIdDiv4 = udiv(partId, _4); - Value partIdRem4 = urem(partId, _4); - Value partRowOffset = mul(udiv(partIdRem4, _2), _8); - partRowOffset = add(mul(partIdDiv4, _4), partRowOffset); - Value partColOffset = mul(urem(partIdRem4, _2), _8); - Value colOffset = add(mul(multiDimWarpId[0], _16), partColOffset); - Value rowOffset = add(mul(multiDimWarpId[1], _16), partRowOffset); - mmaRowIdx[0] = add(urem(laneId, _2), rowOffset); + Value laneIdDiv16 = udiv(laneId, _16); + Value laneIdRem16 = urem(laneId, _16); + Value laneIdRem2 = urem(laneId, _2); + Value laneIdRem16Div8 = udiv(laneIdRem16, _8); + Value laneIdRem16Div4 = udiv(laneIdRem16, _4); + Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2); + Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2); + mmaRowIdx[0] = + add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)), + laneIdRem2); mmaRowIdx[1] = add(mmaRowIdx[0], _2); - mmaColIdx[0] = add(mul(udiv(urem(laneId, _4), _2), _2), colOffset); + mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2)); mmaColIdx[1] = add(mmaColIdx[0], _1); - mmaColIdx[2] = add(mmaColIdx[0], _4); - mmaColIdx[3] = add(mmaColIdx[0], idx_val(5)); + mmaColIdx[2] = add(mmaColIdx[0], _8); + mmaColIdx[3] = add(mmaColIdx[0], idx_val(9)); } else { llvm_unreachable("Unexpected MMALayout version"); } @@ -3543,6 +3544,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, // initialize accumulators SmallVector acc = getElementsFromStruct(loc, loadedC, rewriter); + size_t resSize = acc.size(); + SmallVector resVals(resSize); auto callMMA = [&](unsigned m, unsigned n, unsigned k) { auto ha = has[{m, k}]; @@ -3586,8 +3589,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto getIntAttr = [&](int v) { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; - for (unsigned i = 0; i < 8; i++) - acc[idx[i]] = extract_val(f32_ty, res, getIntAttr(i)); + + for (unsigned i = 0; i < 8; i++) { + Value elem = extract_val(f32_ty, res, getIntAttr(i)); + acc[idx[i]] = elem; + // TODO[goostavz]: double confirm this when m/n/k = [32, 32, x] has been + // verified before MMA + resVals[(m * numN / 2 + n) * 8 + i] = elem; + } }; for (unsigned k = 0; k < NK; k += 4) @@ -3596,12 +3605,10 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, callMMA(m, n, k); } - // replace with new packed result Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(acc.size(), type::f32Ty(ctx))); - Value res = getStructFromElements(loc, acc, rewriter, structTy); + ctx, SmallVector(resSize, type::f32Ty(ctx))); + Value res = getStructFromElements(loc, resVals, rewriter, structTy); rewriter.replaceOp(op, res); - return success(); }