From 3a84278530fd3b0eb27baabe34e0dda5d3ba7af3 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 27 Sep 2022 14:38:34 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Refine dot conversion (#710) This PR does 1. Refine the dot conversion 2. some other tiny code refinement --- include/triton/Analysis/Utility.h | 12 + .../Conversion/TritonGPUToLLVM/PtxAsmFormat.h | 33 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 2 +- lib/Analysis/Allocation.cpp | 9 +- lib/Analysis/AxisInfo.cpp | 8 +- .../TritonGPUToLLVM/PtxAsmFormat.cpp | 24 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 570 ++++++++++-------- lib/Dialect/TritonGPU/IR/Dialect.cpp | 49 +- unittest/Analysis/CMakeLists.txt | 2 +- unittest/Analysis/UtilityTest.cpp | 19 +- .../Conversion/TritonGPUToLLVM/CMakeLists.txt | 2 +- 11 files changed, 439 insertions(+), 291 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index fccd13320..6152c11f5 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -20,6 +20,18 @@ template Int product(llvm::ArrayRef arr) { template Int ceil(Int m, Int n) { return (m + n - 1) / n; } +// output[i] = input[order[i]] +template +SmallVector reorder(ArrayRef input, ArrayRef order) { + size_t rank = order.size(); + assert(input.size() == rank); + SmallVector result(rank); + for (auto it : llvm::enumerate(order)) { + result[it.index()] = input[it.value()]; + } + return result; +} + } // namespace mlir #endif // TRITON_ANALYSIS_UTILITY_H diff --git a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h index 9a082a1ec..c956eb899 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h @@ -8,6 +8,9 @@ #include namespace mlir { +class ConversionPatternRewriter; +class Location; + namespace triton { using llvm::StringRef; @@ -104,6 +107,31 @@ struct PTXBuilder { // Create a list of operands. Operand *newListOperand() { return newOperand(); } + Operand *newListOperand(ArrayRef> items) { + auto *list = newOperand(); + for (auto &item : items) { + list->listAppend(newOperand(item.first, item.second)); + } + return list; + } + + Operand *newListOperand(unsigned count, mlir::Value val, + const std::string &constraint) { + auto *list = newOperand(); + for (int i = 0; i < count; ++i) { + list->listAppend(newOperand(val, constraint)); + } + return list; + } + + Operand *newListOperand(unsigned count, const std::string &constraint) { + auto *list = newOperand(); + for (int i = 0; i < count; ++i) { + list->listAppend(newOperand(constraint)); + } + return list; + } + // Create a new operand. It will not add to operand list. // @value: the MLIR value bind to this operand. // @constraint: ASM operand constraint, .e.g. "=r" @@ -131,6 +159,11 @@ struct PTXBuilder { std::string dump() const; + mlir::Value launch(ConversionPatternRewriter &rewriter, Location loc, + Type resTy, bool hasSideEffect = true, + bool isAlignStack = false, + ArrayRef attrs = {}) const; + private: Operand *newOperand() { argArchive.emplace_back(std::make_unique()); diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index c36b0e501..b9aa13cae 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -24,7 +24,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef shape); SmallVector getSizePerThread(Attribute layout); -unsigned getShapePerCTA(const Attribute &layout, unsigned d); +SmallVector getShapePerCTA(const Attribute &layout); SmallVector getOrder(const Attribute &layout); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 1a23dca6d..e4efee3ce 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -56,11 +56,14 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + auto srcShapePerCTA = getShapePerCTA(srcLayout); + auto dstShapePerCTA = getShapePerCTA(dstLayout); + unsigned pad = std::max(inVec, outVec); for (unsigned d = 0; d < rank; ++d) { - paddedRepShape[d] = std::max( - std::min(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), - std::min(dstTy.getShape()[d], getShapePerCTA(dstLayout, d))); + paddedRepShape[d] = + std::max(std::min(srcTy.getShape()[d], srcShapePerCTA[d]), + std::min(dstTy.getShape()[d], dstShapePerCTA[d])); } unsigned paddedDim = 1; if (auto dstBlockedLayout = dstLayout.dyn_cast()) { diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index e5472eff8..0d0706eab 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -65,7 +65,7 @@ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { DimVectorT retContiguity; DimVectorT retDivisibility; DimVectorT retConstancy; - for (size_t d = 0; d < lhs.getRank(); d++) { + for (size_t d = 0; d < lhs.getRank(); ++d) { retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retDivisibility.push_back( gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); @@ -87,7 +87,7 @@ AxisInfo AxisInfoAnalysis::visitBinaryOp( AxisInfo::DimVectorT newContiguity; AxisInfo::DimVectorT newDivisibility; AxisInfo::DimVectorT newConstancy; - for (size_t d = 0; d < rank; d++) { + for (size_t d = 0; d < rank; ++d) { newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); @@ -166,7 +166,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; - for (size_t d = 0; d < retTy.getRank(); d++) { + for (size_t d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(1); divisibility.push_back(opInfo.getDivisibility(0)); constancy.push_back(retTy.getShape()[d]); @@ -202,7 +202,7 @@ ChangeResult AxisInfoAnalysis::visitOperation( AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; - for (size_t d = 0; d < retTy.getRank(); d++) { + for (size_t d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); diff --git a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp index 9eeb49995..5827b7301 100644 --- a/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp @@ -1,4 +1,6 @@ #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/raw_ostream.h" #include // unify to llvm::raw_string_ostream ? @@ -10,7 +12,7 @@ std::string strJoin(llvm::ArrayRef strs, llvm::StringRef delimiter) { std::string osStr; llvm::raw_string_ostream os(osStr); - for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++) + for (size_t i = 0; !strs.empty() && i < strs.size() - 1; ++i) os << strs[i] << delimiter; if (!strs.empty()) os << strs.back(); @@ -74,6 +76,25 @@ SmallVector PTXBuilder::getAllArgs() const { return res; } +mlir::Value PTXBuilder::launch(ConversionPatternRewriter &rewriter, + Location loc, Type resTy, bool hasSideEffect, + bool isAlignStack, + ArrayRef attrs) const { + auto *ctx = rewriter.getContext(); + auto inlineAsm = rewriter.create( + loc, resTy, getAllMLIRArgs(), // operands + dump(), // asm_string + getConstraints(), // constraints + hasSideEffect, // has_side_effects + isAlignStack, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, attrs) // operand_attrs + ); + + return inlineAsm.getRes(); +} + std::string PTXInstr::Operand::dump() const { if (repr) return repr(idx); @@ -151,5 +172,6 @@ PTXInstrExecution::getArgList() const { } return args; } + } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 4d1ba7330..742ba60d7 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -50,10 +50,28 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { IntegerAttr::get(i32ty, v)); } +// Create a index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + + TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + // Add other specification if needed... } // namespace +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) #define add(...) rewriter.create(loc, __VA_ARGS__) @@ -68,10 +86,19 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { rewriter.create(loc, __VA_ARGS__) #define extract_element(...) \ rewriter.create(loc, __VA_ARGS__) +#define load(...) rewriter.create(loc, __VA_ARGS__) +#define store(val, ptr) rewriter.create(loc, val, ptr) #define address_of(...) rewriter.create(loc, __VA_ARGS__) +#define i32_ty rewriter.getIntegerType(32) +#define vec_ty(type, num) VectorType::get(num, type) +// Creator for constant #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) -#define i32_ty() rewriter.getIntegerType(32) +#define int_val(width, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) +#define idx_val(...) \ + LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ + __VA_ARGS__) } // namespace LLVM } // namespace mlir @@ -215,7 +242,7 @@ struct FuncOpConversion : public FuncOpConversionBase { // Set an attribute for maxntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. newFuncOp->setAttr(NVVMMetadataField::MaxNTid, - rewriter.getIntegerAttr(i32_ty(), 32 * NumWarps)); + rewriter.getIntegerAttr(i32_ty, 32 * NumWarps)); rewriter.eraseOp(funcOp); return success(); @@ -247,19 +274,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { } }; -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - -static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, - LLVMTypeConverter *converter, Type ty, - int64_t value) { - return builder.create(loc, converter->convertType(ty), - builder.getIntegerAttr(ty, value)); -} - Value getStructFromElements(Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type structType) { @@ -272,42 +286,36 @@ Value getStructFromElements(Location loc, ValueRange resultVals, } template -static SmallVector getMultiDimIndex(T linear_index, ArrayRef shape) { - // sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} +static SmallVector getMultiDimIndex(T linearIndex, ArrayRef shape) { + // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} size_t rank = shape.size(); - T acc_mul = 1; - for (size_t i = 1; i < rank; ++i) { - acc_mul *= shape[i]; - } - T linear_remain = linear_index; - SmallVector multidim_index(rank); + T accMul = product(shape.drop_front()); + T linearRemain = linearIndex; + SmallVector multiDimIndex(rank); for (size_t i = 0; i < rank; ++i) { - multidim_index[i] = linear_remain / acc_mul; - linear_remain = linear_remain % acc_mul; + multiDimIndex[i] = linearRemain / accMul; + linearRemain = linearRemain % accMul; if (i != (rank - 1)) { - acc_mul = acc_mul / shape[i + 1]; + accMul = accMul / shape[i + 1]; } } - return multidim_index; + return multiDimIndex; } template -static T getLinearIndex(ArrayRef multidim_index, ArrayRef shape) { - assert(multidim_index.size() == shape.size()); - // sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1} +static T getLinearIndex(ArrayRef multiDimIndex, ArrayRef shape) { + assert(multiDimIndex.size() == shape.size()); + // shape: {a, b, c, d} -> accMul: {b*c*d, c*d, d, 1} size_t rank = shape.size(); - T acc_mul = 1; - for (size_t i = 1; i < rank; ++i) { - acc_mul *= shape[i]; - } - T linear_index = 0; + T accMul = product(shape.drop_front()); + T linearIndex = 0; for (size_t i = 0; i < rank; ++i) { - linear_index += multidim_index[i] * acc_mul; + linearIndex += multiDimIndex[i] * accMul; if (i != (rank - 1)) { - acc_mul = acc_mul / shape[i + 1]; + accMul = accMul / shape[i + 1]; } } - return linear_index; + return linearIndex; } struct ConvertTritonGPUOpToLLVMPatternBase { @@ -352,23 +360,15 @@ public: return threadId; } - Value createIndexConst(ConversionPatternRewriter &rewriter, Location loc, - int64_t value) const { - return rewriter.create( - loc, this->getTypeConverter()->getIndexType(), - rewriter.getIntegerAttr(rewriter.getIndexType(), value)); - } - + // Convert an \param index to a multi-dim coordinate given \param shape and + // \param order. SmallVector delinearize(ConversionPatternRewriter &rewriter, Location loc, Value linear, ArrayRef shape, ArrayRef order) const { unsigned rank = shape.size(); assert(rank == order.size()); - SmallVector reordered(rank); - for (unsigned i = 0; i < rank; ++i) { - reordered[i] = shape[order[i]]; - } + auto reordered = reorder(shape, order); auto reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); SmallVector multiDim(rank); for (unsigned i = 0; i < rank; ++i) { @@ -388,9 +388,7 @@ public: } else { Value remained = linear; for (auto &&en : llvm::enumerate(llvm::reverse(shape.drop_front()))) { - Value dimSize = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), - en.value()); + Value dimSize = idx_val(en.value()); multiDim[rank - 1 - en.index()] = urem(remained, dimSize); remained = udiv(remained, dimSize); } @@ -402,20 +400,19 @@ public: Value linearize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef multiDim, ArrayRef shape) const { int rank = multiDim.size(); - Value linear = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), 0); + Value linear = idx_val(0); if (rank > 0) { linear = multiDim.front(); - for (auto &&z : llvm::zip(multiDim.drop_front(), shape.drop_front())) { - Value dimSize = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), - std::get<1>(z)); - linear = add(mul(linear, dimSize), std::get<0>(z)); + for (auto [dim, shape] : + llvm::zip(multiDim.drop_front(), shape.drop_front())) { + Value dimSize = idx_val(shape); + linear = add(mul(linear, dimSize), dim); } } return linear; } + // Get an index-base for each dimension for a \param blocked_layout. SmallVector emitBaseIndexForBlockedLayout(Location loc, ConversionPatternRewriter &rewriter, @@ -423,7 +420,7 @@ public: ArrayRef shape) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); Value threadId = getThreadId(rewriter, loc); - Value warpSize = createIndexAttrConstant(rewriter, loc, llvmIndexTy, 32); + Value warpSize = idx_val(32); Value laneId = urem(threadId, warpSize); Value warpId = udiv(threadId, warpSize); auto sizePerThread = blocked_layout.getSizePerThread(); @@ -444,19 +441,13 @@ public: unsigned maxWarps = ceil(shape[k], sizePerThread[k] * threadsPerWarp[k]); unsigned maxThreads = ceil(shape[k], sizePerThread[k]); - multiDimWarpId[k] = - urem(multiDimWarpId[k], - createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxWarps)); - multiDimThreadId[k] = - urem(multiDimThreadId[k], - createIndexAttrConstant(rewriter, loc, llvmIndexTy, maxThreads)); + multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps)); + multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads)); // multiDimBase[k] = (multiDimThreadId[k] + // multiDimWarpId[k] * threadsPerWarp[k]) * // sizePerThread[k]; - Value threadsPerWarpK = createIndexAttrConstant( - rewriter, loc, llvmIndexTy, threadsPerWarp[k]); - Value sizePerThreadK = - createIndexAttrConstant(rewriter, loc, llvmIndexTy, sizePerThread[k]); + Value threadsPerWarpK = idx_val(threadsPerWarp[k]); + Value sizePerThreadK = idx_val(sizePerThread[k]); multiDimBase[k] = mul(sizePerThreadK, add(multiDimThreadId[k], mul(multiDimWarpId[k], threadsPerWarpK))); @@ -496,25 +487,22 @@ public: if (auto blockedParent = parent.dyn_cast()) { SmallVector paddedShape(rank + 1); for (unsigned d = 0; d < rank + 1; ++d) { - if (d < dim) { + if (d < dim) paddedShape[d] = shape[d]; - } else if (d == dim) { + else if (d == dim) paddedShape[d] = 1; - } else { + else paddedShape[d] = shape[d - 1]; - } } auto paddedIndices = emitIndicesForBlockedLayout( loc, rewriter, blockedParent, paddedShape); unsigned numIndices = paddedIndices.size(); SmallVector> resultIndices(numIndices); - for (unsigned i = 0; i < numIndices; ++i) { - for (unsigned d = 0; d < rank + 1; ++d) { - if (d != dim) { + for (unsigned i = 0; i < numIndices; ++i) + for (unsigned d = 0; d < rank + 1; ++d) + if (d != dim) resultIndices[i].push_back(paddedIndices[i][d]); - } - } - } + return resultIndices; } else if (auto sliceParent = parent.dyn_cast()) { @@ -529,7 +517,8 @@ public: } } - // Emit indices calculation within each ConversionPattern + // Emit indices calculation within each ConversionPattern, and returns a + // [elemsPerThread X rank] index matrix. // TODO: [goostavz] Double confirm the redundant indices calculations will // be eliminated in the consequent MLIR/LLVM optimization. We might // implement a indiceCache if necessary. @@ -542,23 +531,16 @@ public: auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); unsigned rank = shape.size(); - SmallVector shapePerCTA(rank); - for (unsigned k = 0; k < rank; ++k) { - shapePerCTA[k] = sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]; - } + SmallVector shapePerCTA = getShapePerCTA(blockedLayout); // step 1, delinearize threadId to get the base index auto multiDimBase = emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); // step 2, get offset of each element - unsigned elemsPerThread = 1; + unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape); SmallVector> offset(rank); - SmallVector multiDimElemsPerThread(rank); for (unsigned k = 0; k < rank; ++k) { - multiDimElemsPerThread[k] = - ceil(shape[k], shapePerCTA[k]) * sizePerThread[k]; - elemsPerThread *= multiDimElemsPerThread[k]; // 1 block in minimum if shape[k] is less than shapePerCTA[k] for (unsigned blockOffset = 0; blockOffset < ceil(shape[k], shapePerCTA[k]); @@ -574,34 +556,29 @@ public: threadsPerWarp[k] + threadOffset * sizePerThread[k] + elemOffset); } - // step 3, add offset to base, and reorder the sequence of indices, - // to guarantee that elems in a same sizePerThread are adjacent in - // order - SmallVector> multiDimIdx(elemsPerThread); - unsigned accumSizePerThread = - std::accumulate(sizePerThread.begin(), sizePerThread.end(), 1, - std::multiplies()); + // step 3, add offset to base, and reorder the sequence of indices to + // guarantee that elems in the same sizePerThread are adjacent in order + SmallVector> multiDimIdx(elemsPerThread, + SmallVector(rank)); + unsigned totalSizePerThread = product(sizePerThread); SmallVector threadsPerDim(rank); - for (unsigned k = 0; k < rank; ++k) { + for (unsigned k = 0; k < rank; ++k) threadsPerDim[k] = ceil(shape[k], sizePerThread[k]); - } + for (unsigned n = 0; n < elemsPerThread; ++n) { - unsigned linearNanoTileId = n / accumSizePerThread; - unsigned linearElemsInNanoTileId = n % accumSizePerThread; + unsigned linearNanoTileId = n / totalSizePerThread; + unsigned linearNanoTileElemId = n % totalSizePerThread; SmallVector multiDimNanoTileId = getMultiDimIndex(linearNanoTileId, threadsPerDim); - SmallVector multiElemsInNanoTileId = - getMultiDimIndex(linearElemsInNanoTileId, sizePerThread); - multiDimIdx[n].resize(rank); + SmallVector multiDimNanoTileElemId = + getMultiDimIndex(linearNanoTileElemId, sizePerThread); for (unsigned k = 0; k < rank; ++k) { unsigned reorderedMultiDimId = multiDimNanoTileId[k] * (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + - multiElemsInNanoTileId[k]; + multiDimNanoTileElemId[k]; multiDimIdx[n][k] = - add(multiDimBase[k], - createIndexAttrConstant(rewriter, loc, llvmIndexTy, - offset[k][reorderedMultiDimId])); + add(multiDimBase[k], idx_val(offset[k][reorderedMultiDimId])); } } @@ -617,7 +594,7 @@ public: assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); size_t offset = allocation->getOffset(bufferId); auto llvmIndexTy = this->getTypeConverter()->getIndexType(); - Value offVal = createIndexAttrConstant(rewriter, loc, llvmIndexTy, offset); + Value offVal = idx_val(offset); Value base = gep(ptrTy, smem, offVal); return base; } @@ -636,19 +613,17 @@ protected: Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { - auto tensorTy = resType.cast(); auto layout = tensorTy.getEncoding(); auto srcType = typeConverter->convertType(elemType); auto llSrc = bit_cast(srcType, constVal); - size_t numElemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); - llvm::SmallVector elems(numElemsPerThread, llSrc); + size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape()); + llvm::SmallVector elems(elemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); auto structTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); - auto llStruct = getStructFromElements(loc, elems, rewriter, structTy); - return llStruct; + return getStructFromElements(loc, elems, rewriter, structTy); } struct SplatOpConversion @@ -745,7 +720,7 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { assert(layout && "unexpected layout in getLayout"); auto shape = ty.getShape(); unsigned valueElems = layout.getElemsPerThread(shape); - return std::make_tuple(layout, valueElems); + return {layout, valueElems}; } unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const { @@ -864,14 +839,14 @@ struct StoreOpConversion llvm::SmallVector asmArgs; Type valArgTy = IntegerType::get(ctx, width); - auto wordTy = VectorType::get(wordNElems, valueElemTy); + auto wordTy = vec_ty(valueElemTy, wordNElems); auto *asmArgList = ptxBuilder.newListOperand(); - for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { + for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { // llWord is a width-len composition Value llWord = rewriter.create(loc, wordTy); // Insert each value element to the composition - for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) { + for (int elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; assert(elemOffset < valueElems.size()); Value elem = valueElems[elemOffset]; @@ -894,10 +869,7 @@ struct StoreOpConversion // TODO(Superjomn) Need to check masks before vectorize the load for all // the values share one predicate? Here assume all the mask values are // the same. - Value maskVal = - llMask ? maskElems[vecStart] - : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), - rewriter.getIntegerType(1), 1); + Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1); ptxStoreInstr.global().b(width).v(nWords); auto *asmAddr = @@ -906,22 +878,12 @@ struct StoreOpConversion ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b"); Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1)); llvm::SmallVector argTys({boolTy, ptr.getType()}); - for (int i = 0; i < nWords; i++) + for (int i = 0; i < nWords; ++i) argTys.push_back(valArgTy); auto ASMReturnTy = LLVM::LLVMVoidType::get(ctx); - auto inlineAsm = rewriter.create( - loc, ASMReturnTy, ptxBuilder.getAllMLIRArgs(), // operands - ptxBuilder.dump(), // asm_string - ptxBuilder.getConstraints(), // constraints - // TODO(Superjomn) determine the side effect. - true, // has_side_effects - false, // is_align_stack - LLVM::AsmDialectAttr::get(ctx, - LLVM::AsmDialect::AD_ATT), // asm_dialect - ArrayAttr::get(ctx, {}) // operand_attrs - ); + ptxBuilder.launch(rewriter, loc, ASMReturnTy); } rewriter.eraseOp(op); return success(); @@ -1183,10 +1145,7 @@ struct LoadOpConversion // TODO(Superjomn) Need to check masks before vectorize the load for all // the values share one predicate? Here assume all the mask values are // the same. - Value pred = - mask ? maskElems[vecStart] - : createLLVMIntegerConstant(rewriter, loc, getTypeConverter(), - rewriter.getIntegerType(1), 1); + Value pred = mask ? maskElems[vecStart] : int_val(1, 1); const std::string readConstraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); @@ -1195,7 +1154,7 @@ struct LoadOpConversion // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); - for (int wordIdx = 0; wordIdx < nWords; wordIdx++) { + for (int wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations dstsOpr->listAppend(opr); } @@ -1228,7 +1187,7 @@ struct LoadOpConversion SmallVector others; if (other) { - for (size_t ii = 0; ii < nWords; ii++) { + for (size_t ii = 0; ii < nWords; ++ii) { PTXInstr &mov = *ptxBuilder.create<>("mov"); mov.o("u", width); @@ -1236,7 +1195,7 @@ struct LoadOpConversion auto vecTy = LLVM::getFixedVectorType(valueElemTy, size); Value v = rewriter.create(loc, vecTy); - for (size_t s = 0; s < size; s++) { + for (size_t s = 0; s < size; ++s) { Value falseVal = otherElems[vecStart + ii * size + s]; Value sVal = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), s); @@ -1267,20 +1226,13 @@ struct LoadOpConversion // TODO: if (has_l2_evict_policy) auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); - auto inlineAsmOp = rewriter.create( - loc, retTy, /*operands=*/ptxBuilder.getAllMLIRArgs(), - /*asm_string=*/ptxBuilder.dump(), - /*constraints=*/ptxBuilder.getConstraints(), - /*has_side_effects=*/true, - /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); - Value ret = inlineAsmOp.getResult(0); + Value ret = ptxBuilder.launch(rewriter, loc, retTy); // --- // extract and store return values // --- SmallVector rets; - for (unsigned int ii = 0; ii < nWords; ii++) { + for (unsigned int ii = 0; ii < nWords; ++ii) { Value curr; if (retTy.isa()) { curr = extract_val(IntegerType::get(getContext(), width), ret, @@ -1293,8 +1245,8 @@ struct LoadOpConversion curr); rets.push_back(curr); } - int tmp = (width / valueElemNbits); - for (size_t ii = 0; ii < vec; ii++) { + int tmp = width / valueElemNbits; + for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % tmp); Value loaded = extract_element(valueElemTy, rets[ii / tmp], vecIdx); @@ -1480,6 +1432,7 @@ public: (!dstLayout.isa() && !dstLayout.isa())) { // TODO: to be implemented + llvm::errs() << "Unsupported ConvertLayout found"; return failure(); } auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); @@ -1494,11 +1447,11 @@ public: SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); + auto srcShapePerCTA = getShapePerCTA(srcLayout); + auto dstShapePerCTA = getShapePerCTA(dstLayout); for (unsigned d = 0; d < rank; ++d) { - unsigned inPerCTA = - std::min(unsigned(shape[d]), getShapePerCTA(srcLayout, d)); - unsigned outPerCTA = - std::min(unsigned(shape[d]), getShapePerCTA(dstLayout, d)); + unsigned inPerCTA = std::min(shape[d], srcShapePerCTA[d]); + unsigned outPerCTA = std::min(shape[d], dstShapePerCTA[d]); unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); numReplicates[d] = ceil(shape[d], maxPerCTA); inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; @@ -1579,9 +1532,8 @@ private: auto accumSizePerThread = product(sizePerThread); auto llvmIndexTy = getTypeConverter()->getIndexType(); SmallVector numCTAs(rank); - SmallVector shapePerCTA(rank); + auto shapePerCTA = getShapePerCTA(layout); for (unsigned d = 0; d < rank; ++d) { - shapePerCTA[d] = getShapePerCTA(layout, d); numCTAs[d] = ceil(type.getShape()[d], shapePerCTA[d]); } auto llvmElemTy = getTypeConverter()->convertType(type.getElementType()); @@ -1603,16 +1555,15 @@ private: Value warpSize = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), 32); Value laneId = rewriter.create(loc, threadId, warpSize); - Value fourVal = createIndexConst(rewriter, loc, 4); + Value fourVal = idx_val(4); mmaGrpId = rewriter.create(loc, laneId, fourVal); - mmaGrpIdP8 = rewriter.create( - loc, mmaGrpId, createIndexConst(rewriter, loc, 8)); + mmaGrpIdP8 = rewriter.create(loc, mmaGrpId, idx_val(8)); Value mmaThreadIdInGrp = rewriter.create(loc, laneId, fourVal); - mmaThreadIdInGrpM2 = rewriter.create( - loc, mmaThreadIdInGrp, createIndexConst(rewriter, loc, 2)); - mmaThreadIdInGrpM2P1 = rewriter.create( - loc, mmaThreadIdInGrpM2, createIndexConst(rewriter, loc, 1)); + mmaThreadIdInGrpM2 = + rewriter.create(loc, mmaThreadIdInGrp, idx_val(2)); + mmaThreadIdInGrpM2P1 = + rewriter.create(loc, mmaThreadIdInGrpM2, idx_val(1)); } for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { auto multiDimCTAInRepId = @@ -1654,7 +1605,7 @@ private: reorder(paddedRepShape, outOrd)); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); - auto vecTy = VectorType::get(vec, llvmElemTy); + auto vecTy = vec_ty(llvmElemTy, vec); ptr = bit_cast(LLVM::LLVMPointerType::get(vecTy, 3), ptr); if (stNotRd) { Value valVec = rewriter.create(loc, vecTy); @@ -1665,9 +1616,9 @@ private: vecTy, valVec, vals[elemId + linearCTAId * accumSizePerThread + v], vVal); } - rewriter.create(loc, valVec, ptr); + store(valVec, ptr); } else { - Value valVec = rewriter.create(loc, ptr); + Value valVec = load(ptr); for (unsigned v = 0; v < vec; ++v) { Value vVal = createIndexAttrConstant( rewriter, loc, getTypeConverter()->getIndexType(), v); @@ -1682,6 +1633,7 @@ private: /// ====================== dot codegen begin ========================== +// Data loader for mma.16816 instruction. class MMA16816SmemLoader { public: MMA16816SmemLoader(int wpt, ArrayRef order, int kOrder, @@ -1689,8 +1641,10 @@ public: ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, const Location &loc) - : wpt(wpt), order(order), kOrder(kOrder), tileShape(tileShape), - instrShape(instrShape), matShape(matShape), perPhase(perPhase), + : wpt(wpt), order(order.begin(), order.end()), kOrder(kOrder), + tileShape(tileShape.begin(), tileShape.end()), + instrShape(instrShape.begin(), instrShape.end()), + matShape(matShape.begin(), matShape.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), typeConverter(typeConverter), loc(loc), ctx(rewriter.getContext()) { cMatShape = matShape[order[0]]; @@ -1722,7 +1676,7 @@ public: loadStrideInMat[kOrder] = 2; // instrShape[kOrder] / matShape[kOrder], always 2 loadStrideInMat[kOrder ^ 1] = - wpt * (instrShape[order[1]] / matShape[order[1]]); + wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]); pLoadStrideInMat = loadStrideInMat[order[0]]; sMatStride = @@ -1753,8 +1707,6 @@ public: // Compute the offset to the matrix this thread(indexed by warpOff and lane) // mapped to. SmallVector computeLdmatrixMatOffs(Value warpId, Value lane) { - MLIRContext *ctx = warpId.getContext(); - // 4x4 matrices Value c = urem(lane, i32_val(8)); Value s = udiv(lane, i32_val(8)); // sub-warp-id @@ -1895,6 +1847,7 @@ public: int k = matIdx[kOrder]; int ptrIdx{-1}; + if (canUseLdmatrix) ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); else if (elemBytes == 4 && needTrans) // tf32 & trans @@ -1904,7 +1857,9 @@ public: else llvm::report_fatal_error("unsupported mma type found"); - // prefetch logic removed here. + // The main difference with the original triton code is we removed the + // prefetch-related logic here for the upstream optimizer phase should take + // care with it, and that is transparent in dot conversion. auto getPtr = [&](int idx) { return ptrs[idx]; }; Value ptr = getPtr(ptrIdx); @@ -1915,11 +1870,8 @@ public: matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes; PTXBuilder builder; - auto resArgs = builder.newListOperand(); - // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread. - for (int i = 0; i < 4; i++) - resArgs->listAppend(builder.newOperand("=r")); + auto resArgs = builder.newListOperand(4, "=r"); auto addrArg = builder.newAddrOperand(ptr, "r", sOffset); auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") @@ -1927,46 +1879,127 @@ public: .o("shared.b16"); ldmatrix(resArgs, addrArg); - auto inlineAsm = rewriter.create( - loc, ldmatrixRetTy, builder.getAllMLIRArgs(), // operands - builder.dump(), // asm_string - builder.getConstraints(), // constraints - true, // has_side_effects - false, // is_align_stack - LLVM::AsmDialectAttr::get(ctx, - LLVM::AsmDialect::AD_ATT), // asm_dialect - ArrayAttr::get(ctx, {}) // operand_attrs - ); + // The result type is 4xi32, each i32 is composed of 2xf16 + // elements(adjacent two columns in a row) + Value resV4 = builder.launch(rewriter, loc, ldmatrixRetTy); auto getIntAttr = [&](int v) { - return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)}); + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; - Value resV4 = inlineAsm.getRes(); // 4xi32, each is composed of 2xf16 - // elements(adjacent columns in a row) + Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx)); - - return std::make_tuple(extract_val(fp16x2Ty, resV4, getIntAttr(0)), - extract_val(fp16x2Ty, resV4, getIntAttr(1)), - extract_val(fp16x2Ty, resV4, getIntAttr(2)), - extract_val(fp16x2Ty, resV4, getIntAttr(3))); + return {extract_val(fp16x2Ty, resV4, getIntAttr(0)), + extract_val(fp16x2Ty, resV4, getIntAttr(1)), + extract_val(fp16x2Ty, resV4, getIntAttr(2)), + extract_val(fp16x2Ty, resV4, getIntAttr(3))}; } else if (elemBytes == 4 && needTrans) { // Use lds.32 to load tf32 matrices - assert(false && "Not implemented yet"); + Value ptr2 = getPtr(ptrIdx + 1); + assert(sMatStride == 1); + int sOffsetElem = + matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride; + int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride; + + Value elems[4]; + Type elemTy = type::f32Ty(ctx); + if (kOrder == 1) { + elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); + elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); + elems[2] = + load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + elems[3] = + load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + } else { + elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); + elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); + elems[1] = + load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + elems[3] = + load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + } + + return {elems[0], elems[1], elems[2], elems[3]}; } else if (elemBytes == 1 && needTrans) { - assert(false && "Not implemented yet"); + std::array, 2> ptrs; + ptrs[0] = { + getPtr(ptrIdx), + getPtr(ptrIdx + 1), + getPtr(ptrIdx + 2), + getPtr(ptrIdx + 3), + }; + + ptrs[1] = { + getPtr(ptrIdx + 4), + getPtr(ptrIdx + 5), + getPtr(ptrIdx + 6), + getPtr(ptrIdx + 7), + }; + + assert(sMatStride == 1); + int sOffsetElem = + matIdx[order[1]] * (sMatStride * sMatShape) * sTileStride; + int sOffsetArrElem = 1 * (sMatStride * sMatShape) * sTileStride; + + std::array i8v4Elems; + std::array i32Elems; + i8v4Elems.fill( + rewriter.create(loc, vec_ty(type::i8Ty(ctx), 4))); + + Value i8Elems[4][4]; + Type elemTy = type::i8Ty(ctx); + if (kOrder == 1) { + Value offset = i32_val(sOffsetElem); + + for (int i = 0; i < 2; ++i) + for (int j = 0; j < 4; ++j) + i8Elems[i][j] = load(gep(elemTy, ptrs[i][j], offset)); + + offset = i32_val(sOffsetElem + sOffsetArrElem); + for (int i = 2; i < 4; ++i) + for (int j = 0; j < 4; ++j) + i8Elems[i][j] = load(gep(elemTy, ptrs[i - 2][j], offset)); + + for (int m = 0; m < 4; ++m) { + for (int e = 0; e < 4; ++e) + i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], + i8Elems[m][e], i32_val(e)); + i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]); + } + } else { // k first + Value offset = i32_val(sOffsetElem); + for (int j = 0; j < 4; ++j) + i8Elems[0][j] = load(gep(elemTy, ptrs[0][j], offset)); + for (int j = 0; j < 4; ++j) + i8Elems[2][j] = load(gep(elemTy, ptrs[1][j], offset)); + offset = i32_val(sOffsetElem + sOffsetArrElem); + for (int j = 0; j < 4; ++j) + i8Elems[1][j] = load(gep(elemTy, ptrs[0][j], offset)); + for (int j = 0; j < 4; ++j) + i8Elems[3][j] = load(gep(elemTy, ptrs[1][j], offset)); + + for (int m = 0; m < 4; ++m) { + for (int e = 0; e < 4; ++e) + i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], + i8Elems[m][e], i32_val(e)); + i32Elems[m] = bit_cast(i32_ty, i8v4Elems[m]); + } + } + + return {i32Elems[0], i32Elems[1], i32Elems[2], i32Elems[3]}; } - return std::make_tuple(Value{}, Value{}, Value{}, Value{}); + + assert(false && "Invalid smem load"); + return {Value{}, Value{}, Value{}, Value{}}; } private: int wpt; - ArrayRef order; + SmallVector order; int kOrder; - ArrayRef tileShape; - ArrayRef instrShape; - ArrayRef matShape; + SmallVector tileShape; + SmallVector instrShape; + SmallVector matShape; int perPhase; int maxPhase; int elemBytes; @@ -2157,8 +2190,8 @@ struct DotOpConversionHelper { // The type of a matrix that loaded by either a ldmatrix or composed lds. Type getMatType() const { Type fp32Ty = type::f32Ty(ctx); - Type fp16x2Ty = VectorType::get({2}, type::f16Ty(ctx)); - Type bf16x2Ty = VectorType::get({2}, type::bf16Ty(ctx)); + Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); + Type bf16x2Ty = vec_ty(type::bf16Ty(ctx), 2); // floating point types Type fp16x2Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); @@ -2167,7 +2200,7 @@ struct DotOpConversionHelper { Type fp32Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); // integer types - Type i8x4Ty = VectorType::get({4}, type::i8Ty(ctx)); + Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); Type i8x4Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); Type i32Pack4Ty = LLVM::LLVMStructType::getLiteral( @@ -2189,6 +2222,23 @@ struct DotOpConversionHelper { return Type{}; } + Type getLoadElemTy() { + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return vec_ty(type::f16Ty(ctx), 2); + case TensorCoreType::FP32_BF16_BF16_FP32: + return vec_ty(type::bf16Ty(ctx), 2); + case TensorCoreType::FP32_TF32_TF32_FP32: + return type::f32Ty(ctx); + case TensorCoreType::INT32_INT8_INT8_INT32: + return type::i32Ty(ctx); + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; + } + Type getMmaRetType() const { Type fp32Ty = type::f32Ty(ctx); Type i32Ty = type::i32Ty(ctx); @@ -2375,9 +2425,10 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, const int numRepN = std::max(dShape[1] / (wpt[1] * mmaInstrN), 1); const int numRepK = std::max(NK / mmaInstrK, 1); - Value head = getThreadId(rewriter, loc); - Value lane = urem(head, i32_val(32)); - Value warp = udiv(head, i32_val(32)); + Value _32 = i32_val(32); + Value thread = getThreadId(rewriter, loc); + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); Value warpMN = udiv(warp, i32_val(wpt[0])); Value warpM = urem(warp, i32_val(wpt[0])); Value warpN = urem(warpMN, i32_val(wpt[1])); @@ -2389,7 +2440,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, std::map, Value> hb; // the original register_lds2, but discard the prefetch logic. - auto ld2 = [&](decltype(ha) &vals, int mn, int k, Value val) { + auto ld2 = [](decltype(ha) &vals, int mn, int k, Value val) { vals[{mn, k}] = val; }; @@ -2405,6 +2456,7 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, const int perPhase = sharedLayout.getPerPhase(); const int maxPhase = sharedLayout.getMaxPhase(); const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; + auto order = sharedLayout.getOrder(); MMA16816SmemLoader loader(wpt, sharedLayout.getOrder(), kOrder, tensorTy.getShape() /*tileShape*/, instrShape, @@ -2417,34 +2469,56 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, Type smemPtrTy = helper.getShemPtrTy(); auto smemBase = getSharedMemoryBase(loc, rewriter, tensor); - for (int i = 0; i < numPtrs; i++) { + for (int i = 0; i < numPtrs; ++i) { ptrs[i] = bit_cast( smemPtrTy, gep(smemBase.getType(), smemBase, ValueRange({offs[i]}))); } + bool needTrans = kOrder != order[0]; + // (a, b) is the coordinate. - auto load = [&, loader, ptrs, offs](int a, int b) { + auto load = [&, loader, ptrs, offs, needTrans](int a, int b) { auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, ptrs, helper.getMatType(), helper.getShemPtrTy()); - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha1); - ld2(vals, a, b + 1, ha2); - ld2(vals, a + 1, b + 1, ha3); + if (!needTrans) { + ld2(vals, a, b, ha0); + ld2(vals, a + 1, b, ha1); + ld2(vals, a, b + 1, ha2); + ld2(vals, a + 1, b + 1, ha3); + } else { + ld2(vals, a, b, ha0); + ld2(vals, a + 1, b, ha2); + ld2(vals, a, b + 1, ha1); + ld2(vals, a + 1, b + 1, ha3); + } }; return load; }; - std::function loadA = getLoadMatrixFn( - A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, 1 /*kOrder*/, - {mmaInstrM, mmaInstrK} /*instrShpae*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/); + std::function loadA; std::function loadB = getLoadMatrixFn( B, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShpae*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/); + if (aTensorTy.getEncoding() + .dyn_cast()) { // load from smem + loadA = getLoadMatrixFn(A, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShpae*/, + {matShapeM, matShapeK} /*matShape*/, + warpM /*warpId*/, ha /*vals*/); + } else if (auto blockedLayout = + aTensorTy.getEncoding() + .dyn_cast()) { // load from registers, + // used in gemm fuse + // TODO(Superjomn) Port the logic. + assert(false && "Loading A from register is not supported yet."); + } else { + assert(false && "A's layout is not supported."); + } + const unsigned mStride = numRepN * 2; SmallVector fc(numRepM * mStride + numRepN * 2); auto callMma = [&](unsigned m, unsigned n, unsigned k) { @@ -2452,44 +2526,36 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, auto &mma = *builder.create(helper.getMmaInstr().str()); - auto retArgs = builder.newListOperand(); - for (int i = 0; i < 4; ++i) - retArgs->listAppend(builder.newOperand("=r")); - auto aArg0 = builder.newOperand(ha[{m, k}], "r"); - auto aArg1 = builder.newOperand(ha[{m + 1, k}], "r"); - auto aArg2 = builder.newOperand(ha[{m, k + 1}], "r"); - auto aArg3 = builder.newOperand(ha[{m + 1, k}], "r"); + auto retArgs = builder.newListOperand(4, "=r"); - auto bArg0 = builder.newOperand(ha[{n, k}], "r"); - auto bArg1 = builder.newOperand(ha[{n, k + 1}], "r"); + auto aArgs = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + + auto bArgs = + builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); // Currently, we only support a SplatLike C. For the other cases, e.g., C in // shared layout or blocked layout, we will support them by expanding // convert_layout. auto hc = helper.loadSplatLikeC(C, loc, rewriter); assert(hc.size() == 4UL && "Only splat-like C is supported now"); - auto cArg0 = builder.newOperand(hc[0], "0"); // reuse the output registers - auto cArg1 = builder.newOperand(hc[1], "1"); - auto cArg2 = builder.newOperand(hc[2], "2"); - auto cArg3 = builder.newOperand(hc[3], "3"); - mma({retArgs, aArg0, aArg1, aArg2, aArg3, bArg0, bArg1, cArg0, cArg1, cArg2, - cArg3}); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < hc.size(); ++i) { + cArgs->listAppend(builder.newOperand( + hc[i], std::to_string(i))); // reuse the output registers + } - auto inlineAsm = rewriter.create( - loc, helper.getMmaRetType(), builder.getAllMLIRArgs(), // operands - builder.dump(), // asm_string - builder.getConstraints(), // constraints - true, // has_side_effects - false, // is_align_stack - LLVM::AsmDialectAttr::get(ctx, - LLVM::AsmDialect::AD_ATT), // asm_dialect - ArrayAttr::get(ctx, {}) // operand_attrs - ); + mma(retArgs, aArgs, bArgs, cArgs); + + Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); - auto mmaOut = inlineAsm.getRes(); auto getIntAttr = [&](int v) { - return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty(), v)}); + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; fc[(m + 0) * mStride + (n * 2 + 0)] = @@ -2504,13 +2570,13 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adapter, // Main program - for (unsigned k = 0; k < numRepK; k++) { - for (unsigned m = 0; m < numRepM; m++) + for (unsigned k = 0; k < numRepK; ++k) { + for (unsigned m = 0; m < numRepM; ++m) loadA(2 * m, 2 * k); for (unsigned n = 0; n < numRepN; n += 2) loadB(n, 2 * k); - for (unsigned m = 0; m < numRepM; m++) - for (unsigned n = 0; n < numRepN; n++) { + for (unsigned m = 0; m < numRepM; ++m) + for (unsigned n = 0; n < numRepN; ++n) { callMma(2 * m, n, 2 * k); } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index bcfa3176f..9050b0485 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -72,26 +72,24 @@ SmallVector getSizePerThread(Attribute layout) { } } -unsigned getShapePerCTA(const Attribute &layout, unsigned d) { +SmallVector getShapePerCTA(const Attribute &layout) { + SmallVector shape; if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getSizePerThread()[d] * - blockedLayout.getThreadsPerWarp()[d] * - blockedLayout.getWarpsPerCTA()[d]; + for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) + shape.push_back(blockedLayout.getSizePerThread()[d] * + blockedLayout.getThreadsPerWarp()[d] * + blockedLayout.getWarpsPerCTA()[d]); } else if (auto mmaLayout = layout.dyn_cast()) { assert(mmaLayout.getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); - assert(d < 2 && "Unexpected usage of getShapePerCTA"); - if (d == 0) { - return 16 * mmaLayout.getWarpsPerCTA()[0]; - } else { - // d == 1 - return 8 * mmaLayout.getWarpsPerCTA()[1]; - } + return {16 * mmaLayout.getWarpsPerCTA()[0], + 8 * mmaLayout.getWarpsPerCTA()[1]}; } else { assert(0 && "Unimplemented usage of getShapePerCTA"); - return 0; } -}; + + return shape; +} SmallVector getOrder(const Attribute &layout) { if (auto blockedLayout = layout.dyn_cast()) { @@ -106,7 +104,7 @@ SmallVector getOrder(const Attribute &layout) { assert(0 && "Unimplemented usage of getOrder"); return {}; } -}; +} } // namespace gpu } // namespace triton @@ -180,16 +178,17 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef shape) const { size_t rank = shape.size(); - assert(rank == getSizePerThread().size() && + auto sizePerThread = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + auto threadsPerWarp = getThreadsPerWarp(); + assert(rank == sizePerThread.size() && "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); - SmallVector elemsPerThreadPerDim(rank); + SmallVector elemsPerThread(rank); for (size_t i = 0; i < rank; ++i) { - unsigned t = - getSizePerThread()[i] * getThreadsPerWarp()[i] * getWarpsPerCTA()[i]; - elemsPerThreadPerDim[i] = - ceil(shape[i], t) * getSizePerThread()[i]; + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + elemsPerThread[i] = ceil(shape[i], t) * sizePerThread[i]; } - return product(elemsPerThreadPerDim); + return product(elemsPerThread); } unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { @@ -216,11 +215,9 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { } unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { - size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of mma layout"); - unsigned elemsCol = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; - unsigned elemsRow = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; - return elemsCol * elemsRow; + int threads = product(getWarpsPerCTA()); + int numElem = product(shape); + return numElem / threads; } unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { diff --git a/unittest/Analysis/CMakeLists.txt b/unittest/Analysis/CMakeLists.txt index 4db4a37af..880c8117b 100644 --- a/unittest/Analysis/CMakeLists.txt +++ b/unittest/Analysis/CMakeLists.txt @@ -1,5 +1,5 @@ add_triton_ut( - NAME TritonAnalysisTests + NAME TestTritonAnalysis SRCS UtilityTest.cpp LIBS TritonAnalysis ) diff --git a/unittest/Analysis/UtilityTest.cpp b/unittest/Analysis/UtilityTest.cpp index 69a7119a4..2d25a8803 100644 --- a/unittest/Analysis/UtilityTest.cpp +++ b/unittest/Analysis/UtilityTest.cpp @@ -4,11 +4,26 @@ //===----------------------------------------------------------------------===// #include "triton/Analysis/Utility.h" -#include #include namespace mlir { -TEST(UtilityTest, DummyTest) { EXPECT_EQ(true, true); } +TEST(Analysis, reorder) { + SmallVector shape({10, 20, 30}); + { + SmallVector order({2, 1, 0}); + auto reordered = reorder(shape, order); + EXPECT_EQ(reordered[0], 30); + EXPECT_EQ(reordered[1], 20); + EXPECT_EQ(reordered[2], 10); + } + { + SmallVector order({1, 0, 2}); + auto reordered = reorder(shape, order); + EXPECT_EQ(reordered[0], 20); + EXPECT_EQ(reordered[1], 10); + EXPECT_EQ(reordered[2], 30); + } +} } // namespace mlir diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt index 1cd7e9dc5..c9f495cf4 100644 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,5 +1,5 @@ add_triton_ut( - NAME PtxAsmFormatTest + NAME TestPtxAsmFormat SRCS PtxAsmFormatTest.cpp LIBS TritonGPUToLLVM )