diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index a14691c6c..d79632c55 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2239,16 +2239,16 @@ struct AllocTensorOpConversion smemBase = bitcast(smemBase, elemPtrTy); auto order = resultTy.getEncoding().cast().getOrder(); // workaround for 3D tensors - // TODO: We need to modify the pipeline pass to give a proper shared encoding to 3D tensors + // TODO: We need to modify the pipeline pass to give a proper shared + // encoding to 3D tensors SmallVector newOrder; - if (resultTy.getShape().size() == 3) + if (resultTy.getShape().size() == 3) newOrder = {1 + order[0], 1 + order[1], 0}; else newOrder = SmallVector(order.begin(), order.end()); - - auto smemObj = - SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, loc, rewriter); + auto smemObj = SharedMemoryObject(smemBase, resultTy.getShape(), newOrder, + loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); @@ -2882,6 +2882,8 @@ private: SmallVector multiDimWarpId(2); multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0])); + multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16)); + multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8)); Value four = idx_val(4); Value mmaGrpId = udiv(laneId, four); Value mmaGrpIdP8 = add(mmaGrpId, idx_val(8)); @@ -3661,7 +3663,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { // Here we assume the DotOp's operands always comes from shared memory. auto AShape = A.getType().cast().getShape(); - size_t reduceAxis = 1; + size_t reduceAxis = op.transA() ? 0 : 1; unsigned K = AShape[reduceAxis]; bool isOuter = K == 1; @@ -4124,8 +4126,9 @@ private: struct MMA16816ConversionHelper { MmaEncodingAttr mmaLayout; ArrayRef wpt; + SmallVector properWpt; - Value thread, lane, warp, warpMN, warpN, warpM; + Value thread, lane, warp; DotOpMmaV2ConversionHelper helper; ConversionPatternRewriter &rewriter; @@ -4135,23 +4138,34 @@ struct MMA16816ConversionHelper { using ValueTable = std::map, Value>; - MMA16816ConversionHelper(MmaEncodingAttr mmaLayout, Value thread, - ConversionPatternRewriter &rewriter, + // dotOperand: type of either one operand of dotOp. + MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, + Value thread, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, Location loc) : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter), loc(loc), - ctx(mmaLayout.getContext()) { - wpt = mmaLayout.getWarpsPerCTA(); + ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { + helper.deduceMmaType(dotOperand); Value _32 = i32_val(32); lane = urem(thread, _32); warp = udiv(thread, _32); - warpMN = udiv(warp, i32_val(wpt[0])); - warpM = urem(warp, i32_val(wpt[0])); - warpN = urem(warpMN, i32_val(wpt[1])); } - // Get the mmaInstrShape from either $a or $b. + // Get a warpId for M axis. + Value getWarpM(int M) const { + auto matShape = helper.getMmaMatShape(); + return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0])); + } + + // Get a warpId for N axis. + Value getWarpN(int N) const { + auto matShape = helper.getMmaMatShape(); + Value warpMN = udiv(warp, i32_val(wpt[0])); + return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1])); + } + + // Get the mmaInstrShape deducing either from $a or $b. std::tuple getMmaInstrShape(Type operand) const { helper.deduceMmaType(operand); auto mmaInstrShape = helper.getMmaInstrShape(); @@ -4161,6 +4175,7 @@ struct MMA16816ConversionHelper { return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK); } + // Get the mmaMatShape deducing either from $a or $b. std::tuple getMmaMatShape(Type operand) const { helper.deduceMmaType(operand); auto matShape = helper.getMmaMatShape(); @@ -4210,28 +4225,28 @@ struct MMA16816ConversionHelper { } // Get number of elements per thread for $a operand. - static size_t getANumElemsPerThread(RankedTensorType operand, - ArrayRef wpt) { + static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) { auto shape = operand.getShape(); - int repM = getNumRepM(operand, shape[0], wpt[0]); + int repM = getNumRepM(operand, shape[0], wpt); int repK = getNumRepK_(operand, shape[1]); return 4 * repM * repK; } // Get number of elements per thread for $b operand. - static size_t getBNumElemsPerThread(RankedTensorType operand, - ArrayRef wpt) { + static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) { auto shape = operand.getShape(); int repK = getNumRepK_(operand, shape[0]); - int repN = getNumRepN(operand, shape[1], wpt[1]); + int repN = getNumRepN(operand, shape[1], wpt); return 4 * std::max(repN / 2, 1) * repK; } // Loading $a from smem to registers, returns a LLVM::Struct. Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { auto aTensorTy = tensor.getType().cast(); - auto shape = aTensorTy.getShape(); + auto layout = aTensorTy.getEncoding().cast(); + SmallVector shape(aTensorTy.getShape().begin(), + aTensorTy.getShape().end()); ValueTable ha; std::function loadFn; auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); @@ -4241,6 +4256,7 @@ struct MMA16816ConversionHelper { int numRepK = getNumRepK(aTensorTy, shape[1]); if (aTensorTy.getEncoding().isa()) { + Value warpM = getWarpM(shape[0]); // load from smem loadFn = getLoadMatrixFn( tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, @@ -4268,12 +4284,17 @@ struct MMA16816ConversionHelper { Value loadB(Value tensor, const SharedMemoryObject &smemObj) { ValueTable hb; auto tensorTy = tensor.getType().cast(); - auto shape = tensorTy.getShape(); + auto layout = tensorTy.getEncoding().cast(); + + SmallVector shape(tensorTy.getShape().begin(), + tensorTy.getShape().end()); + auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy); auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy); int numRepK = getNumRepK(tensorTy, shape[0]); int numRepN = getNumRepN(tensorTy, shape[1]); + Value warpN = getWarpN(shape[1]); auto loadFn = getLoadMatrixFn( tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, @@ -4319,7 +4340,11 @@ struct MMA16816ConversionHelper { auto aTensorTy = a.getType().cast(); auto dTensorTy = d.getType().cast(); - auto aShape = aTensorTy.getShape(); + SmallVector aShape(aTensorTy.getShape().begin(), + aTensorTy.getShape().end()); + if (op.transA()) + std::swap(aShape[0], aShape[1]); + auto dShape = dTensorTy.getShape(); // shape / shape_per_cta @@ -4602,9 +4627,9 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( Value res; if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2 - MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), - rewriter, getTypeConverter(), - op.getLoc()); + MMA16816ConversionHelper mmaHelper(src.getType(), mmaLayout, + getThreadId(rewriter, loc), rewriter, + getTypeConverter(), op.getLoc()); if (dotOperandLayout.getOpIdx() == 0) { // operand $a @@ -4695,12 +4720,15 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, .cast() .getEncoding() .cast(); - MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), - rewriter, getTypeConverter(), loc); Value A = op.a(); Value B = op.b(); Value C = op.c(); + + MMA16816ConversionHelper mmaHelper(A.getType(), mmaLayout, + getThreadId(rewriter, loc), rewriter, + getTypeConverter(), loc); + auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); @@ -5532,13 +5560,13 @@ public: if (mmaLayout.getVersion() == 2) { if (dotOpLayout.getOpIdx() == 0) { // $a int elems = - MMA16816ConversionHelper::getANumElemsPerThread(type, wpt); + MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]); return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(elems, vecTy)); } if (dotOpLayout.getOpIdx() == 1) { // $b int elems = - MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); + MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]); return struct_ty(SmallVector(elems, vecTy)); } } @@ -6159,10 +6187,9 @@ private: if (srcBlocked && dstDotOp) { auto tmpType = RankedTensorType::get( dstType.getShape(), dstType.getElementType(), - triton::gpu::SharedEncodingAttr::get(mod.getContext(), dstDotOp, - srcType.getShape(), - getOrder(srcBlocked), - srcType.getElementType())); + triton::gpu::SharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), + getOrder(srcBlocked), srcType.getElementType())); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getOperand()); auto newConvert = builder.create( diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index f06d3eda0..4deff76a8 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -27,8 +27,6 @@ def matmul_no_scf_kernel( c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn tl.store(c_ptrs, c) -# TODO: num_warps could only be 4 for now - @pytest.mark.parametrize('SHAPE,NUM_WARPS,TRANS_A,TRANS_B', [ (shape, num_warps, trans_a, trans_b) @@ -172,6 +170,7 @@ def get_proper_err(a, b, golden): # Non-forloop [64, 32, 64, 4, 64, 32, 64, False, False], [128, 64, 128, 4, 128, 64, 128, False, False], + [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue # K-Forloop [64, 32, 128, 4, 64, 32, 64, False, False], [128, 16, 128, 4, 128, 16, 32, False, False], @@ -186,6 +185,7 @@ def get_proper_err(a, b, golden): [128, 256, 128, 4, 128, 256, 32, False, False], [256, 128, 64, 4, 256, 128, 16, False, False], [128, 64, 128, 4, 128, 64, 32, False, False], + # [16, 16, 64, 4, 16, 16, 16, False, False], # TODO failed due to pipeline pass # trans [128, 64, 128, 4, 128, 64, 32, True, False], [128, 64, 128, 4, 128, 64, 32, False, True],