From 922155f1d24f10cce4030cc5cd2c57e906cb585f Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 23 Sep 2022 11:43:54 +0800 Subject: [PATCH] [BACKEND] add dot conversion (mma version=2) (#672) LLVM Conversion for Dot op. Due to the lack of `convert_layout`, currently, the dot only supports the following combination of operands - `$a` in shared layout - `$b` in shared layout - `$c` in MMA layout(but only Splat-like, leaving the generic cases to `convert_layout`) This PR focus on `mma.16816` related logic support, leaving the other cases to the following PR. Co-authored-by: Philippe Tillet --- include/triton/Conversion/MLIRTypes.h | 11 +- lib/Analysis/Allocation.cpp | 14 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 1075 +++++++++++++++-- lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 +- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 2 +- test/Analysis/test-allocation.mlir | 10 +- test/Conversion/tritongpu_to_llvm.mlir | 31 +- 7 files changed, 1033 insertions(+), 116 deletions(-) diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index 78c1bea33..1daf34358 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -10,14 +10,10 @@ namespace triton { namespace type { // Integer types -Type i32Ty(MLIRContext *ctx) { - return IntegerType::get(ctx, 32, IntegerType::Signed); -} -Type i8Ty(MLIRContext *ctx) { - return IntegerType::get(ctx, 8, IntegerType::Signed); -} +Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); } +Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); } Type u32Ty(MLIRContext *ctx) { - return IntegerType::get(ctx, 32, IntegerType::Signless); + return IntegerType::get(ctx, 32, IntegerType::Unsigned); } Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1, IntegerType::Unsigned); @@ -27,6 +23,7 @@ Type u1Ty(MLIRContext *ctx) { Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); } Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); } +Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); } static bool isFloat(Type type) { return type.isF32() || type.isF64() || type.isF16() || type.isF128(); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 6afa7ea1a..0ea29afdc 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -43,6 +43,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, return 0; } }; + // blocked -> blocked if (srcLayout.isa() && dstLayout.isa()) { auto srcBlockedLayout = srcLayout.cast(); @@ -65,6 +66,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } paddedRepShape[outOrd[0]] += pad; } + // blocked -> shared + if (srcLayout.isa() && + dstLayout.isa()) { + auto sharedLayout = dstLayout.cast(); + for (int v : dstTy.getShape()) + paddedRepShape.push_back(v); + } + return paddedRepShape; } @@ -131,9 +140,8 @@ private: auto dstTy = cvtLayout.result().getType().cast(); auto srcEncoding = srcTy.getEncoding(); auto dstEncoding = dstTy.getEncoding(); - if (srcEncoding.isa() || - dstEncoding.isa()) { - // Only blocked -> blocked conversion requires for scratch allocation + if (srcEncoding.isa()) { + // only block->block and block->shared is supported now return; } // ConvertLayoutOp with both input/output non-shared_layout diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index ef1a0c2ed..eec722935 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -39,6 +39,38 @@ namespace LLVM { static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } +namespace { + +// Create a 32-bit integer constant. +Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +// Add other specification if needed... + +} // namespace + +#define udiv(...) rewriter.create(loc, __VA_ARGS__) +#define urem(...) rewriter.create(loc, __VA_ARGS__) +#define add(...) rewriter.create(loc, __VA_ARGS__) +#define mul(...) rewriter.create(loc, __VA_ARGS__) +#define xor_(...) rewriter.create(loc, __VA_ARGS__) +#define bit_cast(...) rewriter.create(loc, __VA_ARGS__) +#define gep(...) rewriter.create(loc, __VA_ARGS__) +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define insert_val(...) rewriter.create(loc, __VA_ARGS__) +#define extract_val(...) rewriter.create(loc, __VA_ARGS__) +#define insert_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define extract_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define address_of(...) rewriter.create(loc, __VA_ARGS__) + +#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) +#define i32_ty() rewriter.getIntegerType(32) + } // namespace LLVM } // namespace mlir @@ -180,9 +212,8 @@ struct FuncOpConversion : public FuncOpConversionBase { // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. - newFuncOp->setAttr( - NVVMMetadataField::MaxNTid, - rewriter.getIntegerAttr(type::i32Ty(ctx), 32 * NumWarps)); + newFuncOp->setAttr(NVVMMetadataField::MaxNTid, + rewriter.getIntegerAttr(i32_ty(), 32 * NumWarps)); rewriter.eraseOp(funcOp); return success(); @@ -232,9 +263,8 @@ Value getStructFromElements(Location loc, ValueRange resultVals, Type structType) { Value llvmStruct = rewriter.create(loc, structType); for (auto v : llvm::enumerate(resultVals)) { - llvmStruct = rewriter.create( - loc, structType, llvmStruct, v.value(), - rewriter.getI64ArrayAttr(v.index())); + llvmStruct = insert_val(structType, llvmStruct, v.value(), + rewriter.getI64ArrayAttr(v.index())); } return llvmStruct; } @@ -286,8 +316,7 @@ struct ConvertTritonGPUOpToLLVMPatternBase { for (unsigned i = 0; i < elems; ++i) { Type type = llvmStruct.getType().cast().getBody()[i]; - results[i] = rewriter.create( - loc, type, llvmStruct, rewriter.getI64ArrayAttr(i)); + results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i)); } return results; } @@ -304,6 +333,16 @@ public: PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit) {} + Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { + auto llvmIndexTy = this->getTypeConverter()->getIndexType(); + auto cast = rewriter.create( + loc, TypeRange{llvmIndexTy}, + ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>( + loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)}); + Value threadId = cast.getResult(0); + return threadId; + } + SmallVector delinearize(ConversionPatternRewriter &rewriter, Location loc, Value linear, ArrayRef shape, @@ -336,9 +375,8 @@ public: Value dimSize = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), en.value()); - multiDim[rank - 1 - en.index()] = - rewriter.create(loc, remained, dimSize); - remained = rewriter.create(loc, remained, dimSize); + multiDim[rank - 1 - en.index()] = urem(remained, dimSize); + remained = udiv(remained, dimSize); } multiDim[0] = remained; } @@ -356,27 +394,22 @@ public: Value dimSize = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), std::get<1>(z)); - linear = rewriter.create( - loc, rewriter.create(loc, linear, dimSize), - std::get<0>(z)); + linear = add(mul(linear, dimSize), std::get<0>(z)); } } return linear; } SmallVector - emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &b, + emitBaseIndexForBlockedLayout(Location loc, + ConversionPatternRewriter &rewriter, const BlockedEncodingAttr &blocked_layout, ArrayRef shape) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); - auto cast = b.create( - loc, TypeRange{llvmIndexTy}, - ValueRange{b.create<::mlir::gpu::ThreadIdOp>( - loc, b.getIndexType(), ::mlir::gpu::Dimension::x)}); - Value threadId = cast.getResult(0); - Value warpSize = createIndexAttrConstant(b, loc, llvmIndexTy, 32); - Value laneId = b.create(loc, threadId, warpSize); - Value warpId = b.create(loc, threadId, warpSize); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = createIndexAttrConstant(rewriter, loc, llvmIndexTy, 32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); auto sizePerThread = blocked_layout.getSizePerThread(); auto threadsPerWarp = blocked_layout.getThreadsPerWarp(); auto warpsPerCTA = blocked_layout.getWarpsPerCTA(); @@ -385,9 +418,9 @@ public: // step 1, delinearize threadId to get the base index SmallVector multiDimWarpId = - delinearize(b, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, order); SmallVector multiDimThreadId = - delinearize(b, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { // Wrap around multiDimWarpId/multiDimThreadId incase @@ -395,24 +428,22 @@ public: unsigned maxWarps = ceil(shape[k], sizePerThread[k] * threadsPerWarp[k]); unsigned maxThreads = ceil(shape[k], sizePerThread[k]); - multiDimWarpId[k] = b.create( - loc, multiDimWarpId[k], - createIndexAttrConstant(b, loc, llvmIndexTy, maxWarps)); - multiDimThreadId[k] = b.create( - loc, multiDimThreadId[k], - createIndexAttrConstant(b, loc, llvmIndexTy, maxThreads)); + multiDimWarpId[k] = + urem(multiDimWarpId[k], + createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxWarps)); + multiDimThreadId[k] = + urem(multiDimThreadId[k], + createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxThreads)); // multiDimBase[k] = (multiDimThreadId[k] + // multiDimWarpId[k] * threadsPerWarp[k]) * // sizePerThread[k]; - Value threadsPerWarpK = - createIndexAttrConstant(b, loc, llvmIndexTy, threadsPerWarp[k]); + Value threadsPerWarpK = createIndexAttrConstant( + rewriter, loc, llvmIndexTy, threadsPerWarp[k]); Value sizePerThreadK = - createIndexAttrConstant(b, loc, llvmIndexTy, sizePerThread[k]); - multiDimBase[k] = b.create( - loc, sizePerThreadK, - b.create( - loc, multiDimThreadId[k], - b.create(loc, multiDimWarpId[k], threadsPerWarpK))); + createIndexAttrConstant(rewriter, loc, llvmIndexTy, sizePerThread[k]); + multiDimBase[k] = + mul(sizePerThreadK, add(multiDimThreadId[k], + mul(multiDimWarpId[k], threadsPerWarpK))); } return multiDimBase; } @@ -433,7 +464,7 @@ public: } SmallVector> - emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &b, + emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, const SliceEncodingAttr &sliceLayout, ArrayRef shape) const { auto parent = sliceLayout.getParent(); @@ -450,8 +481,8 @@ public: paddedShape[d] = shape[d - 1]; } } - auto paddedIndices = - emitIndicesForBlockedLayout(loc, b, blockedParent, paddedShape); + auto paddedIndices = emitIndicesForBlockedLayout( + loc, rewriter, blockedParent, paddedShape); unsigned numIndices = paddedIndices.size(); SmallVector> resultIndices(numIndices); for (unsigned i = 0; i < numIndices; ++i) { @@ -480,7 +511,7 @@ public: // be eliminated in the consequent MLIR/LLVM optimization. We might // implement a indiceCache if necessary. SmallVector> - emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b, + emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, const BlockedEncodingAttr &blockedLayout, ArrayRef shape) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); @@ -495,7 +526,7 @@ public: // step 1, delinearize threadId to get the base index auto multiDimBase = - emitBaseIndexForBlockedLayout(loc, b, blockedLayout, shape); + emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); // step 2, get offset of each element unsigned elemsPerThread = 1; @@ -544,10 +575,10 @@ public: multiDimNanoTileId[k] * (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + multiElemsInNanoTileId[k]; - multiDimIdx[n][k] = b.create( - loc, multiDimBase[k], - createIndexAttrConstant(b, loc, llvmIndexTy, - offset[k][reorderedMultiDimId])); + multiDimIdx[n][k] = + add(multiDimBase[k], + createIndexAttrConstant(rewriter, loc, llvmIndexTy, + offset[k][reorderedMultiDimId])); } } @@ -564,7 +595,7 @@ public: size_t offset = allocation->getOffset(bufferId); auto llvmIndexTy = this->getTypeConverter()->getIndexType(); Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset); - Value base = rewriter.create(loc, ptrTy, smem, offVal); + Value base = gep(ptrTy, smem, offVal); return base; } }; @@ -582,7 +613,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, auto tensorTy = resType.cast(); auto layout = tensorTy.getEncoding(); auto srcType = typeConverter->convertType(elemType); - auto llSrc = rewriter.create(loc, srcType, constVal); + auto llSrc = bit_cast(srcType, constVal); size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); llvm::SmallVector elems(numElemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); @@ -819,15 +850,15 @@ struct StoreOpConversion Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) elem = rewriter.create(loc, type::i8Ty(ctx), elem); - elem = rewriter.create(loc, valueElemTy, elem); + elem = bit_cast(valueElemTy, elem); - llWord = rewriter.create( - loc, wordTy, llWord, elem, - rewriter.create( - loc, type::u32Ty(ctx), - IntegerAttr::get(type::u32Ty(ctx), elemIdx))); + Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); + llWord = + insert_element(wordTy, llWord, elem, + rewriter.create( + loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx))); } - llWord = rewriter.create(loc, valArgTy, llWord); + llWord = bit_cast(valArgTy, llWord); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); asmArgList->listAppend(ptxBuilder.newOperand(llWord, constraint)); @@ -1023,8 +1054,7 @@ struct MakeRangeOpConversion SmallVector retVals(elems); for (auto multiDim : llvm::enumerate(idxs)) { assert(multiDim.value().size() == 1); - retVals[multiDim.index()] = - rewriter.create(loc, multiDim.value()[0], start); + retVals[multiDim.index()] = add(multiDim.value()[0], start); } SmallVector types(elems, elemTy); Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); @@ -1183,11 +1213,9 @@ struct LoadOpConversion Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), s); - v = rewriter.create(loc, vecTy, v, falseVal, - sVal); + v = insert_element(vecTy, v, falseVal, sVal); } - v = rewriter.create( - loc, IntegerType::get(getContext(), width), v); + v = bit_cast(IntegerType::get(getContext(), width), v); PTXInstr::Operand *opr{}; if (otherIsSplatConstInt) { @@ -1228,14 +1256,13 @@ struct LoadOpConversion for (unsigned int ii = 0; ii < nWords; ii++) { Value curr; if (retTy.isa()) { - curr = rewriter.create( - loc, IntegerType::get(getContext(), width), ret, - rewriter.getI64ArrayAttr(ii)); + curr = extract_val(IntegerType::get(getContext(), width), ret, + rewriter.getI64ArrayAttr(ii)); } else { curr = ret; } - curr = rewriter.create( - loc, LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), + curr = bit_cast( + LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), curr); rets.push_back(curr); } @@ -1243,8 +1270,7 @@ struct LoadOpConversion for (size_t ii = 0; ii < vec; ii++) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); - Value loaded = rewriter.create( - loc, valueElemTy, rets[ii / tmp], vecIdx); + Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); loadedVals.push_back(loaded); } } // end vec @@ -1298,8 +1324,7 @@ struct AddPtrOpConversion getElementsFromStruct(loc, adaptor.offset(), elems, rewriter); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = - rewriter.create(loc, elemTy, ptrs[i], offsets[i]); + resultVals[i] = gep(elemTy, ptrs[i], offsets[i]); } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); @@ -1359,7 +1384,7 @@ public: PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), - allocation_(allocation), smem_(smem) {} + allocation(allocation), smem(smem) {} LogicalResult matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -1374,15 +1399,15 @@ public: if ((!srcLayout.isa()) || (!dstLayout.isa())) { // TODO: not implemented - assert(0 && - "convert_layout except for blocked -> blocked is not implemented"); + llvm::errs() + << "convert_layout except for blocked -> blocked is not implemented"; return failure(); } auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); - Value smemBase = getSharedMemoryBase(loc, rewriter, smem_, allocation_, - op.getOperation()); + Value smemBase = + getSharedMemoryBase(loc, rewriter, smem, allocation, op.getOperation()); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - smemBase = rewriter.create(loc, elemPtrTy, smemBase); + smemBase = bit_cast(elemPtrTy, smemBase); auto shape = dstTy.getShape(); unsigned rank = dstTy.getRank(); @@ -1527,8 +1552,8 @@ private: getMultiDimIndex(elemId, layout.getSizePerThread()); SmallVector multiDimOffset(rank); for (unsigned d = 0; d < rank; ++d) { - multiDimOffset[d] = rewriter.create( - loc, multiDimOffsetFirstElem[d], + multiDimOffset[d] = add( + multiDimOffsetFirstElem[d], createIndexAttrConstant(rewriter, loc, llvmIndexTy, multiDimCTAInRepId[d] * shapePerCTA[d] + multiDimElemId[d])); @@ -1537,18 +1562,16 @@ private: linearize(rewriter, loc, reorder(multiDimOffset, outOrd), reorder(paddedRepShape, outOrd)); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); - Value ptr = - rewriter.create(loc, elemPtrTy, smemBase, offset); + Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = VectorType::get(vec, llvmElemTy); - ptr = rewriter.create( - loc, LLVM::LLVMPointerType::get(vecTy, 3), ptr); + ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr); if (stNotRd) { Value valVec = rewriter.create(loc, vecTy); for (unsigned v = 0; v < vec; ++v) { Value vVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), v); - valVec = rewriter.create( - loc, vecTy, valVec, + valVec = insert_element( + vecTy, valVec, vals[elemId + linearCTAId * accumSizePerThread + v], vVal); } rewriter.create(loc, valVec, ptr); @@ -1558,18 +1581,875 @@ private: Value vVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), v); vals[elemId + linearCTAId * accumSizePerThread + v] = - rewriter.create(loc, llvmElemTy, valVec, - vVal); + extract_element(llvmElemTy, valVec, vVal); } } } } } - const Allocation *allocation_; - Value smem_; + const Allocation *allocation; + Value smem; }; +/// ====================== dot codegen begin ========================== + +class MMA16816SmemLoader { +public: + MMA16816SmemLoader(int wpt, ArrayRef order, int kOrder, + ArrayRef tileShape, ArrayRef instrShape, + ArrayRef matShape, int perPhase, int maxPhase, + int elemBytes, ConversionPatternRewriter &rewriter, + TypeConverter *typeConverter, const Location &loc) + : wpt(wpt), order(order), kOrder(kOrder), tileShape(tileShape), + instrShape(instrShape), matShape(matShape), perPhase(perPhase), + maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), + typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) { + cMatShape = matShape[order[0]]; + sMatShape = matShape[order[1]]; + + cTileStride = tileShape[order[1]]; + sTileStride = tileShape[order[0]]; + + // rule: k must be the fast-changing axis. + needTrans = kOrder != order[0]; + canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16 + + if (canUseLdmatrix) { + // Each CTA, the warps is arranged as [1xwpt] if not transposed, + // otherwise [wptx1], and each warp will perform a mma. + numPtr = + tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; + } else { + numPtr = tileShape[order[0]] / wpt / matShape[order[0]]; + } + + numPtr = std::max(numPtr, 2); + + // Special rule for i8/u8, 4 ptrs for each matrix + if (!canUseLdmatrix && elemBytes == 1) + numPtr *= 4; + + int loadStrideInMat[2]; + loadStrideInMat[kOrder] = + 2; // instrShape[kOrder] / matShape[kOrder], always 2 + loadStrideInMat[kOrder ^ 1] = + wpt * (instrShape[order[1]] / matShape[order[1]]); + + pLoadStrideInMat = loadStrideInMat[order[0]]; + sMatStride = + loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]); + + // Each matArr contains warpOffStride matrices. + matArrStride = kOrder == 1 ? 1 : wpt; + warpOffStride = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]; + } + + // lane = thread % 32 + // warpOff = (thread/32) % wpt(0) + llvm::SmallVector computeOffsets(Value warpOff, Value lane) { + if (canUseLdmatrix) + return computeLdmatrixMatOffs(warpOff, lane); + else if (elemBytes == 4 && needTrans) + return computeB32MatOffs(warpOff, lane); + else if (elemBytes == 1 && needTrans) + return computeB8MatOffs(warpOff, lane); + else + llvm::report_fatal_error("Invalid smem load config"); + + return {}; + } + + int getNumPtr() const { return numPtr; } + + // Compute the offset to the matrix this thread(indexed by warpOff and lane) + // mapped to. + SmallVector computeLdmatrixMatOffs(Value warpId, Value lane) { + MLIRContext *ctx = warpId.getContext(); + + // 4x4 matrices + Value c = urem(lane, i32_val(8)); + Value s = udiv(lane, i32_val(8)); // sub-warp-id + + // Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a warp + Value s0 = urem(s, i32_val(2)); + Value s1 = udiv(s, i32_val(2)); + + // We use different orders for a and b for better performance. + Value kMatArr = kOrder == 1 ? s1 : s0; + Value nkMatArr = kOrder == 1 ? s0 : s1; + + // matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and + // [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is + // |0 0 1 1 2 2| -> 0,1,2 are the warpids + // |0 0 1 1 2 2| + // + // for B(kOrder=0) is + // |0 0| -> 0,1,2 are the warpids + // |1 1| + // |2 2| + // |0 0| + // |1 1| + // |2 2| + // Note, for each warp, it handles a 2x2 matrices, that is the coordinate + // address (s0,s1) annotates. + + Value matOff[2]; + matOff[kOrder ^ 1] = add( + mul(warpId, i32_val(warpOffStride)), // warp offset + mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp + matOff[kOrder] = kMatArr; + + // Physical offset (before swizzling) + Value cMatOff = matOff[order[0]]; + Value sMatOff = matOff[order[1]]; + + // row offset inside a matrix, each matrix has 8 rows. + Value sOffInMat = c; + + SmallVector offs(numPtr); + Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); + Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); + for (int i = 0; i < numPtr; ++i) { + Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); + cMatOffI = xor_(cMatOffI, phase); + offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), + mul(sOff, i32_val(sTileStride))); + } + + return offs; + } + + // Compute 32-bit matrix offsets. + SmallVector computeB32MatOffs(Value warpOff, Value lane) { + assert(needTrans && "Only used in transpose mode."); + // Load tf32 matrices with lds32 + Value cOffInMat = udiv(lane, i32_val(4)); + Value sOffInMat = urem(lane, i32_val(4)); + + Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); + SmallVector offs(numPtr); + + for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time + int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; + int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; + if (kMatArrInt > 0) // we don't need pointers for k + continue; + Value kMatArr = i32_val(kMatArrInt); + Value nkMatArr = i32_val(nkMatArrInt); + + Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), + mul(nkMatArr, i32_val(matArrStride))); + Value sMatOff = kMatArr; + Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); + // FIXME: (kOrder == 1?) is really dirty hack + for (int i = 0; i < numPtr / 2; ++i) { + Value cMatOffI = + add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); + cMatOffI = xor_(cMatOffI, phase); + Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); + cOff = urem(cOff, i32_val(tileShape[order[0]])); + sOff = urem(sOff, i32_val(tileShape[order[1]])); + offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, i32_val(sTileStride))); + } + } + return offs; + } + + // compute 8-bit matrix offset. + SmallVector computeB8MatOffs(Value warpOff, Value lane) { + assert(needTrans && "Only used in transpose mode."); + Value cOffInMat = udiv(lane, i32_val(4)); + Value sOffInMat = + mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols + + SmallVector offs(numPtr); + for (int mat = 0; mat < 4; ++mat) { + int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; + int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; + if (kMatArrInt > 0) // we don't need pointers for k + continue; + Value kMatArr = i32_val(kMatArrInt); + Value nkMatArr = i32_val(nkMatArrInt); + + Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), + mul(nkMatArr, i32_val(matArrStride))); + Value sMatOff = kMatArr; + + for (int loadx4Off = 0; loadx4Off < numPtr / 8; ++loadx4Off) { + for (int elemOff = 0; elemOff < 4; ++elemOff) { + int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; + Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * + (kOrder == 1 ? 1 : 2))); + Value sOffInMatElem = add(sOffInMat, i32_val(elemOff)); + + // disable swizzling ... + + Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); + Value sOff = add(sOffInMatElem, mul(sMatOff, i32_val(sMatShape))); + // To prevent out-of-bound access when tile is too small. + cOff = urem(cOff, i32_val(tileShape[order[0]])); + sOff = urem(sOff, i32_val(tileShape[order[1]])); + offs[ptrOff] = add(cOff, mul(sOff, i32_val(sTileStride))); + } + } + } + return offs; + } + + // Load 4 matrices and returns 4 vec<2> elements. + std::tuple + loadX4(int mat0, int mat1, ArrayRef offs, ArrayRef ptrs, + Type ldmatrixRetTy, Type shemPtrTy) const { + assert(mat0 % 2 == 0 && mat1 % 2 == 0 && + "smem matrix load must be aligned"); + int matIdx[2] = {mat0, mat1}; + int k = matIdx[kOrder]; + + int ptrIdx{-1}; + if (canUseLdmatrix) + ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); + else if (elemBytes == 4 && needTrans) // tf32 & trans + ptrIdx = matIdx[order[0]]; + else if (elemBytes == 1 && needTrans) + ptrIdx = matIdx[order[0]] * 4; + else + llvm::report_fatal_error("unsupported mma type found"); + + // prefetch logic removed here. + auto getPtr = [&](int idx) { return ptrs[idx]; }; + + Value ptr = getPtr(ptrIdx); + + Value resV4; + if (canUseLdmatrix) { + int sOffset = + matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; + PTXBuilder builder; + + auto resArgs = builder.newListOperand(); + + // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread. + for (int i = 0; i < 4; i++) + resArgs->listAppend(builder.newOperand("=r")); + auto addrArg = builder.newAddrOperand(ptr, "r", sOffset); + + auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") + ->o("trans", needTrans /*predicate*/) + .o("shared.b16"); + ldmatrix(resArgs, addrArg); + + auto inlineAsm = rewriter.create( + loc, ldmatrixRetTy, builder.getAllMLIRArgs(), // operands + builder.dump(), // asm_string + builder.getConstraints(), // constraints + true, // has_side_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, {}) // operand_attrs + ); + + auto getIntAttr = [&](int v) { + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)}); + }; + + Value resV4 = inlineAsm.getRes(); // 4xi32, each is composed of 2xf16 + // elements(adjacent columns in a row) + + Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx)); + + return std::make_tuple(extract_val(fp16x2Ty, resV4, getIntAttr(0)), + extract_val(fp16x2Ty, resV4, getIntAttr(1)), + extract_val(fp16x2Ty, resV4, getIntAttr(2)), + extract_val(fp16x2Ty, resV4, getIntAttr(3))); + } else if (elemBytes == 4 && + needTrans) { // Use lds.32 to load tf32 matrices + assert(false && "Not implemented yet"); + } else if (elemBytes == 1 && needTrans) { + assert(false && "Not implemented yet"); + } + return std::make_tuple(Value{}, Value{}, Value{}, Value{}); + } + +private: + int wpt; + ArrayRef order; + int kOrder; + ArrayRef tileShape; + ArrayRef instrShape; + ArrayRef matShape; + int perPhase; + int maxPhase; + int elemBytes; + ConversionPatternRewriter &rewriter; + TypeConverter *typeConverter{}; + const Location &loc; + MLIRContext *ctx{}; + + int cMatShape; + int sMatShape; + + int cTileStride; + int sTileStride; + + bool needTrans; + bool canUseLdmatrix; + + int numPtr; + + int pLoadStrideInMat; + int sMatStride; + + int matArrStride; + int warpOffStride; +}; + +bool isSplatLike(Value value) { + if (auto constv = dyn_cast(value.getDefiningOp())) + if (auto attr = constv.getValue().dyn_cast()) + return attr.isSplat(); + return false; +} + +struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { + enum class TensorCoreType : uint8_t { + // floating-point tensor core instr + FP32_FP16_FP16_FP32 = 0, // default + FP32_BF16_BF16_FP32, + FP32_TF32_TF32_FP32, + // integer tensor core instr + INT32_INT1_INT1_INT32, // Not implemented + INT32_INT4_INT4_INT32, // Not implemented + INT32_INT8_INT8_INT32, // Not implemented + // + NOT_APPLICABLE, + }; + + explicit DotOpConversion(LLVMTypeConverter &typeConverter, + const Allocation *allocation, Value smem, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit), + allocation(allocation), smem(smem) {} + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + // D = A * B + C + Value A = op.a(); + Value B = op.b(); + Value C = op.c(); + Value D = op.getResult(); + MLIRContext *ctx = op->getContext(); + bool allowTF32 = op.allowTF32(); + + assert(isSplatLike(C) && "Currently only splat-like C is supported now"); + + // Here we assume the DotOp's operands always comes from shared memory. + auto AShape = A.getType().cast().getShape(); + size_t reduceAxis = 1; + unsigned K = AShape[reduceAxis]; + bool isOuter = K == 1; + bool isMMA = D.getType() + .cast() + .getEncoding() + .isa(); + MmaEncodingAttr mmaLayout; + if (isMMA) + mmaLayout = D.getType() + .cast() + .getEncoding() + .cast(); + + if (!isOuter && isMMA) { + if (mmaLayout.getVersion() == 1) + return convertMMA884(op, adaptor, rewriter); + if (mmaLayout.getVersion() == 2) + return convertMMA16816(op, adaptor, rewriter); + llvm::report_fatal_error( + "Unsupported MMA kind found when converting DotOp to LLVM."); + } + + if (op.getType().cast().getElementType().isF32() && + A.getType().cast().getElementType().isF32()) + return convertFMADot(op, adaptor, rewriter); + + llvm::report_fatal_error( + "Unsupported DotOp found when converting TritonGPU to LLVM."); + } + +private: + // Convert to mma.m16n8k16 + LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const; + /// Convert to mma.m8n8k4 + LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const { + assert(false && "Not implemented yet."); + return failure(); + } + + LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const { + assert(false && "Not implemented yet."); + return failure(); + } + + Value getSmemAddr(Value value, Location loc, + ConversionPatternRewriter &rewriter) const { + return getSharedMemoryBase(loc, rewriter, smem, allocation, + value.getDefiningOp()); + } + + const Allocation *allocation; + Value smem; +}; + +struct DotOpConversionHelper { + using TensorCoreType = DotOpConversion::TensorCoreType; + + Value A, B, C, D; + MmaEncodingAttr mmaLayout; + RankedTensorType ATensorTy, BTensorTy, DTensorTy; + MLIRContext *ctx{}; + + explicit DotOpConversionHelper(DotOp dot) + : dot(dot), mmaType(getMmaType(dot)) { + A = dot.a(); + B = dot.b(); + C = dot.c(); + D = dot.d(); + ctx = dot->getContext(); + mmaLayout = C.getType() + .cast() + .getEncoding() + .cast(); + + ATensorTy = A.getType().cast(); + BTensorTy = B.getType().cast(); + DTensorTy = D.getType().cast(); + } + + // Load SplatLike C which contains a constVal. It simply returns 4 fp32 + // constVal. + SmallVector loadSplatLikeC(Value C, Location loc, + ConversionPatternRewriter &rewriter) { + assert(isSplatLike(C)); + + int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32; + if (auto constv = llvm::dyn_cast(C.getDefiningOp())) { + if (auto attr = constv.getValue().dyn_cast()) { + Type elemType = attr.getElementType(); + if (elemType.isInteger(32)) { + int v = attr.getSplatValue(); + return SmallVector(numRes, i32_val(v)); + } else if (elemType.isInteger(8)) { + int v = attr.getSplatValue(); + auto newv = rewriter.create( + loc, elemType, IntegerAttr::get(elemType, v)); + return SmallVector(numRes, newv); + } else if (elemType.isF32()) { + int v = attr.getSplatValue(); + auto newv = rewriter.create( + loc, elemType, FloatAttr::get(elemType, v)); + return SmallVector(numRes, newv); + } + } + } + + assert(false && "Not supported type."); + return {}; + } + + Type getShemPtrTy() const { + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return ptr_ty(type::f16Ty(ctx), 3); + case TensorCoreType::FP32_BF16_BF16_FP32: + return ptr_ty(type::bf16Ty(ctx), 3); + case TensorCoreType::FP32_TF32_TF32_FP32: + return ptr_ty(type::f32Ty(ctx), 3); + case TensorCoreType::INT32_INT8_INT8_INT32: + return ptr_ty(type::i8Ty(ctx), 3); + default: + llvm::report_fatal_error("mma16816 data type not supported"); + } + return Type{}; + } + + // The type of a matrix that loaded by either a ldmatrix or composed lds. + Type getMatType() const { + Type fp32Ty = type::f32Ty(ctx); + Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx)); + Type bf16x2Ty = VectorType::get({2}, type::bf16Ty(ctx)); + // floating point types + Type fp16x2Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); + Type bf16x2Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, bf16x2Ty)); + Type fp32Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); + // integer types + Type i8x4Ty = VectorType::get({4}, type::i8Ty(ctx)); + Type i8x4Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); + Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(4, type::i32Ty(ctx))); + + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return fp16x2Pack4Ty; + case TensorCoreType::FP32_BF16_BF16_FP32: + return bf16x2Pack4Ty; + case TensorCoreType::FP32_TF32_TF32_FP32: + return fp32Pack4Ty; + case TensorCoreType::INT32_INT8_INT8_INT32: + return i8x4Pack4Ty; + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; + } + + Type getMmaRetType() const { + Type fp32Ty = type::f32Ty(ctx); + Type i32Ty = type::i32Ty(ctx); + Type fp32x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); + Type i32x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i32Ty)); + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return fp32x4Ty; + case TensorCoreType::FP32_BF16_BF16_FP32: + return fp32x4Ty; + case TensorCoreType::FP32_TF32_TF32_FP32: + return fp32x4Ty; + case TensorCoreType::INT32_INT8_INT8_INT32: + return i32x4Ty; + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; + } + + ArrayRef getMmaInstrShape() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrShape.at(mmaType); + } + + ArrayRef getMmaMatShape() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaMatShape.at(mmaType); + } + + int getVec() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrVec.at(mmaType); + } + + StringRef getMmaInstr() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrPtx.at(mmaType); + } + + static TensorCoreType getMmaType(triton::DotOp op) { + Value A = op.a(); + Value B = op.b(); + auto aTy = A.getType().cast(); + auto bTy = B.getType().cast(); + // d = a*b + c + auto dTy = op.d().getType().cast(); + auto mmaLayout = dTy.getEncoding().cast(); + + if (dTy.getElementType().isF32()) { + if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) + return TensorCoreType::FP32_FP16_FP16_FP32; + if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) + return TensorCoreType::FP32_BF16_BF16_FP32; + if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && + op.allowTF32()) + return TensorCoreType::FP32_TF32_TF32_FP32; + } else if (dTy.getElementType().isInteger(32)) { + if (aTy.getElementType().isInteger(8) && + bTy.getElementType().isInteger(8)) + return TensorCoreType::INT32_INT8_INT8_INT32; + } + + return TensorCoreType::NOT_APPLICABLE; + } + +private: + TensorCoreType mmaType; + + // Used on nvidia GPUs mma layout .version == 2 + // Refer to + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-storage + // for more details. + inline static const std::map> + mmaInstrShape = { + {TensorCoreType::FP32_FP16_FP16_FP32, {16, 8, 16}}, + {TensorCoreType::FP32_BF16_BF16_FP32, {16, 8, 16}}, + {TensorCoreType::FP32_TF32_TF32_FP32, {16, 8, 8}}, + + {TensorCoreType::INT32_INT1_INT1_INT32, {16, 8, 256}}, + {TensorCoreType::INT32_INT4_INT4_INT32, {16, 8, 64}}, + {TensorCoreType::INT32_INT8_INT8_INT32, {16, 8, 32}}, + }; + + // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) + // Refer to + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + // for more details. + inline static const std::map> + mmaMatShape = { + {TensorCoreType::FP32_FP16_FP16_FP32, {8, 8, 8}}, + {TensorCoreType::FP32_BF16_BF16_FP32, {8, 8, 8}}, + {TensorCoreType::FP32_TF32_TF32_FP32, {8, 8, 4}}, + + {TensorCoreType::INT32_INT1_INT1_INT32, {8, 8, 64}}, + {TensorCoreType::INT32_INT4_INT4_INT32, {8, 8, 32}}, + {TensorCoreType::INT32_INT8_INT8_INT32, {8, 8, 16}}, + }; + + // Supported mma instruction in PTX. + // Refer to + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + // for more details. + inline static const std::map mmaInstrPtx = { + {TensorCoreType::FP32_FP16_FP16_FP32, + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, + {TensorCoreType::FP32_BF16_BF16_FP32, + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, + {TensorCoreType::FP32_TF32_TF32_FP32, + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, + + {TensorCoreType::INT32_INT1_INT1_INT32, + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, + {TensorCoreType::INT32_INT4_INT4_INT32, + "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, + {TensorCoreType::INT32_INT8_INT8_INT32, + "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, + }; + + // vector length per ldmatrix (16*8/elelment_size_in_bits) + inline static const std::map mmaInstrVec = { + {TensorCoreType::FP32_FP16_FP16_FP32, 8}, + {TensorCoreType::FP32_BF16_BF16_FP32, 8}, + {TensorCoreType::FP32_TF32_TF32_FP32, 4}, + + {TensorCoreType::INT32_INT1_INT1_INT32, 128}, + {TensorCoreType::INT32_INT4_INT4_INT32, 32}, + {TensorCoreType::INT32_INT8_INT8_INT32, 16}, + }; + +private: + DotOp dot; +}; + +LogicalResult +DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + MLIRContext *ctx = op->getContext(); + // D = A * B + C + Value A = op.a(); + Value B = op.b(); + Value C = op.c(); + Value D = op.getResult(); + bool allowTF32 = op.allowTF32(); + + auto aTensorTy = A.getType().cast(); + auto bTensorTy = B.getType().cast(); + auto dTensorTy = D.getType().cast(); + + auto aShape = aTensorTy.getShape(); + auto bShape = bTensorTy.getShape(); + auto dShape = dTensorTy.getShape(); + + auto mmaLayout = dTensorTy.getEncoding().cast(); + + auto wpt = mmaLayout.getWarpsPerCTA(); + + // TODO(Superjomn) Process C->is_trans_a() logic + + DotOpConversionHelper helper(op); + + int NK = aShape[1]; + + auto mmaInstrShape = helper.getMmaInstrShape(); + const int mmaInstrM = mmaInstrShape[0]; + const int mmaInstrN = mmaInstrShape[1]; + const int mmaInstrK = mmaInstrShape[2]; + + auto matShape = helper.getMmaMatShape(); + const int matShapeM = matShape[0]; + const int matShapeN = matShape[1]; + const int matShapeK = matShape[2]; + + // shape / shape_per_cta + const int numRepM = std::max(dShape[0] / (wpt[0] * mmaInstrM), 1); + const int numRepN = std::max(dShape[1] / (wpt[1] * mmaInstrN), 1); + const int numRepK = std::max(NK / mmaInstrK, 1); + + Value head = getThreadId(rewriter, loc); + Value lane = urem(head, i32_val(32)); + Value warp = udiv(head, i32_val(32)); + Value warpMN = udiv(warp, i32_val(wpt[0])); + Value warpM = urem(warp, i32_val(wpt[0])); + Value warpN = urem(warpMN, i32_val(wpt[1])); + + size_t aElemBytes = aTensorTy.getElementTypeBitWidth() / 8; + size_t bElemBytes = bTensorTy.getElementTypeBitWidth() / 8; + + std::map, Value> ha; + std::map, Value> hb; + + // the original register_lds2, but discard the prefetch logic. + auto ld2 = [&](decltype(ha) &vals, int mn, int k, Value val) { + vals[{mn, k}] = val; + }; + + // Load A or B matrix. + auto getLoadMatrixFn = + [&](Value tensor, int wpt, int kOrder, ArrayRef instrShape, + ArrayRef matShape, Value warpId, + decltype(ha) &vals) -> std::function { + auto tensorTy = tensor.getType().cast(); + // We assumes that the input operand of Dot should be from shared layout. + // TODO(Superjomn) Consider other layouts if needed later. + auto sharedLayout = tensorTy.getEncoding().cast(); + const int perPhase = sharedLayout.getPerPhase(); + const int maxPhase = sharedLayout.getMaxPhase(); + const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; + + MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, + tensorTy.getShape() /*tileShape*/, instrShape, + matShape, perPhase, maxPhase, elemBytes, rewriter, + typeConverter, loc); + SmallVector offs = loader.computeOffsets(warpId, lane); + + const int numPtrs = loader.getNumPtr(); + SmallVector ptrs(numPtrs); + + Type smemPtrTy = helper.getShemPtrTy(); + auto smemBase = getSmemAddr(tensor, loc, rewriter); + for (int i = 0; i < numPtrs; i++) { + ptrs[i] = bit_cast( + smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]}))); + } + + // (a, b) is the coordinate. + auto load = [&, loader, ptrs, offs](int a, int b) { + auto [ha0, ha1, ha2, ha3] = loader.loadX4( + (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, + ptrs, helper.getMatType(), helper.getShemPtrTy()); + ld2(vals, a, b, ha0); + ld2(vals, a + 1, b, ha1); + ld2(vals, a, b + 1, ha2); + ld2(vals, a + 1, b + 1, ha3); + }; + + return load; + }; + + std::function loadA = getLoadMatrixFn( + A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/, + {mmaInstrM, mmaInstrK} /*instrShpae*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); + std::function loadB = getLoadMatrixFn( + B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, + {mmaInstrK, mmaInstrN} /*instrShpae*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); + + const unsigned mStride = numRepN * 2; + SmallVector fc(numRepM * mStride + numRepN * 2); + auto callMma = [&](unsigned m, unsigned n, unsigned k) { + PTXBuilder builder; + + auto &mma = *builder.create(helper.getMmaInstr().str()); + + auto retArgs = builder.newListOperand(); + for (int i = 0; i < 4; ++i) + retArgs->listAppend(builder.newOperand("=r")); + auto aArg0 = builder.newOperand(ha[{m, k}], "r"); + auto aArg1 = builder.newOperand(ha[{m + 1, k}], "r"); + auto aArg2 = builder.newOperand(ha[{m, k + 1}], "r"); + auto aArg3 = builder.newOperand(ha[{m + 1, k}], "r"); + + auto bArg0 = builder.newOperand(ha[{n, k}], "r"); + auto bArg1 = builder.newOperand(ha[{n, k + 1}], "r"); + + // Currently, we only support a SplatLike C. For the other cases, e.g., C in + // shared layout or blocked layout, we will support them by expanding + // convert_layout. + auto hc = helper.loadSplatLikeC(C, loc, rewriter); + assert(hc.size() == 4UL && "Only splat-like C is supported now"); + auto cArg0 = builder.newOperand(hc[0], "0"); // reuse the output registers + auto cArg1 = builder.newOperand(hc[1], "1"); + auto cArg2 = builder.newOperand(hc[2], "2"); + auto cArg3 = builder.newOperand(hc[3], "3"); + + mma({retArgs, aArg0, aArg1, aArg2, aArg3, bArg0, bArg1, cArg0, cArg1, cArg2, + cArg3}); + + auto inlineAsm = rewriter.create( + loc, helper.getMmaRetType(), builder.getAllMLIRArgs(), // operands + builder.dump(), // asm_string + builder.getConstraints(), // constraints + true, // has_side_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, {}) // operand_attrs + ); + + auto mmaOut = inlineAsm.getRes(); + auto getIntAttr = [&](int v) { + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)}); + }; + + fc[(m + 0) * mStride + (n * 2 + 0)] = + extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(0)); + fc[(m + 0) * mStride + (n * 2 + 1)] = + extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(1)); + fc[(m + 1) * mStride + (n * 2 + 0)] = + extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(2)); + fc[(m + 1) * mStride + (n * 2 + 1)] = + extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(3)); + }; + + // Main program + + for (unsigned k = 0; k < numRepK; k++) { + for (unsigned m = 0; m < numRepM; m++) + loadA(2 * m, 2 * k); + for (unsigned n = 0; n < numRepN; n += 2) + loadB(n, 2 * k); + for (unsigned m = 0; m < numRepM; m++) + for (unsigned n = 0; n < numRepN; n++) { + callMma(2 * m, n, 2 * k); + } + } + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(fc.size(), type::f32Ty(ctx))); + Value res = getStructFromElements(loc, fc, rewriter, structTy); + rewriter.replaceOp(op, res); + + return success(); +} + +/// ====================== mma codegen end ============================ + class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { public: using TypeConverter::convertType; @@ -1600,11 +2480,9 @@ public: convertType(type.getElementType())); return LLVM::LLVMStructType::getLiteral(&getContext(), types); } else if (auto mma_layout = layout.dyn_cast()) { - // TODO: Not implemented - return llvm::None; + return type; } else if (auto shared_layout = layout.dyn_cast()) { - // TODO: Not implemented - return llvm::None; + return type; } return llvm::None; } @@ -1637,6 +2515,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add>(typeConverter, benefit); patterns.add>(typeConverter, benefit); + patterns.add(typeConverter, allocation, smem, benefit); } class ConvertTritonGPUToLLVM @@ -1661,7 +2540,7 @@ public: // step 2: Allocate for shared memories // step 3: Convert the rest of ops via partial conversion // The reason for a seperation between 1/3 is that, step 2 is out of - // the scope of Dialect Conversion, thus we need to make sure the smem_ + // the scope of Dialect Conversion, thus we need to make sure the smem // is not revised during the conversion of step 3. RewritePatternSet func_patterns(context); func_patterns.add(typeConverter, numWarps, 1 /*benefit*/); @@ -1678,7 +2557,7 @@ public: // patterns. RewritePatternSet patterns(context); populateTritonToLLVMPatterns(typeConverter, patterns, numWarps, - *axisAnalysis, &allocation, smem_, + *axisAnalysis, &allocation, smem, 10 /*benefit*/); // Add arith/math's patterns to help convert scalar expression to LLVM. @@ -1704,7 +2583,7 @@ protected: void initSharedMemory(size_t size, TritonGPUToLLVMTypeConverter &typeConverter); - Value smem_; + Value smem; }; void ConvertTritonGPUToLLVM::initSharedMemory( @@ -1723,7 +2602,7 @@ void ConvertTritonGPUToLLVM::initSharedMemory( assert(funcs.size() == 1 && "Inliner pass is expected before TritonGPUToLLVM"); b.setInsertionPointToStart(&funcs[0].getBody().front()); - smem_ = b.create(loc, global); + smem = b.create(loc, global); } } // namespace diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 8ccaf6f3d..d6fd864b7 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -177,9 +177,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { } unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { - // TODO: - assert(0 && "MmaEncodingAttr::getElemsPerThread not implemented"); - return 0; + int threads = product(getWarpsPerCTA()); + int numElem = product(shape); + return numElem / threads; } unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 5a04496d5..179c9391a 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -66,7 +66,7 @@ void extractNVVMMetadata(mlir::ModuleOp module, // maxntid if (op->hasAttr(NVVMMetadataField::MaxNTid)) { auto attr = op->getAttr(NVVMMetadataField::MaxNTid); - meta.maxntidx = attr.dyn_cast().getSInt(); + meta.maxntidx = attr.dyn_cast().getInt(); hasMetadata = true; } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 7cfecade2..2fbff865d 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -22,9 +22,11 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: offset = 0, size = 8192 + // CHECK: scratch offset = 8192, size = 0 + // CHECK-NEXT: offset = 0, size = 8192 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> + // CHECK-NEXT: scratch offset = 16384, size = 0 // CHECK-NEXT: offset = 8192, size = 8192 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> @@ -50,16 +52,20 @@ func @reusable(%A : !tt.ptr) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: offset = 0, size = 8192 + // CHECK: scratch offset = 8192, size = 0 + // CHECK-NEXT: offset = 0, size = 8192 %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> + // CHECK-NEXT: scratch offset = 16384, size = 0 // CHECK-NEXT: offset = 8192, size = 8192 %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> + // CHECK-NEXT: scratch offset = 24576, size = 0 // CHECK-NEXT: offset = 16384, size = 8192 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> %c = tt.dot %a1, %a2, %c_init {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> + // CHECK-NEXT: scratch offset = 8192, size = 0 // CHECK-NEXT: offset = 0, size = 8192 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #A> %c1 = tt.dot %a3, %a4, %c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 59ae739c6..915cc6c0c 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -3,7 +3,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) // Here the 128 comes from the 4 in module attribute multiples 32 - // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : si32} {{.*}} + // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}} func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return return @@ -422,6 +422,33 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { } } +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}> +#mma0 = #triton_gpu.mma<{version=2, warpsPerCTA=[1,1]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_dot + func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { + %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> + %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + %D = tt.dot %AA, %BB, %cst0 {allowTF32 = true} : tensor<16x16xf16, #shared0> * tensor<16x16xf16, #shared0> -> tensor<16x16xf32, #mma0> + + return + } +} + + // TODO: problems in MLIR's parser on slice layout // #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> // module attributes {"triton_gpu.num-warps" = 1 : i32} { @@ -429,4 +456,4 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> // return // } -// } \ No newline at end of file +// }