From 031c2ae77b7b808b92647e4dd4767602793f40a7 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 1 Nov 2022 09:42:14 +0800 Subject: [PATCH] [Triton-MLIR][BACKEND] Port the mma conversion (#815) This PR does - port the mma related code, and support dot conversion and convert_layout[shared->dot_op>] - add a lit test for dot v1 --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 736 +++++++++++++++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 21 +- test/Conversion/tritongpu_to_llvm.mlir | 25 + 3 files changed, 690 insertions(+), 92 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d1fcdb92f..f1641f05c 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -56,6 +56,12 @@ Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { IntegerAttr::get(i32ty, v)); } +Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + // Create a index type constant. Value createIndexConstant(OpBuilder &builder, Location loc, @@ -90,7 +96,8 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define fmin(...) rewriter.create(loc, __VA_ARGS__) #define and_(...) rewriter.create(loc, __VA_ARGS__) #define xor_(...) rewriter.create(loc, __VA_ARGS__) -#define bitcast(...) rewriter.create(loc, __VA_ARGS__) +#define bitcast(val__, type__) \ + rewriter.create(loc, type__, val__) #define gep(...) rewriter.create(loc, __VA_ARGS__) #define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) #define insert_val(...) rewriter.create(loc, __VA_ARGS__) @@ -112,11 +119,13 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define i32_ty rewriter.getIntegerType(32) +#define f16_ty rewriter.getF16Type() #define i8_ty rewriter.getIntegerType(8) #define f32_ty rewriter.getF32Type() #define vec_ty(type, num) VectorType::get(num, type) +#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) -#define struct_ty(...) LLVM::LLVMStructType::getLiteral(__VA_ARGS__) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) // Creator for constant #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) @@ -685,7 +694,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, if (tensorTy.getEncoding().isa()) { auto tensorTy = resType.cast(); auto srcType = typeConverter->convertType(elemType); - auto llSrc = bitcast(srcType, constVal); + auto llSrc = bitcast(constVal, srcType); size_t elemsPerThread = getElemsPerThread(tensorTy); llvm::SmallVector elems(elemsPerThread, llSrc); llvm::SmallVector elemTypes(elems.size(), srcType); @@ -964,7 +973,7 @@ struct LoadOpConversion rewriter, loc, this->getTypeConverter()->getIndexType(), s); v = insert_element(vecTy, v, falseVal, sVal); } - v = bitcast(IntegerType::get(getContext(), width), v); + v = bitcast(v, IntegerType::get(getContext(), width)); PTXInstr::Operand *opr{}; if (otherIsSplatConstInt) @@ -1001,9 +1010,8 @@ struct LoadOpConversion } else { curr = ret; } - curr = bitcast( - LLVM::getFixedVectorType(valueElemTy, width / valueElemNbits), - curr); + curr = bitcast(curr, LLVM::getFixedVectorType(valueElemTy, + width / valueElemNbits)); rets.push_back(curr); } int tmp = width / valueElemNbits; @@ -1101,7 +1109,7 @@ struct StoreOpConversion Value elem = valueElems[elemOffset]; if (elem.getType().isInteger(1)) elem = rewriter.create(loc, type::i8Ty(ctx), elem); - elem = bitcast(valueElemTy, elem); + elem = bitcast(elem, valueElemTy); Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); llWord = @@ -1109,7 +1117,7 @@ struct StoreOpConversion rewriter.create( loc, u32Ty, IntegerAttr::get(u32Ty, elemIdx))); } - llWord = bitcast(valArgTy, llWord); + llWord = bitcast(llWord, valArgTy); std::string constraint = (width == 64) ? "l" : ((width == 32) ? "r" : "c"); asmArgs.emplace_back(llWord, constraint); @@ -1328,7 +1336,7 @@ Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, if (bits == 64) { Type vecTy = vec_ty(f32_ty, 2); - Value vec = bitcast(vecTy, val); + Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); val0 = shflSync(rewriter, loc, val0, i); @@ -1336,7 +1344,7 @@ Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter, vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); - return bitcast(val.getType(), vec); + return bitcast(vec, val.getType()); } PTXBuilder builder; @@ -1363,7 +1371,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(elemPtrTy, smemBase); + smemBase = bitcast(smemBase, elemPtrTy); auto smemShape = getScratchConfigForReduce(op); @@ -1430,7 +1438,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic( barrier(); SmallVector resultVals(resultElems); - for (size_t i = 0; i < resultElems; i++) { + for (unsigned i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, ints[0]); Value readOffset = linearize(rewriter, loc, readIdx, smemShape); @@ -1469,7 +1477,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(elemPtrTy, smemBase); + smemBase = bitcast(smemBase, elemPtrTy); auto order = srcLayout.getOrder(); unsigned sizeIntraWarps = threadsPerWarp[axis]; @@ -1569,7 +1577,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast( barrier(); SmallVector resultVals(resultElems); - for (size_t i = 0; i < resultElems; i++) { + for (size_t i = 0; i < resultElems; ++i) { SmallVector readIdx = resultIndices[i]; readIdx.insert(readIdx.begin() + axis, i32_val(0)); Value readOffset = linearize(rewriter, loc, readIdx, smemShape); @@ -2136,7 +2144,7 @@ void ConvertLayoutOpConversion::processReplica( auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); - ptr = bitcast(ptr_ty(vecTy, 3), ptr); + ptr = bitcast(ptr, ptr_ty(vecTy, 3)); if (stNotRd) { Value valVec = undef(vecTy); for (unsigned v = 0; v < vec; ++v) { @@ -2175,7 +2183,7 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( auto llvmElemTy = getTypeConverter()->convertType(dstTy.getElementType()); Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); - smemBase = bitcast(elemPtrTy, smemBase); + smemBase = bitcast(smemBase, elemPtrTy); auto shape = dstTy.getShape(); unsigned rank = dstTy.getRank(); SmallVector numReplicates(rank); @@ -2234,7 +2242,8 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed( } SmallVector types(outElems, llvmElemTy); - Type structTy = struct_ty(getContext(), types); + auto *ctx = llvmElemTy.getContext(); + Type structTy = struct_ty(types); Value result = getStructFromElements(loc, outVals, rewriter, structTy); rewriter.replaceOp(op, result); @@ -2294,7 +2303,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( Value minVecVal = idx_val(minVec); Value smemBase = getSharedMemoryBase(loc, rewriter, dst); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); - smemBase = bitcast(elemPtrTy, smemBase); + smemBase = bitcast(smemBase, elemPtrTy); unsigned numWordsEachRep = product(wordsInEachRep); SmallVector wordVecs(numWordsEachRep); // TODO: We should get less barriers if it is handled by membar pass @@ -2350,7 +2359,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( // step 3: store Value smemAddr = gep(elemPtrTy, smemBase, offset); - smemAddr = bitcast(ptr_ty(wordTy, 3), smemAddr); + smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); store(wordVecs[linearWordIdx], smemAddr); } } @@ -2693,7 +2702,7 @@ public: for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); - i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]); + i32Elems[m] = bitcast(i8v4Elems[m], i32_ty); } } else { // k first Value offset = i32_val(sOffsetElem); @@ -2711,7 +2720,7 @@ public: for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); - i32Elems[m] = bitcast(i32_ty, i8v4Elems[m]); + i32Elems[m] = bitcast(i8v4Elems[m], i32_ty); } } @@ -2823,10 +2832,7 @@ private: ConversionPatternRewriter &rewriter) const; /// Convert to mma.m8n8k4 LogicalResult convertMMA884(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(false && "Not implemented yet."); - return failure(); - } + ConversionPatternRewriter &rewriter) const; LogicalResult convertFMADot(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -2835,48 +2841,127 @@ private: } }; -struct DotOpConversionHelper { +// Helper for conversion of DotOp with mma, that is sm<80 +struct DotOpMmaV1ConversionHelper { + MmaEncodingAttr mmaLayout; + ArrayRef wpt; + + using ValueTable = std::map, std::pair>; + + explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout) + : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} + + int getRepM(int M) const { + return std::max(M / (wpt[0] * instrShape[0]), 1); + } + int getRepN(int N) const { + return std::max(N / (wpt[1] * instrShape[1]), 1); + } + int getRepK(int K) const { return std::max(K / instrShape[2], 1); } + + static ArrayRef getMmaInstrShape() { return instrShape; } + + static Type getMmaRetType(TensorType operand) { + auto *ctx = operand.getContext(); + Type fp32Ty = type::f32Ty(ctx); + // f16*f16+f32->f32 + return struct_ty(SmallVector{8, fp32Ty}); + } + + // number of fp16x2 elements for $a. + int numElemsPerThreadA(RankedTensorType tensorTy) const { + auto shape = tensorTy.getShape(); + auto order = getOrder(); + + bool isARow = order[0] != 0; + bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes + int packSize0 = (isARow || isAVec4) ? 1 : 2; + + SmallVector fpw({2, 2, 1}); + int repM = 2 * packSize0; + int repK = 1; + int spwM = fpw[0] * 4 * repM; + SmallVector rep({repM, 0, repK}); // pad N with 0 + SmallVector spw({spwM, 0, 1}); // pad N with 0 + + int NK = shape[1]; + unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); + + // NOTE We cound't get the vec from the shared layout. + // int vecA = sharedLayout.getVec(); + // TODO[Superjomn]: Consider the case when vecA > 4 + bool vecGt4 = false; + int elemsPerLd = vecGt4 ? 4 : 2; + return (numM / 2) * (NK / 4) * elemsPerLd; + } + + // number of fp16x2 elements for $b. + int numElemsPerThreadB(RankedTensorType tensorTy) const { + auto shape = tensorTy.getShape(); + auto order = getOrder(); + bool isBRow = order[0] != 0; + bool isBVec4 = isBRow && shape[order[0]] <= 16; + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + SmallVector fpw({2, 2, 1}); + SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 + SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 + // NOTE We cound't get the vec from the shared layout. + // int vecB = sharedLayout.getVec(); + // TODO[Superjomn]: Consider the case when vecA > 4 + bool vecGt4 = false; + int elemsPerLd = vecGt4 ? 4 : 2; + int NK = shape[0]; + + unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); + return (numN / 2) * (NK / 4) * elemsPerLd; + } + + // Loading $a from smem to registers, returns a LLVM::Struct. + Value loadA(Value A, Value llA, Value thread, Value smem, Location loc, + ConversionPatternRewriter &rewriter) const; + + // Loading $b from smem to registers, returns a LLVM::Struct. + Value loadB(Value B, Value llB, Value thread, Value smem, Location loc, + ConversionPatternRewriter &rewriter) const; + + // Loading $c to registers, returns a LLVM::Struct. + Value loadC(Value C, Value llC, ConversionPatternRewriter &rewriter) const; + + static ArrayRef getOrder() { return mmaOrder; } + + // Compute the offset of the matrix to load. + // Returns offsetAM, offsetAK, offsetBN, offsetBK. + // NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at + // the same time in the usage in convert_layout[shared->dot_op], we leave the + // noexist info to be 0 and only use the desired argument from the composed + // result. In this way we want to retain the original code structure in + // convert_mma884 method for easier debugging. + std::tuple + computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, + ArrayRef spw, ArrayRef rep, + ConversionPatternRewriter &rewriter, Location loc) const; + + // Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1. + ValueTable extractLoadedOperand(Value llStruct, int n0, int n1, + ConversionPatternRewriter &rewriter) const; + +private: + static constexpr unsigned instrShape[] = {16, 16, 4}; + static constexpr unsigned mmaOrder[] = {0, 1}; +}; + +// Helper for conversion of DotOp with mma, that is sm>=80 +struct DotOpMmaV2ConversionHelper { using TensorCoreType = DotOpConversion::TensorCoreType; MmaEncodingAttr mmaLayout; MLIRContext *ctx{}; - explicit DotOpConversionHelper(MmaEncodingAttr mmaLayout) + explicit DotOpMmaV2ConversionHelper(MmaEncodingAttr mmaLayout) : mmaLayout(mmaLayout) { ctx = mmaLayout.getContext(); } - // Load SplatLike C which contains a constVal. It simply returns 4 fp32 - // constVal. - SmallVector loadSplatLikeC(Value C, Location loc, - ConversionPatternRewriter &rewriter) const { - assert(isSplatLike(C)); - - int numRes = getMmaInstrShape()[0] * getMmaInstrShape()[1] / 32; - if (auto constv = llvm::dyn_cast(C.getDefiningOp())) { - if (auto attr = constv.getValue().dyn_cast()) { - Type elemType = attr.getElementType(); - if (elemType.isInteger(32)) { - int v = attr.getSplatValue(); - return SmallVector(numRes, i32_val(v)); - } else if (elemType.isInteger(8)) { - int v = attr.getSplatValue(); - auto newv = rewriter.create( - loc, elemType, IntegerAttr::get(elemType, v)); - return SmallVector(numRes, newv); - } else if (elemType.isF32()) { - int v = attr.getSplatValue(); - auto newv = rewriter.create( - loc, elemType, FloatAttr::get(elemType, v)); - return SmallVector(numRes, newv); - } - } - } - - assert(false && "Not supported type."); - return {}; - } - void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); } void deduceMmaType(Type operandTy) const { mmaType = getTensorCoreTypeFromOperand(operandTy); @@ -2884,8 +2969,8 @@ struct DotOpConversionHelper { // Get the M and N of mat instruction shape. static std::tuple getMatShapeMN() { - // According to DotOpConversionHelper::mmaMatShape, all the matrix shape's - // M,N are {8,8} + // According to DotOpMmaV2ConversionHelper::mmaMatShape, all the matrix + // shape's M,N are {8,8} return {8, 8}; } @@ -3143,7 +3228,7 @@ struct MMA16816ConversionHelper { Value thread, lane, warp, warpMN, warpN, warpM; - DotOpConversionHelper helper; + DotOpMmaV2ConversionHelper helper; ConversionPatternRewriter &rewriter; TypeConverter *typeConverter; Location loc; @@ -3203,22 +3288,25 @@ struct MMA16816ConversionHelper { static int getNumRepM(Type operand, int M, int wpt) { auto tensorCoreType = - DotOpConversionHelper::getTensorCoreTypeFromOperand(operand); - int mmaInstrM = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[0]; + DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrM = + DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[0]; return std::max(M / (wpt * mmaInstrM), 1); } static int getNumRepN(Type operand, int N, int wpt) { auto tensorCoreType = - DotOpConversionHelper::getTensorCoreTypeFromOperand(operand); - int mmaInstrN = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[1]; + DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrN = + DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[1]; return std::max(N / (wpt * mmaInstrN), 1); } static int getNumRepK_(Type operand, int K) { auto tensorCoreType = - DotOpConversionHelper::getTensorCoreTypeFromOperand(operand); - int mmaInstrK = DotOpConversionHelper::getMmaInstrShape(tensorCoreType)[2]; + DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrK = + DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[2]; return std::max(K / mmaInstrK, 1); } @@ -3304,7 +3392,7 @@ struct MMA16816ConversionHelper { // Loading $c to registers, returns a Value. Value loadC(Value tensor, Value llTensor) const { auto tensorTy = tensor.getType().cast(); - auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy); + auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy); size_t fcSize = 4 * repM * repN; assert(tensorTy.getEncoding().isa() && @@ -3371,7 +3459,7 @@ struct MMA16816ConversionHelper { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; - for (int i = 0; i < 4; i++) + for (int i = 0; i < 4; ++i) fc[m * colsPerThread + 4 * n + i] = extract_val(type::f32Ty(ctx), mmaOut, getIntAttr(i)); }; @@ -3427,7 +3515,7 @@ private: Type smemPtrTy = helper.getShemPtrTy(); for (int i = 0; i < numPtrs; ++i) { ptrs[i] = - bitcast(smemPtrTy, gep(smemPtrTy, llTensor, ValueRange({offs[i]}))); + bitcast(gep(smemPtrTy, llTensor, ValueRange({offs[i]})), smemPtrTy); } auto [ha0, ha1, ha2, ha3] = loader.loadX4( @@ -3492,7 +3580,7 @@ private: int offset{}; ValueTable vals; - for (int i = 0; i < n0; i++) { + for (int i = 0; i < n0; ++i) { for (int j = 0; j < n1; j++) { vals[{2 * i, 2 * j}] = elems[offset++]; vals[{2 * i, 2 * j + 1}] = elems[offset++]; @@ -3514,20 +3602,37 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( auto dotOperandLayout = dstTensorTy.getEncoding().cast(); + MmaEncodingAttr mmaLayout = dotOperandLayout.getParent().dyn_cast_or_null(); assert(mmaLayout); - MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), - rewriter, getTypeConverter(), op.getLoc()); - Value res; - if (dotOperandLayout.getOpIdx() == 0) { - // operand $a - res = mmaHelper.loadA(src, adaptor.src()); - } else if (dotOperandLayout.getOpIdx() == 1) { - // operand $b - res = mmaHelper.loadB(src, adaptor.src()); + if (mmaLayout.getVersion() == 2) { + MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), + rewriter, getTypeConverter(), + op.getLoc()); + + if (dotOperandLayout.getOpIdx() == 0) { + // operand $a + res = mmaHelper.loadA(src, adaptor.src()); + } else if (dotOperandLayout.getOpIdx() == 1) { + // operand $b + res = mmaHelper.loadB(src, adaptor.src()); + } + } else if (mmaLayout.getVersion() == 1) { + DotOpMmaV1ConversionHelper helper(mmaLayout); + if (dotOperandLayout.getOpIdx() == 0) { + // operand $a + res = helper.loadA(src, adaptor.src(), getThreadId(rewriter, loc), + adaptor.src(), loc, rewriter); + } else if (dotOperandLayout.getOpIdx() == 1) { + // operand $b + res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), + adaptor.src(), loc, rewriter); + } + } else { + assert(false && "Unsupported mma layout found"); } rewriter.replaceOp(op, res); @@ -3571,6 +3676,424 @@ DotOpConversion::convertMMA16816(triton::DotOp op, OpAdaptor adaptor, adaptor); } +// Simply port the old code here to avoid large difference and make debugging +// and profiling easier. +LogicalResult +DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto *ctx = op.getContext(); + auto loc = op.getLoc(); + + Value A = op.a(); + Value B = op.b(); + Value D = op.getResult(); + auto mmaLayout = D.getType() + .cast() + .getEncoding() + .cast(); + + auto ATensorTy = A.getType().cast(); + auto BTensorTy = B.getType().cast(); + auto DTensorTy = D.getType().cast(); + auto AShape = ATensorTy.getShape(); + auto BShape = BTensorTy.getShape(); + auto DShape = DTensorTy.getShape(); + auto wpt = mmaLayout.getWarpsPerCTA(); + + bool transA = op.transA(); + bool transB = op.transB(); + + bool isARow = !transA; + bool isBRow = !transB; + bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes + bool isBVec4 = isBRow && BShape[isBRow] <= 16; + int packSize0 = (isARow || isAVec4) ? 1 : 2; + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + SmallVector fpw({2, 2, 1}); + SmallVector rep({2 * packSize0, 2 * packSize1, 1}); + SmallVector spw({fpw[0] * 4 * rep[0], fpw[1] * 4 * rep[1], 1}); + + Value loadedA = adaptor.a(); + Value loadedB = adaptor.b(); + Value loadedC = adaptor.c(); + DotOpMmaV1ConversionHelper helper(mmaLayout); + + unsigned numM = rep[0] * DShape[0] / (spw[0] * wpt[0]); + unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]); + unsigned NK = AShape[1]; + + auto has = helper.extractLoadedOperand(loadedA, numM / 2, NK, rewriter); + auto hbs = helper.extractLoadedOperand(loadedB, numN / 2, NK, rewriter); + + size_t accSize = numM * numN; + + // initialize accumulators + SmallVector acc = getElementsFromStruct(loc, loadedC, rewriter); + + auto callMMA = [&](unsigned m, unsigned n, unsigned k) { + auto ha = has[{m, k}]; + auto hb = hbs[{n, k}]; + std::vector idx{{ + (m * 2 + 0) + (n * 4 + 0) * numM, // row0 + (m * 2 + 0) + (n * 4 + 1) * numM, + (m * 2 + 1) + (n * 4 + 0) * numM, // row1 + (m * 2 + 1) + (n * 4 + 1) * numM, + (m * 2 + 0) + (n * 4 + 2) * numM, // row2 + (m * 2 + 0) + (n * 4 + 3) * numM, + (m * 2 + 1) + (n * 4 + 2) * numM, // row3 + (m * 2 + 1) + (n * 4 + 3) * numM, + }}; + + PTXBuilder builder; + + auto *resOprs = builder.newListOperand(8, "=f"); + auto *AOprs = builder.newListOperand({ + {ha.first, "f"}, + {ha.second, "f"}, + }); + + auto *BOprs = builder.newListOperand({ + {hb.first, "f"}, + {hb.second, "f"}, + }); + auto *COprs = builder.newListOperand(); + for (int i = 0; i < 8; ++i) + COprs->listAppend(builder.newOperand(acc[idx[i]], std::to_string(i))); + + auto mma = builder.create("mma.sync.aligned.m8n8k4") + ->o(isARow ? "row" : "col") + .o(isBRow ? "row" : "col") + .o(".f32.f16.f16.f32"); + + mma(resOprs, AOprs, BOprs, COprs); + + Value res = builder.launch(rewriter, loc, helper.getMmaRetType(ATensorTy)); + + auto getIntAttr = [&](int v) { + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); + }; + for (unsigned i = 0; i < 8; i++) + acc[idx[i]] = extract_val(f32_ty, res, getIntAttr(i)); + }; + + for (unsigned k = 0; k < NK; k += 4) + for (unsigned m = 0; m < numM / 2; ++m) + for (unsigned n = 0; n < numN / 2; ++n) { + callMMA(m, n, k); + } + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(acc.size(), type::f32Ty(ctx))); + Value res = getStructFromElements(loc, acc, rewriter, structTy); + rewriter.replaceOp(op, res); + + return success(); +} + +Value DotOpMmaV1ConversionHelper::loadA( + Value tensor, Value llTensor, Value thread, Value smem, Location loc, + ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto tensorTy = tensor.getType().cast(); + auto shape = tensorTy.getShape(); + auto sharedLayout = tensorTy.getEncoding().cast(); + auto order = sharedLayout.getOrder(); + + bool isARow = order[0] != 0; + bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes + int packSize0 = (isARow || isAVec4) ? 1 : 2; + + SmallVector fpw({2, 2, 1}); + int repM = 2 * packSize0; + int repK = 1; + int spwM = fpw[0] * 4 * repM; + SmallVector rep({repM, 0, repK}); // pad N with 0 + SmallVector spw({spwM, 0, 1}); // pad N with 0 + + int vecA = sharedLayout.getVec(); + + int strideAM = isARow ? shape[1] : 1; + int strideAK = isARow ? 1 : shape[0]; + int strideA0 = isARow ? strideAK : strideAM; + int strideA1 = isARow ? strideAM : strideAK; + + int strideRepM = wpt[0] * fpw[0] * 8; + int strideRepK = 1; + + auto [offsetAM, offsetAK, _0, _1] = + computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); + + // swizzling + int perPhaseA = sharedLayout.getPerPhase(); + int maxPhaseA = sharedLayout.getMaxPhase(); + int stepA0 = isARow ? strideRepK : strideRepM; + int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1); + int NK = shape[1]; + + // pre-compute pointer lanes + Value offA0 = isARow ? offsetAK : offsetAM; + Value offA1 = isARow ? offsetAM : offsetAK; + Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); + SmallVector offA(numPtrA); + + for (int i = 0; i < numPtrA; i++) { + Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); + offA0I = udiv(offA0I, i32_val(vecA)); + offA0I = xor_(offA0I, phaseA); + offA0I = xor_(offA0I, i32_val(vecA)); + offA[i] = + add(mul(offA0I, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + } + + Type f16x2Ty = vec_ty(f16_ty, 2); + // One thread get 8 elements as result + Type retTy = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx))); + + // prepare arguments + SmallVector ptrA(numPtrA); + + std::map, std::pair> has; + for (int i = 0; i < numPtrA; i++) + ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]); + + auto instrShape = getMmaInstrShape(); + unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); + + Type f16PtrTy = ptr_ty(f16_ty); + + auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { + vals[{m, k}] = {val0, val1}; + }; + auto loadA = [&](int m, int k) { + int offidx = (isARow ? k / 4 : m) % numPtrA; + Value thePtrA = gep(f16PtrTy, smem, offA[offidx]); + + int stepAM = isARow ? m : m / numPtrA * numPtrA; + int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; + Value pa = gep(f16PtrTy, thePtrA, + i32_val(stepAM * strideRepM * strideAM + stepAK * strideAK)); + Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); + Value ha = load(bitcast(pa, aPtrTy)); + // record lds that needs to be moved + Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); + Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); + ld(has, m, k, ha00, ha01); + + if (vecA > 4) { + Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); + Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); + if (isARow) + ld(has, m, k + 4, ha10, ha11); + else + ld(has, m + 1, k, ha10, ha11); + } + }; + + for (unsigned k = 0; k < NK; k += 4) + for (unsigned m = 0; m < numM / 2; ++m) + if (!has.count({m, k})) + loadA(m, k); + + SmallVector elems; + elems.reserve(has.size() * 2); + auto vecTy = vec_ty(f16_ty, 2); + for (auto item : has) { // has is a map, the key should be ordered. + elems.push_back(item.second.first); + elems.push_back(item.second.second); + } + + Type resTy = struct_ty(SmallVector(elems.size(), f16x2Ty)); + Value res = getStructFromElements(loc, elems, rewriter, resTy); + return res; +} + +Value DotOpMmaV1ConversionHelper::loadB( + Value tensor, Value llTensor, Value thread, Value smem, Location loc, + ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto tensorTy = tensor.getType().cast(); + auto shape = tensorTy.getShape(); + auto sharedLayout = tensorTy.getEncoding().cast(); + auto order = sharedLayout.getOrder(); + bool isBRow = order[0] != 0; + bool isBVec4 = isBRow && shape[order[0]] <= 16; + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + SmallVector fpw({2, 2, 1}); + SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 + SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 + int vecB = sharedLayout.getVec(); + int strideBN = isBRow ? 1 : shape[0]; + int strideBK = isBRow ? shape[1] : 1; + int strideB0 = isBRow ? strideBN : strideBK; + int strideB1 = isBRow ? strideBK : strideBN; + int strideRepN = wpt[1] * fpw[1] * 8; + int strideRepK = 1; + + // swizzling + int perPhaseA = sharedLayout.getPerPhase(); + int maxPhaseA = sharedLayout.getMaxPhase(); + int perPhaseB = sharedLayout.getPerPhase(); + int maxPhaseB = sharedLayout.getMaxPhase(); + int stepB0 = isBRow ? strideRepN : strideRepK; + int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1); + int NK = shape[0]; + + auto [_0, _1, offsetBN, offsetBK] = + computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); + + Value offB0 = isBRow ? offsetBN : offsetBK; + Value offB1 = isBRow ? offsetBK : offsetBN; + Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); + SmallVector offB(numPtrB); + for (int i = 0; i < numPtrB; ++i) { + Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); + offB0I = udiv(offB0I, i32_val(vecB)); + offB0I = xor_(offB0I, phaseB); + offB0I = mul(offB0I, i32_val(vecB)); + offB[i] = + add(mul(offB0I, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + } + + Type f16PtrTy = ptr_ty(f16_ty); + Type f16x2Ty = vec_ty(f16_ty, 2); + + SmallVector ptrB(numPtrB); + ValueTable hbs; + for (int i = 0; i < numPtrB; ++i) + ptrB[i] = gep(ptr_ty(f16_ty), smem, offB[i]); + + auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { + vals[{m, k}] = {val0, val1}; + }; + + auto loadB = [&](int n, int K) { + int offidx = (isBRow ? n : K / 4) % numPtrB; + Value thePtrB = ptrB[offidx]; + + int stepBN = isBRow ? n / numPtrB * numPtrB : n; + int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); + Value pb = gep(f16PtrTy, thePtrB, + i32_val(stepBN * strideRepN * strideBN + stepBK * strideBK)); + Value hb = + load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); + // record lds that needs to be moved + Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); + Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); + ld(hbs, n, K, hb00, hb01); + if (vecB > 4) { + Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); + Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); + if (isBRow) + ld(hbs, n + 1, K, hb10, hb11); + else + ld(hbs, n, K + 4, hb10, hb11); + } + }; + + unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); + for (unsigned k = 0; k < NK; k += 4) + for (unsigned n = 0; n < numN / 2; ++n) { + if (!hbs.count({n, k})) + loadB(n, k); + } + + SmallVector elems; + for (auto &item : hbs) { // has is a map, the key should be ordered. + elems.push_back(item.second.first); + elems.push_back(item.second.second); + } + Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); + Type resTy = struct_ty(SmallVector(elems.size(), fp16x2Ty)); + Value res = getStructFromElements(loc, elems, rewriter, resTy); + return res; +} + +Value DotOpMmaV1ConversionHelper::loadC( + Value tensor, Value llTensor, ConversionPatternRewriter &rewriter) const { + return llTensor; +} + +std::tuple +DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow, + bool isBRow, ArrayRef fpw, + ArrayRef spw, ArrayRef rep, + ConversionPatternRewriter &rewriter, + Location loc) const { + auto *ctx = rewriter.getContext(); + Value _1 = i32_val(1); + Value _3 = i32_val(3); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + + Value lane = urem(threadId, _32); + Value warp = udiv(threadId, _32); + + // warp offset + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + Value warpMOff = mul(warp0, i32_val(spw[0])); + Value warpNOff = mul(warp1, i32_val(spw[1])); + // Quad offset + Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0])); + Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1])); + // Pair offset + Value pairMOff = udiv(urem(lane, _16), _4); + pairMOff = urem(pairMOff, i32_val(fpw[0])); + pairMOff = mul(pairMOff, _4); + Value pairNOff = udiv(urem(lane, _16), _4); + pairNOff = udiv(pairNOff, i32_val(fpw[0])); + pairNOff = urem(pairNOff, i32_val(fpw[1])); + pairNOff = mul(pairNOff, _4); + // scale + pairMOff = mul(pairMOff, i32_val(rep[0] / 2)); + quadMOff = mul(quadMOff, i32_val(rep[0] / 2)); + pairNOff = mul(pairNOff, i32_val(rep[1] / 2)); + quadNOff = mul(quadNOff, i32_val(rep[1] / 2)); + // Quad pair offset + Value laneMOff = add(pairMOff, quadMOff); + Value laneNOff = add(pairNOff, quadNOff); + // A offset + Value offsetAM = add(warpMOff, laneMOff); + Value offsetAK = and_(lane, _3); + // B offset + Value offsetBN = add(warpNOff, laneNOff); + Value offsetBK = and_(lane, _3); + // i indices + Value offsetCM = add(and_(lane, _1), offsetAM); + if (isARow) { + offsetAM = add(offsetAM, urem(threadId, _4)); + offsetAK = i32_val(0); + } + if (!isBRow) { + offsetBN = add(offsetBN, urem(threadId, _4)); + offsetBK = i32_val(0); + } + + return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK); +} + +DotOpMmaV1ConversionHelper::ValueTable +DotOpMmaV1ConversionHelper::extractLoadedOperand( + Value llStruct, int n0, int n1, ConversionPatternRewriter &rewriter) const { + ValueTable rcds; + SmallVector elems = + ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( + llStruct.getLoc(), llStruct, rewriter); + + int offset = 0; + for (int i = 0; i < n0; ++i) + for (int k = 0; k < n1; k += 4) { + rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); + offset += 2; + } + + return rcds; +} + /// ====================== mma codegen end ============================ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, @@ -3579,9 +4102,10 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, TypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { + auto tensorTy = resType.cast(); + auto shape = tensorTy.getShape(); if (layout.getVersion() == 2) { - auto tensorTy = resType.cast(); - auto [repM, repN] = DotOpConversionHelper::getRepMN(tensorTy); + auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy); size_t fcSize = 4 * repM * repN; auto structTy = LLVM::LLVMStructType::getLiteral( @@ -3589,6 +4113,18 @@ Value convertSplatLikeOpWithMmaLayout(const MmaEncodingAttr &layout, return getStructFromElements(loc, SmallVector(fcSize, constVal), rewriter, structTy); } + if (layout.getVersion() == 1) { + DotOpMmaV1ConversionHelper helper(layout); + int repM = helper.getRepM(shape[0]); + int repN = helper.getRepN(shape[1]); + // According to mma layout of v1, each thread process 8 elements. + int elems = 8 * repM * repN; + + auto structTy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), SmallVector(elems, elemType)); + return getStructFromElements(loc, SmallVector(elems, constVal), + rewriter, structTy); + } assert(false && "Unsupported mma layout found"); } @@ -3620,6 +4156,7 @@ public: llvm::Optional convertTritonTensorType(RankedTensorType type) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); + auto shape = type.getShape(); if (layout && (layout.isa() || layout.isa() || layout.isa())) { @@ -3632,12 +4169,21 @@ public: return LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3); } else if (auto mmaLayout = layout.dyn_cast_or_null()) { if (mmaLayout.getVersion() == 2) { - auto [repM, repN] = DotOpConversionHelper::getRepMN(type); + auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(type); size_t fcSize = 4 * repM * repN; return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fcSize, type.getElementType())); } + if (mmaLayout.getVersion() == 1) { + DotOpMmaV1ConversionHelper helper(mmaLayout); + int repM = helper.getRepM(shape[0]); + int repN = helper.getRepN(shape[1]); + int elems = 8 * repM * repN; + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems, type.getElementType())); + } + llvm::errs() << "Unexpected mma layout detected in TritonToLLVMTypeConverter"; return llvm::None; @@ -3645,9 +4191,9 @@ public: } else if (auto dot_op_layout = layout.dyn_cast_or_null()) { auto mmaLayout = dot_op_layout.getParent().cast(); + auto wpt = mmaLayout.getWarpsPerCTA(); + Type elemTy = type.getElementType(); if (mmaLayout.getVersion() == 2) { - auto wpt = mmaLayout.getWarpsPerCTA(); - Type elemTy = type.getElementType(); if (dot_op_layout.getOpIdx() == 0) { // $a int elems = @@ -3660,8 +4206,22 @@ public: int elems = MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt); Type x2Ty = vec_ty(elemTy, 2); - return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems, x2Ty)); + return struct_ty(SmallVector(elems, x2Ty)); + } + } + + if (mmaLayout.getVersion() == 1) { + DotOpMmaV1ConversionHelper helper(mmaLayout); + + if (dot_op_layout.getOpIdx() == 0) { // $a + int elems = helper.numElemsPerThreadA(type); + Type x2Ty = vec_ty(elemTy, 2); + return struct_ty(SmallVector(elems, x2Ty)); + } + if (dot_op_layout.getOpIdx() == 1) { // $b + int elems = helper.numElemsPerThreadB(type); + Type x2Ty = vec_ty(elemTy, 2); + return struct_ty(SmallVector(elems, x2Ty)); } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1b570b289..c44cbf5fc 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -270,10 +270,23 @@ unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef shape) const { size_t rank = shape.size(); assert(rank == 2 && "Unexpected rank of mma layout"); - assert(getVersion() == 2 && "mmaLayout version = 1 is not implemented yet"); - unsigned elemsCol = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; - unsigned elemsRow = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; - return elemsCol * elemsRow; + assert((getVersion() == 1 || getVersion() == 2) && + "Only version 1 and 2 is supported"); + + int res = 0; + if (getVersion() == 1) { + unsigned mmasRow = ceil(shape[0], 16 * getWarpsPerCTA()[0]); + unsigned mmasCol = ceil(shape[1], 16 * getWarpsPerCTA()[1]); + // Each warp-level mma884 will perform a m16xn16xk4 mma, thus get a m16xn16 + // matrix as result. + res = mmasRow * mmasCol * (16 * 16 / 32); + } else if (getVersion() == 2) { + unsigned elemsCol = ceil(shape[0], 16 * getWarpsPerCTA()[0]) * 2; + unsigned elemsRow = ceil(shape[1], 8 * getWarpsPerCTA()[1]) * 2; + res = elemsCol * elemsRow; + } + + return res; } unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef shape) const { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index a1a9392d7..3bfed0569 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -738,3 +738,28 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 1, warpsPerCTA = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + %a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a> + %b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b> + + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> + // TODO[goostavz]: uncomment the following lines after convert_layout[mma -> blocked] is ready. + // %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked> + // %30 = tt.splat %ptr : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked> + // %36 = tt.broadcast %30 : (tensor<128x1x!tt.ptr, #blocked>) -> tensor<128x256x!tt.ptr, #blocked> + // tt.store %36, %38 : tensor<128x256xf32, #blocked> + return + } +}