diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index c77a29c21..ad64e3ddd 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -457,6 +457,25 @@ struct ConvertTritonGPUOpToLLVMPatternBase { } return results; } + + static SharedMemoryObject + getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); + return SharedMemoryObject(/*base=*/elems[0], + /*strides=*/{elems.begin() + 1, elems.end()}); + } + + static Value + getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + ConversionPatternRewriter &rewriter) { + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + return getStructFromElements(loc, elems, rewriter, structTy); + } }; template @@ -830,25 +849,6 @@ public: return base; } - static SharedMemoryObject - getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, - ConversionPatternRewriter &rewriter) { - auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); - return SharedMemoryObject(/*base=*/elems[0], - /*strides=*/{elems.begin() + 1, elems.end()}); - } - - static Value - getStructFromSharedMemoryObject(Location loc, - const SharedMemoryObject &smemObj, - ConversionPatternRewriter &rewriter) { - auto elems = smemObj.getElems(); - auto types = smemObj.getTypes(); - auto structTy = - LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); - return getStructFromElements(loc, elems, rewriter, structTy); - } - protected: const Allocation *allocation; Value smem; @@ -3566,7 +3566,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { } if (op.getType().cast().getElementType().isF32() && - A.getType().cast().getElementType().isF32()) + A.getType().cast().getElementType().isF32() && + !op.allowTF32()) return convertFMADot(op, adaptor, rewriter); llvm::report_fatal_error( @@ -4385,6 +4386,90 @@ private: } }; +// Helper for conversion of FMA DotOp. +struct DotOpFMAConversionHelper { + Attribute layout; + MLIRContext *ctx{}; + + using ValueTable = std::map, Value>; + + explicit DotOpFMAConversionHelper(Attribute layout) + : layout(layout), ctx(layout.getContext()) {} + + SmallVector getThreadIds(Value threadId, + ArrayRef shapePerCTA, + ArrayRef order, + ConversionPatternRewriter &rewriter, + Location loc) const; + + Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; + + Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; + + ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, + int sizePerThread, + ConversionPatternRewriter &rewriter, + Location loc) const; + + Value getStructFromValueTable(ValueTable vals, + ConversionPatternRewriter &rewriter, + Location loc) const { + SmallVector elemTypes(vals.size(), f32_ty); + SmallVector elems; + elems.reserve(vals.size()); + for (auto &item : vals) { + elems.push_back(item.second); + } + + Type structTy = struct_ty(elemTypes); + return getStructFromElements(loc, elems, rewriter, structTy); + } + // get number of elements per thread for $a or $b. + static int getNumElemsPerThread(ArrayRef shape, + DotOperandEncodingAttr dotOpLayout) { + auto blockedLayout = dotOpLayout.getParent().cast(); + auto shapePerCTA = getShapePerCTA(blockedLayout); + auto sizePerThread = getSizePerThread(blockedLayout); + auto order = blockedLayout.getOrder(); + + // TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it + // if not. + int K = dotOpLayout.getOpIdx() == 0 ? shape[1] : shape[0]; + int otherDim = dotOpLayout.getOpIdx() == 1 ? shape[1] : shape[0]; + + bool isM = dotOpLayout.getOpIdx() == 0; + int shapePerCTAMN = getShapePerCTAForMN(blockedLayout, isM); + int sizePerThreadMN = getsizePerThreadForMN(blockedLayout, isM); + return K * std::max(otherDim / shapePerCTAMN, 1) * sizePerThreadMN; + } + + // Get shapePerCTA for M or N axis. + static int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTA = getShapePerCTA(layout); + + int mShapePerCTA = + order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + int nShapePerCTA = + order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + return isM ? mShapePerCTA : nShapePerCTA; + } + + // Get sizePerThread for M or N axis. + static int getsizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = getSizePerThread(layout); + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; + } +}; + Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, @@ -4393,14 +4478,15 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( Value src = op.src(); Value dst = op.result(); auto dstTensorTy = dst.getType().cast(); - // TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it - // is an attribute of DotOp. + // TODO[Superjomn]: allowTF32 is not accessible here for it is an attribute of + // an Op instance. bool allowTF32 = false; bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32, mmaLayout.getVersion()); auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.src(), rewriter); Value res; + if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2 MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), rewriter, getTypeConverter(), @@ -4459,7 +4545,25 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( } else if (auto blockedLayout = dotOperandLayout.getParent() .dyn_cast_or_null()) { - assert(false && "Blocked layout is not supported yet"); + // TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it + // is an attribute of DotOp. + bool allowTF32 = false; + bool isFMADot = dstTensorTy.getElementType().isF32() && !allowTF32; + if (isFMADot) { + auto dotOpLayout = + dstTensorTy.getEncoding().cast(); + auto blockedLayout = dotOpLayout.getParent().cast(); + DotOpFMAConversionHelper helper(blockedLayout); + auto thread = getThreadId(rewriter, loc); + if (dotOpLayout.getOpIdx() == 0) { // $a + res = helper.loadA(src, adaptor.src(), blockedLayout, thread, loc, + rewriter); + } else { // $b + res = helper.loadB(src, adaptor.src(), blockedLayout, thread, loc, + rewriter); + } + } else + assert(false && "Unsupported dot operand layout found"); } else { assert(false && "Unsupported dot operand layout found"); } @@ -4925,6 +5029,183 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand( return rcds; } +Value DotOpFMAConversionHelper::loadA( + Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc, + ConversionPatternRewriter &rewriter) const { + auto aTensorTy = A.getType().cast(); + auto aLayout = aTensorTy.getEncoding().cast(); + auto aShape = aTensorTy.getShape(); + + auto aOrder = aLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isARow = aOrder[0] == 1; + + int strideAM = isARow ? aShape[1] : 1; + int strideAK = isARow ? 1 : aShape[0]; + int strideA0 = isARow ? strideAK : strideAM; + int strideA1 = isARow ? strideAM : strideAK; + int lda = isARow ? strideAM : strideAK; + int aNumPtr = 8; + int bNumPtr = 8; + int NK = aShape[1]; + + auto shapePerCTA = getShapePerCTA(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); + + Value threadIdM = threadIds[0]; + Value threadIdN = threadIds[1]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + } + + auto aSmem = + ConvertTritonGPUOpToLLVMPatternBase::getSharedMemoryObjectFromStruct( + loc, llA, rewriter); + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]); + + ValueTable has; + int M = aShape[aOrder[1]]; + + int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/); + int mSizePerThread = getsizePerThreadForMN(dLayout, true /*isM*/); + + for (unsigned k = 0; k < NK; ++k) { + for (unsigned m = 0; m < M; m += mShapePerCTA) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + if (!has.count({m + mm, k})) { + Value pa = gep(f32PtrTy, aPtrs[0], + i32_val((m + mm) * strideAM + k * strideAK)); + Value va = load(pa); + has[{m + mm, k}] = va; + } + } + + return getStructFromValueTable(has, rewriter, loc); +} + +Value DotOpFMAConversionHelper::loadB( + Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto bTensorTy = B.getType().cast(); + auto bLayout = bTensorTy.getEncoding().cast(); + auto bShape = bTensorTy.getShape(); + + auto bOrder = bLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + int strideBN = isBRow ? 1 : bShape[0]; + int strideBK = isBRow ? bShape[1] : 1; + int strideB0 = isBRow ? strideBN : strideBK; + int strideB1 = isBRow ? strideBK : strideBN; + int ldb = isBRow ? strideBK : strideBN; + int bNumPtr = 8; + int NK = bShape[0]; + + auto shapePerCTA = getShapePerCTA(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); + Value threadIdN = threadIds[1]; + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + } + + auto bSmem = + ConvertTritonGPUOpToLLVMPatternBase::getSharedMemoryObjectFromStruct( + loc, llB, rewriter); + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]); + + int N = bShape[bOrder[0]]; + ValueTable hbs; + + int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/); + int nSizePerThread = getsizePerThreadForMN(dLayout, false /*isM*/); + + for (unsigned k = 0; k < NK; ++k) + for (unsigned n = 0; n < N; n += nShapePerCTA) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + Value pb = gep(f32PtrTy, bPtrs[0], + i32_val((n + nn) * strideBN + k * strideBK)); + Value vb = load(pb); + hbs[{n + nn, k}] = vb; + } + + return getStructFromValueTable(hbs, rewriter, loc); +} + +DotOpFMAConversionHelper::ValueTable +DotOpFMAConversionHelper::getValueTableFromStruct( + Value val, int K, int n0, int shapePerCTA, int sizePerThread, + ConversionPatternRewriter &rewriter, Location loc) const { + ValueTable res; + auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( + loc, val, rewriter); + int id = 0; + std::set> keys; // ordered + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + keys.insert({m + mm, k}); + } + } + + for (auto &key : llvm::enumerate(keys)) { + res[key.value()] = elems[key.index()]; + } + + return res; +} +SmallVector DotOpFMAConversionHelper::getThreadIds( + Value threadId, ArrayRef shapePerCTA, + ArrayRef order, ConversionPatternRewriter &rewriter, + Location loc) const { + int dim = order.size(); + SmallVector threadIds(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTA[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTA[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + return threadIds; +} + LogicalResult DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -4948,120 +5229,68 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, auto bShape = bTensorTy.getShape(); auto cShape = cTensorTy.getShape(); - auto aLayout = aTensorTy.getEncoding().cast(); - auto bLayout = bTensorTy.getEncoding().cast(); - auto cLayout = cTensorTy.getEncoding().cast(); - auto dLayout = dTensorTy.getEncoding().cast(); - - auto aOrder = aLayout.getOrder(); - auto bOrder = bLayout.getOrder(); - + ValueTable has, hbs; + int mShapePerCTA{-1}, nShapePerCTA{-1}; + int mSizePerThread{-1}, nSizePerThread{-1}; + ArrayRef aOrder, bOrder; + Value llA, llB; + BlockedEncodingAttr dLayout = + dTensorTy.getEncoding().cast(); auto order = dLayout.getOrder(); + auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); - bool isARow = aOrder[0] == 1; - bool isBRow = bOrder[0] == 1; + DotOpFMAConversionHelper helper(dLayout); + if (auto aDotOpLayout = + aTensorTy.getEncoding() + .dyn_cast()) { // get input from + // convert_layout + auto bDotOpLayout = + bTensorTy.getEncoding().dyn_cast(); + auto aLayout = aDotOpLayout.getParent().cast(); + auto bLayout = bDotOpLayout.getParent().cast(); - int strideAM = isARow ? aShape[1] : 1; - int strideAK = isARow ? 1 : aShape[0]; - int strideBN = isBRow ? 1 : bShape[0]; - int strideBK = isBRow ? bShape[1] : 1; - int strideA0 = isARow ? strideAK : strideAM; - int strideA1 = isARow ? strideAM : strideAK; - int strideB0 = isBRow ? strideBN : strideBK; - int strideB1 = isBRow ? strideBK : strideBN; - int lda = isARow ? strideAM : strideAK; - int ldb = isBRow ? strideBK : strideBN; - int aPerPhase = aLayout.getPerPhase(); - int aMaxPhase = aLayout.getMaxPhase(); - int bPerPhase = bLayout.getPerPhase(); - int bMaxPhase = bLayout.getMaxPhase(); - int aNumPtr = 8; - int bNumPtr = 8; - int NK = aShape[1]; - - auto shapePerCTA = getShapePerCTA(dLayout); + assert(bLayout); + llA = adaptor.a(); + llB = adaptor.b(); + } else if (auto aLayout = + aTensorTy.getEncoding() + .dyn_cast()) { // load input from smem + auto bLayout = bTensorTy.getEncoding().dyn_cast(); + assert(bLayout); + Value thread = getThreadId(rewriter, loc); + llA = helper.loadA(A, adaptor.a(), dLayout, thread, loc, rewriter); + llB = helper.loadB(B, adaptor.b(), dLayout, thread, loc, rewriter); + } auto sizePerThread = getSizePerThread(dLayout); + auto shapePerCTA = getShapePerCTA(dLayout); - Value _0 = i32_val(0); + int K = aShape[1]; + int M = aShape[0]; + int N = bShape[1]; - Value mContig = i32_val(sizePerThread[order[1]]); - Value nContig = i32_val(sizePerThread[order[0]]); + mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - // threadId in blocked layout - SmallVector threadIds; - { - int dim = cShape.size(); - threadIds.resize(dim); - for (unsigned k = 0; k < dim - 1; k++) { - Value dimK = i32_val(shapePerCTA[order[k]]); - Value rem = urem(threadId, dimK); - threadId = udiv(threadId, dimK); - threadIds[order[k]] = rem; - } - Value dimK = i32_val(shapePerCTA[order[dim - 1]]); - threadIds[order[dim - 1]] = urem(threadId, dimK); - } + has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA, mSizePerThread, + rewriter, loc); + hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA, nSizePerThread, + rewriter, loc); - Value threadIdM = threadIds[0]; - Value threadIdN = threadIds[1]; - - Value offA0 = isARow ? _0 : mul(threadIdM, mContig); - Value offA1 = isARow ? mul(threadIdM, mContig) : _0; - SmallVector aOff(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) { - aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); - } - - Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; - Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); - SmallVector bOff(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) { - bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); - } - - auto aSmem = getSharedMemoryObjectFromStruct(loc, adaptor.a(), rewriter); - auto bSmem = getSharedMemoryObjectFromStruct(loc, adaptor.b(), rewriter); - - Type f32PtrTy = ptr_ty(f32_ty); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) - aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]); - - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) - bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]); - - ValueTable has, hbs; - auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); SmallVector ret = cc; - // is this compatible with blocked layout? - - for (unsigned k = 0; k < NK; k++) { + for (unsigned k = 0; k < K; k++) { int z = 0; - for (unsigned i = 0; i < cShape[order[1]]; i += shapePerCTA[order[1]]) - for (unsigned j = 0; j < cShape[order[0]]; j += shapePerCTA[order[0]]) - for (unsigned ii = 0; ii < sizePerThread[order[1]]; ++ii) - for (unsigned jj = 0; jj < sizePerThread[order[0]]; ++jj) { - unsigned m = order[0] == 1 ? i : j; - unsigned n = order[0] == 1 ? j : i; - unsigned mm = order[0] == 1 ? ii : jj; - unsigned nn = order[0] == 1 ? jj : ii; - if (!has.count({m + mm, k})) { - Value pa = gep(f32PtrTy, aPtrs[0], - i32_val((m + mm) * strideAM + k * strideAK)); - Value va = load(pa); - has[{m + mm, k}] = va; - } - if (!hbs.count({n + nn, k})) { - Value pb = gep(f32PtrTy, bPtrs[0], - i32_val((n + nn) * strideBN + k * strideBK)); - Value vb = load(pb); - hbs[{n + nn, k}] = vb; - } - + for (unsigned m = 0; m < M; m += mShapePerCTA) + for (unsigned n = 0; n < N; n += nShapePerCTA) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { ret[z] = rewriter.create(loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]); + ++z; } } @@ -5138,6 +5367,13 @@ public: auto ctx = type.getContext(); Attribute layout = type.getEncoding(); auto shape = type.getShape(); + + // TODO[Keren, Superjomn]: fix it, allowTF32 is not accessible here for it + // is bound to an Op instance. + bool allowTF32 = false; + bool isFMADot = type.getElementType().isF32() && !allowTF32 && + layout.dyn_cast_or_null(); + if (layout && (layout.isa() || layout.isa() || layout.isa())) { @@ -5158,65 +5394,55 @@ public: types.push_back(IntegerType::get(ctx, 32)); } return LLVM::LLVMStructType::getLiteral(ctx, types); - } else if (auto mmaLayout = layout.dyn_cast_or_null()) { - if (mmaLayout.getVersion() == 2) { - auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type); - size_t fcSize = 4 * repM * repN; - return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(fcSize, convertType(type.getElementType()))); - } - - if (mmaLayout.getVersion() == 1) { - DotOpMmaV1ConversionHelper helper(mmaLayout); - int repM = helper.getRepM(shape[0]); - int repN = helper.getRepN(shape[1]); - int elems = 8 * repM * repN; - return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems, convertType(type.getElementType()))); - } - - llvm::errs() - << "Unexpected mma layout detected in TritonToLLVMTypeConverter"; - return llvm::None; - - } else if (auto dot_op_layout = + } else if (auto dotOpLayout = layout.dyn_cast_or_null()) { - auto mmaLayout = dot_op_layout.getParent().cast(); - auto wpt = mmaLayout.getWarpsPerCTA(); - Type elemTy = convertType(type.getElementType()); - auto vecSize = 1; - if (elemTy.getIntOrFloatBitWidth() == 16) { - vecSize = 2; - } else if (elemTy.getIntOrFloatBitWidth() == 8) { - vecSize = 4; - } else { - assert(false && "Unsupported element type"); - } - Type vecTy = vec_ty(elemTy, vecSize); - if (mmaLayout.getVersion() == 2) { - if (dot_op_layout.getOpIdx() == 0) { // $a - int elems = - MMA16816ConversionHelper::getANumElemsPerThread(type, wpt); - return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems, vecTy)); - } - if (dot_op_layout.getOpIdx() == 1) { // $b - int elems = - MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); - return struct_ty(SmallVector(elems, vecTy)); - } - } + if (isFMADot) { // for parent is blocked layout + int numElemsPerThread = + DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout); - if (mmaLayout.getVersion() == 1) { - DotOpMmaV1ConversionHelper helper(mmaLayout); + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(numElemsPerThread, type::f32Ty(ctx))); - if (dot_op_layout.getOpIdx() == 0) { // $a - int elems = helper.numElemsPerThreadA(type); - return struct_ty(SmallVector(elems, vecTy)); + } else { // for parent is MMA layout + auto mmaLayout = dotOpLayout.getParent().cast(); + auto wpt = mmaLayout.getWarpsPerCTA(); + Type elemTy = convertType(type.getElementType()); + auto vecSize = 1; + if (elemTy.getIntOrFloatBitWidth() == 16) { + vecSize = 2; + } else if (elemTy.getIntOrFloatBitWidth() == 8) { + vecSize = 4; + } else { + assert(false && "Unsupported element type"); } - if (dot_op_layout.getOpIdx() == 1) { // $b - int elems = helper.numElemsPerThreadB(type); - return struct_ty(SmallVector(elems, vecTy)); + Type vecTy = vec_ty(elemTy, vecSize); + if (mmaLayout.getVersion() == 2) { + if (dotOpLayout.getOpIdx() == 0) { // $a + int elems = + MMA16816ConversionHelper::getANumElemsPerThread(type, wpt); + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems, vecTy)); + } + if (dotOpLayout.getOpIdx() == 1) { // $b + int elems = + MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); + return struct_ty(SmallVector(elems, vecTy)); + } + } + + if (mmaLayout.getVersion() == 1) { + DotOpMmaV1ConversionHelper helper(mmaLayout); + + if (dotOpLayout.getOpIdx() == 0) { // $a + int elems = helper.numElemsPerThreadA(type); + Type x2Ty = vec_ty(elemTy, 2); + return struct_ty(SmallVector(elems, x2Ty)); + } + if (dotOpLayout.getOpIdx() == 1) { // $b + int elems = helper.numElemsPerThreadB(type); + Type x2Ty = vec_ty(elemTy, 2); + return struct_ty(SmallVector(elems, x2Ty)); + } } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 2973422c5..d603e823c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -64,8 +64,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef shape) { } unsigned getElemsPerThread(Type type) { - if (type.isIntOrIndexOrFloat() || - type.isa() || + if (type.isIntOrIndexOrFloat() || type.isa() || type.isa()) return 1; auto tensorType = type.cast(); @@ -372,6 +371,9 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { unsigned DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape) const { + if (auto blockedLayout = getParent().dyn_cast()) { + return blockedLayout.getElemsPerThread(shape); + } assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented"); return 0; } diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 2bac2c83f..01738db1c 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -171,65 +171,64 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO assert_close(c, golden, rtol=max(1e-4, 1.5 * golden_rel_err), atol=max(1e-4, 1.5 * golden_abs_err), check_dtype=False) -# XXX(Keren): Temporarily disable this test until we have shared -> dot conversion implemented -#@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ -# [32, 32, 16, 4, 32, 32, 16], -# [32, 16, 16, 4, 32, 32, 16], -# [128, 8, 8, 4, 32, 32, 16], -# [127, 41, 43, 4, 32, 32, 16], -#]) -#def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): -# @triton.jit -# def matmul_kernel( -# a_ptr, b_ptr, c_ptr, -# M, N, K, -# stride_am, stride_ak, -# stride_bk, stride_bn, -# stride_cm, stride_cn, -# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -# ): -# pid = tl.program_id(axis=0) -# # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) -# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) -# pid_m = pid // num_pid_n -# pid_n = pid % num_pid_n -# -# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) -# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) -# offs_k = tl.arange(0, BLOCK_SIZE_K) -# a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) -# b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) -# -# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -# for k in range(0, K, BLOCK_SIZE_K): -# a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) -# b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) -# a = tl.load(a_ptrs, a_mask) -# b = tl.load(b_ptrs, b_mask) -# # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering -# accumulator += tl.dot(a, b, allow_tf32=False) -# a_ptrs += BLOCK_SIZE_K * stride_ak -# b_ptrs += BLOCK_SIZE_K * stride_bk -# offs_k += BLOCK_SIZE_K -# -# offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) -# offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) -# c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn -# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) -# tl.store(c_ptrs, accumulator, c_mask) -# -# a = torch.randn((M, K), device='cuda', dtype=torch.float32) -# b = torch.randn((K, N), device='cuda', dtype=torch.float32) -# c = torch.empty((M, N), device=a.device, dtype=torch.float32) -# -# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) -# matmul_kernel[grid](a, b, c, -# M, N, K, -# stride_am=a.stride(0), stride_ak=a.stride(1), -# stride_bk=b.stride(0), stride_bn=b.stride(1), -# stride_cm=c.stride(0), stride_cn=c.stride(1), -# BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) -# -# golden = torch.matmul(a, b) -# torch.testing.assert_close(c, golden) -# +@pytest.mark.parametrize('M,N,K,num_warps,block_M,block_N,block_K', [ + [32, 32, 16, 4, 32, 32, 16], + [32, 16, 16, 4, 32, 32, 16], + [128, 8, 8, 4, 32, 32, 16], + # TODO[Superjomn]: fix it later + # [127, 41, 43, 4, 32, 32, 16], +]) +def test_gemm_fmadot(M, N, K, num_warps, block_M, block_N, block_K): + @triton.jit + def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ): + pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) + b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, a_mask) + b = tl.load(b_ptrs, b_mask) + # NOTE the allow_tf32 should be false to force the dot op to do fmadot lowering + accumulator += tl.dot(a, b, allow_tf32=False) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_k += BLOCK_SIZE_K + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, c_mask) + + a = torch.randn((M, K), device='cuda', dtype=torch.float32) + b = torch.randn((K, N), device='cuda', dtype=torch.float32) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + matmul_kernel[grid](a, b, c, + M, N, K, + stride_am=a.stride(0), stride_ak=a.stride(1), + stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), + BLOCK_SIZE_M=block_M, BLOCK_SIZE_N=block_N, BLOCK_SIZE_K=block_K) + + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 77f8db1db..e4095671b 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -820,18 +820,20 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-warps" = 4 : i32} { func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { - // We are going to completely depracate using shared layout for operands of dot - //%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> - //%28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<32x32xf32, #blocked> - //%30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> - //%36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> - //tt.store %36, %28 : tensor<32x32xf32, #blocked> + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + // CHECK: llvm.intr.fmuladd + %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> + + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> + %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> + %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> + tt.store %36, %28 : tensor<32x32xf32, #blocked> return } } @@ -846,4 +848,4 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> return } -} \ No newline at end of file +}