#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_HELPERS_H #define TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_HELPERS_H #include "./Utility.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" #include #include #include namespace mlir { namespace LLVM { using namespace mlir::triton; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; // Helper for conversion of DotOp with mma, that is sm<80 struct DotOpMmaV1ConversionHelper { MmaEncodingAttr mmaLayout; ArrayRef wpt; using ValueTable = std::map, std::pair>; explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout) : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} int getRepM(int M) const { return std::max(M / (wpt[0] * instrShape[0]), 1); } int getRepN(int N) const { return std::max(N / (wpt[1] * instrShape[1]), 1); } static ArrayRef getMmaInstrShape() { return instrShape; } static Type getMmaRetType(TensorType operand) { auto *ctx = operand.getContext(); Type fp32Ty = type::f32Ty(ctx); // f16*f16+f32->f32 return struct_ty(SmallVector{8, fp32Ty}); } // number of fp16x2 elements for $a. int numElemsPerThreadA(RankedTensorType tensorTy) const { auto shape = tensorTy.getShape(); auto order = getOrder(); bool isARow = order[0] != 0; bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes int packSize0 = (isARow || isAVec4) ? 1 : 2; SmallVector fpw({2, 2, 1}); int repM = 2 * packSize0; int repK = 1; int spwM = fpw[0] * 4 * repM; SmallVector rep({repM, 0, repK}); // pad N with 0 SmallVector spw({spwM, 0, 1}); // pad N with 0 int NK = shape[1]; unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); // NOTE: We couldn't get the vec from the shared layout. // int vecA = sharedLayout.getVec(); // TODO[Superjomn]: Consider the case when vecA > 4 bool vecGt4 = false; int elemsPerLd = vecGt4 ? 4 : 2; return (numM / 2) * (NK / 4) * elemsPerLd; } // number of fp16x2 elements for $b. int numElemsPerThreadB(RankedTensorType tensorTy) const { auto shape = tensorTy.getShape(); auto order = getOrder(); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; int packSize1 = (isBRow && !isBVec4) ? 2 : 1; SmallVector fpw({2, 2, 1}); SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 // NOTE: We couldn't get the vec from the shared layout. // int vecB = sharedLayout.getVec(); // TODO[Superjomn]: Consider the case when vecA > 4 bool vecGt4 = false; int elemsPerLd = vecGt4 ? 4 : 2; int NK = shape[0]; unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); return (numN / 2) * (NK / 4) * elemsPerLd; } // Loading $a from smem to registers, returns a LLVM::Struct. Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const; // Loading $b from smem to registers, returns a LLVM::Struct. Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const; static ArrayRef getOrder() { return mmaOrder; } // Compute the offset of the matrix to load. // Returns offsetAM, offsetAK, offsetBN, offsetBK. // NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at // the same time in the usage in convert_layout[shared->dot_op], we leave // the noexist info to be 0 and only use the desired argument from the // composed result. In this way we want to retain the original code // structure in convert_mma884 method for easier debugging. std::tuple computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, ArrayRef spw, ArrayRef rep, ConversionPatternRewriter &rewriter, Location loc) const; // Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1. DotOpMmaV1ConversionHelper::ValueTable extractLoadedOperand(Value llStruct, int NK, ConversionPatternRewriter &rewriter) const; private: static constexpr unsigned instrShape[] = {16, 16, 4}; static constexpr unsigned mmaOrder[] = {0, 1}; }; // Helper for conversion of DotOp with mma, that is sm>=80 struct DotOpMmaV2ConversionHelper { enum class TensorCoreType : uint8_t { // floating-point tensor core instr FP32_FP16_FP16_FP32 = 0, // default FP32_BF16_BF16_FP32, FP32_TF32_TF32_FP32, // integer tensor core instr INT32_INT1_INT1_INT32, // Not implemented INT32_INT4_INT4_INT32, // Not implemented INT32_INT8_INT8_INT32, // Not implemented // NOT_APPLICABLE, }; MmaEncodingAttr mmaLayout; MLIRContext *ctx{}; explicit DotOpMmaV2ConversionHelper(MmaEncodingAttr mmaLayout) : mmaLayout(mmaLayout) { ctx = mmaLayout.getContext(); } void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); } void deduceMmaType(Type operandTy) const { mmaType = getTensorCoreTypeFromOperand(operandTy); } // Get the M and N of mma instruction shape. static std::tuple getInstrShapeMN() { // According to DotOpConversionHelper::mmaInstrShape, all the M,N are // {16,8} return {16, 8}; } static std::tuple getRepMN(const RankedTensorType &tensorTy) { auto mmaLayout = tensorTy.getEncoding().cast(); auto wpt = mmaLayout.getWarpsPerCTA(); int M = tensorTy.getShape()[0]; int N = tensorTy.getShape()[1]; auto [instrM, instrN] = getInstrShapeMN(); int repM = std::max(M / (wpt[0] * instrM), 1); int repN = std::max(N / (wpt[1] * instrN), 1); return {repM, repN}; } Type getShemPtrTy() const { switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: return ptr_ty(type::f16Ty(ctx), 3); case TensorCoreType::FP32_BF16_BF16_FP32: return ptr_ty(type::bf16Ty(ctx), 3); case TensorCoreType::FP32_TF32_TF32_FP32: return ptr_ty(type::f32Ty(ctx), 3); case TensorCoreType::INT32_INT8_INT8_INT32: return ptr_ty(type::i8Ty(ctx), 3); default: llvm::report_fatal_error("mma16816 data type not supported"); } return Type{}; } // The type of matrix that loaded by either a ldmatrix or composed lds. Type getMatType() const { Type fp32Ty = type::f32Ty(ctx); Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); Type bf16x2Ty = vec_ty(type::bf16Ty(ctx), 2); // floating point types Type fp16x2Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); Type bf16x2Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, bf16x2Ty)); Type fp32Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); // integer types Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); Type i8x4Pack4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: return fp16x2Pack4Ty; case TensorCoreType::FP32_BF16_BF16_FP32: return bf16x2Pack4Ty; case TensorCoreType::FP32_TF32_TF32_FP32: return fp32Pack4Ty; case TensorCoreType::INT32_INT8_INT8_INT32: return i8x4Pack4Ty; default: llvm::report_fatal_error("Unsupported mma type found"); } return Type{}; } Type getLoadElemTy() { switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: return vec_ty(type::f16Ty(ctx), 2); case TensorCoreType::FP32_BF16_BF16_FP32: return vec_ty(type::bf16Ty(ctx), 2); case TensorCoreType::FP32_TF32_TF32_FP32: return type::f32Ty(ctx); case TensorCoreType::INT32_INT8_INT8_INT32: return type::i32Ty(ctx); default: llvm::report_fatal_error("Unsupported mma type found"); } return Type{}; } Type getMmaRetType() const { Type fp32Ty = type::f32Ty(ctx); Type i32Ty = type::i32Ty(ctx); Type fp32x4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); Type i32x4Ty = LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i32Ty)); switch (mmaType) { case TensorCoreType::FP32_FP16_FP16_FP32: return fp32x4Ty; case TensorCoreType::FP32_BF16_BF16_FP32: return fp32x4Ty; case TensorCoreType::FP32_TF32_TF32_FP32: return fp32x4Ty; case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; default: llvm::report_fatal_error("Unsupported mma type found"); } return Type{}; } ArrayRef getMmaInstrShape() const { assert(mmaType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); return mmaInstrShape.at(mmaType); } static ArrayRef getMmaInstrShape(TensorCoreType tensorCoreType) { assert(tensorCoreType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); return mmaInstrShape.at(tensorCoreType); } ArrayRef getMmaMatShape() const { assert(mmaType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); return mmaMatShape.at(mmaType); } // Deduce the TensorCoreType from either $a or $b's type. static TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) { auto tensorTy = operandTy.cast(); auto elemTy = tensorTy.getElementType(); if (elemTy.isF16()) return TensorCoreType::FP32_FP16_FP16_FP32; if (elemTy.isF32()) return TensorCoreType::FP32_TF32_TF32_FP32; if (elemTy.isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; if (elemTy.isInteger(8)) return TensorCoreType::INT32_INT8_INT8_INT32; return TensorCoreType::NOT_APPLICABLE; } int getVec() const { assert(mmaType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); return mmaInstrVec.at(mmaType); } StringRef getMmaInstr() const { assert(mmaType != TensorCoreType::NOT_APPLICABLE && "Unknown mma type found."); return mmaInstrPtx.at(mmaType); } static TensorCoreType getMmaType(triton::DotOp op) { Value A = op.a(); Value B = op.b(); auto aTy = A.getType().cast(); auto bTy = B.getType().cast(); // d = a*b + c auto dTy = op.d().getType().cast(); if (dTy.getElementType().isF32()) { if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) return TensorCoreType::FP32_FP16_FP16_FP32; if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.allowTF32()) return TensorCoreType::FP32_TF32_TF32_FP32; } else if (dTy.getElementType().isInteger(32)) { if (aTy.getElementType().isInteger(8) && bTy.getElementType().isInteger(8)) return TensorCoreType::INT32_INT8_INT8_INT32; } return TensorCoreType::NOT_APPLICABLE; } private: mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE}; // Used on nvidia GPUs mma layout .version == 2 // Refer to // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-storage // for more details. inline static const std::map> mmaInstrShape = { {TensorCoreType::FP32_FP16_FP16_FP32, {16, 8, 16}}, {TensorCoreType::FP32_BF16_BF16_FP32, {16, 8, 16}}, {TensorCoreType::FP32_TF32_TF32_FP32, {16, 8, 8}}, {TensorCoreType::INT32_INT1_INT1_INT32, {16, 8, 256}}, {TensorCoreType::INT32_INT4_INT4_INT32, {16, 8, 64}}, {TensorCoreType::INT32_INT8_INT8_INT32, {16, 8, 32}}, }; // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) // Refer to // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix // for more details. inline static const std::map> mmaMatShape = { {TensorCoreType::FP32_FP16_FP16_FP32, {8, 8, 8}}, {TensorCoreType::FP32_BF16_BF16_FP32, {8, 8, 8}}, {TensorCoreType::FP32_TF32_TF32_FP32, {8, 8, 4}}, {TensorCoreType::INT32_INT1_INT1_INT32, {8, 8, 64}}, {TensorCoreType::INT32_INT4_INT4_INT32, {8, 8, 32}}, {TensorCoreType::INT32_INT8_INT8_INT32, {8, 8, 16}}, }; // Supported mma instruction in PTX. // Refer to // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma // for more details. inline static const std::map mmaInstrPtx = { {TensorCoreType::FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, {TensorCoreType::FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, {TensorCoreType::FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, {TensorCoreType::INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, {TensorCoreType::INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, {TensorCoreType::INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, }; // vector length per ldmatrix (16*8/element_size_in_bits) inline static const std::map mmaInstrVec = { {TensorCoreType::FP32_FP16_FP16_FP32, 8}, {TensorCoreType::FP32_BF16_BF16_FP32, 8}, {TensorCoreType::FP32_TF32_TF32_FP32, 4}, {TensorCoreType::INT32_INT1_INT1_INT32, 128}, {TensorCoreType::INT32_INT4_INT4_INT32, 32}, {TensorCoreType::INT32_INT8_INT8_INT32, 16}, }; }; // Data loader for mma.16816 instruction. class MMA16816SmemLoader { public: MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, const Location &loc) : order(order.begin(), order.end()), kOrder(kOrder), tileShape(tileShape.begin(), tileShape.end()), instrShape(instrShape.begin(), instrShape.end()), matShape(matShape.begin(), matShape.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { cMatShape = matShape[order[0]]; sMatShape = matShape[order[1]]; sStride = smemStrides[order[1]]; // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16 if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwpt] if not transposed, // otherwise [wptx1], and each warp will perform a mma. numPtrs = tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; } else { numPtrs = tileShape[order[0]] / wpt / matShape[order[0]]; } numPtrs = std::max(numPtrs, 2); // Special rule for i8/u8, 4 ptrs for each matrix if (!canUseLdmatrix && elemBytes == 1) numPtrs *= 4; int loadStrideInMat[2]; loadStrideInMat[kOrder] = 2; // instrShape[kOrder] / matShape[kOrder], always 2 loadStrideInMat[kOrder ^ 1] = wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]); pLoadStrideInMat = loadStrideInMat[order[0]]; sMatStride = loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]); // Each matArr contains warpOffStride matrices. matArrStride = kOrder == 1 ? 1 : wpt; warpOffStride = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]; } // lane = thread % 32 // warpOff = (thread/32) % wpt(0) llvm::SmallVector computeOffsets(Value warpOff, Value lane, Value cSwizzleOffset) { if (canUseLdmatrix) return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); else if (elemBytes == 4 && needTrans) return computeB32MatOffs(warpOff, lane, cSwizzleOffset); else if (elemBytes == 1 && needTrans) return computeB8MatOffs(warpOff, lane, cSwizzleOffset); else llvm::report_fatal_error("Invalid smem load config"); return {}; } int getNumPtrs() const { return numPtrs; } // Compute the offset to the matrix this thread(indexed by warpOff and lane) // mapped to. SmallVector computeLdmatrixMatOffs(Value warpId, Value lane, Value cSwizzleOffset) { // 4x4 matrices Value c = urem(lane, i32_val(8)); Value s = udiv(lane, i32_val(8)); // sub-warp-id // Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a // warp Value s0 = urem(s, i32_val(2)); Value s1 = udiv(s, i32_val(2)); // We use different orders for a and b for better performance. Value kMatArr = kOrder == 1 ? s1 : s0; Value nkMatArr = kOrder == 1 ? s0 : s1; // matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and // [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is // |0 0 1 1 2 2| -> 0,1,2 are the warpids // |0 0 1 1 2 2| // // for B(kOrder=0) is // |0 0| -> 0,1,2 are the warpids // |1 1| // |2 2| // |0 0| // |1 1| // |2 2| // Note, for each warp, it handles a 2x2 matrices, that is the coordinate // address (s0,s1) annotates. Value matOff[2]; matOff[kOrder ^ 1] = add( mul(warpId, i32_val(warpOffStride)), // warp offset mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp matOff[kOrder] = kMatArr; // Physical offset (before swizzling) Value cMatOff = matOff[order[0]]; Value sMatOff = matOff[order[1]]; Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); cMatOff = add(cMatOff, cSwizzleMatOff); // row offset inside a matrix, each matrix has 8 rows. Value sOffInMat = c; SmallVector offs(numPtrs); Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); for (int i = 0; i < numPtrs; ++i) { Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); cMatOffI = xor_(cMatOffI, phase); offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride)); } return offs; } // Compute 32-bit matrix offsets. SmallVector computeB32MatOffs(Value warpOff, Value lane, Value cSwizzleOffset) { assert(needTrans && "Only used in transpose mode."); // Load tf32 matrices with lds32 Value cOffInMat = udiv(lane, i32_val(4)); Value sOffInMat = urem(lane, i32_val(4)); Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); SmallVector offs(numPtrs); for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; if (kMatArrInt > 0) // we don't need pointers for k continue; Value kMatArr = i32_val(kMatArrInt); Value nkMatArr = i32_val(nkMatArrInt); Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), mul(nkMatArr, i32_val(matArrStride))); Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); cMatOff = add(cMatOff, cSwizzleMatOff); Value sMatOff = kMatArr; Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); // FIXME: (kOrder == 1?) is really dirty hack for (int i = 0; i < numPtrs / 2; ++i) { Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); cMatOffI = xor_(cMatOffI, phase); Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); cOff = urem(cOff, i32_val(tileShape[order[0]])); sOff = urem(sOff, i32_val(tileShape[order[1]])); offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride)); } } return offs; } // compute 8-bit matrix offset. SmallVector computeB8MatOffs(Value warpOff, Value lane, Value cSwizzleOffset) { assert(needTrans && "Only used in transpose mode."); Value cOffInMat = udiv(lane, i32_val(4)); Value sOffInMat = mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols SmallVector offs(numPtrs); for (int mat = 0; mat < 4; ++mat) { int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; if (kMatArrInt > 0) // we don't need pointers for k continue; Value kMatArr = i32_val(kMatArrInt); Value nkMatArr = i32_val(nkMatArrInt); Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), mul(nkMatArr, i32_val(matArrStride))); Value sMatOff = kMatArr; for (int loadx4Off = 0; loadx4Off < numPtrs / 8; ++loadx4Off) { for (int elemOff = 0; elemOff < 4; ++elemOff) { int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); Value sOffInMatElem = add(sOffInMat, i32_val(elemOff)); // disable swizzling ... Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); Value sOff = add(sOffInMatElem, mul(sMatOff, i32_val(sMatShape))); // To prevent out-of-bound access when tile is too small. cOff = urem(cOff, i32_val(tileShape[order[0]])); sOff = urem(sOff, i32_val(tileShape[order[1]])); offs[ptrOff] = add(cOff, mul(sOff, sStride)); } } } return offs; } // Load 4 matrices and returns 4 vec<2> elements. std::tuple loadX4(int mat0, int mat1, ArrayRef offs, ArrayRef ptrs, Type ldmatrixRetTy, Type shemPtrTy) const { assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); int matIdx[2] = {mat0, mat1}; int ptrIdx{-1}; if (canUseLdmatrix) ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); else if (elemBytes == 4 && needTrans) ptrIdx = matIdx[order[0]]; else if (elemBytes == 1 && needTrans) ptrIdx = matIdx[order[0]] * 4; else llvm::report_fatal_error("unsupported mma type found"); // The main difference with the original triton code is we removed the // prefetch-related logic here for the upstream optimizer phase should // take care with it, and that is transparent in dot conversion. auto getPtr = [&](int idx) { return ptrs[idx]; }; Value ptr = getPtr(ptrIdx); if (canUseLdmatrix) { Value sOffset = mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride); Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset); PTXBuilder builder; // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a // thread. auto resArgs = builder.newListOperand(4, "=r"); auto addrArg = builder.newAddrOperand(sOffsetPtr, "r"); auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") ->o("trans", needTrans /*predicate*/) .o("shared.b16"); ldmatrix(resArgs, addrArg); // The result type is 4xi32, each i32 is composed of 2xf16 // elements(adjacent two columns in a row) Value resV4 = builder.launch(rewriter, loc, ldmatrixRetTy); auto getIntAttr = [&](int v) { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; // The struct should have exactly the same element types. Type elemType = resV4.getType().cast().getBody()[0]; return {extract_val(elemType, resV4, getIntAttr(0)), extract_val(elemType, resV4, getIntAttr(1)), extract_val(elemType, resV4, getIntAttr(2)), extract_val(elemType, resV4, getIntAttr(3))}; } else if (elemBytes == 4 && needTrans) { // Use lds.32 to load tf32 matrices Value ptr2 = getPtr(ptrIdx + 1); assert(sMatStride == 1); int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); int sOffsetArrElem = sMatStride * sMatShape; Value sOffsetArrElemVal = add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); Value elems[4]; Type elemTy = type::f32Ty(ctx); Type elemPtrTy = ptr_ty(elemTy); if (kOrder == 1) { elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal)); elems[1] = load(gep(elemPtrTy, ptr2, sOffsetElemVal)); elems[2] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); } else { elems[0] = load(gep(elemPtrTy, ptr, sOffsetElemVal)); elems[2] = load(gep(elemPtrTy, ptr2, sOffsetElemVal)); elems[1] = load(gep(elemPtrTy, ptr, sOffsetArrElemVal)); elems[3] = load(gep(elemPtrTy, ptr2, sOffsetArrElemVal)); } return {elems[0], elems[1], elems[2], elems[3]}; } else if (elemBytes == 1 && needTrans) { // work with int8 std::array, 2> ptrs; ptrs[0] = { getPtr(ptrIdx), getPtr(ptrIdx + 1), getPtr(ptrIdx + 2), getPtr(ptrIdx + 3), }; ptrs[1] = { getPtr(ptrIdx + 4), getPtr(ptrIdx + 5), getPtr(ptrIdx + 6), getPtr(ptrIdx + 7), }; assert(sMatStride == 1); int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); int sOffsetArrElem = 1 * (sMatStride * sMatShape); Value sOffsetArrElemVal = add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); std::array i8v4Elems; std::array i32Elems; i8v4Elems.fill( rewriter.create(loc, vec_ty(type::i8Ty(ctx), 4))); Value i8Elems[4][4]; Type elemTy = type::i8Ty(ctx); Type elemPtrTy = ptr_ty(elemTy); Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); if (kOrder == 1) { for (int i = 0; i < 2; ++i) for (int j = 0; j < 4; ++j) i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], sOffsetElemVal)); for (int i = 2; i < 4; ++i) for (int j = 0; j < 4; ++j) i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); } } else { // k first for (int j = 0; j < 4; ++j) i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetElemVal)); for (int j = 0; j < 4; ++j) i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetElemVal)); for (int j = 0; j < 4; ++j) i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetArrElemVal)); for (int j = 0; j < 4; ++j) i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetArrElemVal)); for (int m = 0; m < 4; ++m) { for (int e = 0; e < 4; ++e) i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], i8Elems[m][e], i32_val(e)); i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); } } return {i32Elems[0], i32Elems[1], i32Elems[2], i32Elems[3]}; } assert(false && "Invalid smem load"); return {Value{}, Value{}, Value{}, Value{}}; } private: SmallVector order; int kOrder; SmallVector tileShape; SmallVector instrShape; SmallVector matShape; int perPhase; int maxPhase; int elemBytes; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; int cMatShape; int sMatShape; Value sStride; bool needTrans; bool canUseLdmatrix; int numPtrs; int pLoadStrideInMat; int sMatStride; int matArrStride; int warpOffStride; }; // This class helps to adapt the existing DotOpConversion to the latest // DotOpOperand layout design. It decouples the exising implementation to two // parts: // 1. loading the specific operand matrix(for $a, $b, $c) from smem // 2. passing the loaded value and perform the mma codegen struct MMA16816ConversionHelper { MmaEncodingAttr mmaLayout; ArrayRef wpt; SmallVector properWpt; Value thread, lane, warp; DotOpMmaV2ConversionHelper helper; ConversionPatternRewriter &rewriter; TypeConverter *typeConverter; Location loc; MLIRContext *ctx{}; using ValueTable = std::map, Value>; // dotOperand: type of either one operand of dotOp. MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, Value thread, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, Location loc) : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter), loc(loc), ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { helper.deduceMmaType(dotOperand); Value _32 = i32_val(32); lane = urem(thread, _32); warp = udiv(thread, _32); } // Get a warpId for M axis. Value getWarpM(int M) const { auto matShape = helper.getMmaMatShape(); return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0])); } // Get a warpId for N axis. Value getWarpN(int N) const { auto matShape = helper.getMmaMatShape(); Value warpMN = udiv(warp, i32_val(wpt[0])); return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1])); } // Get the mmaInstrShape deducing either from $a or $b. std::tuple getMmaInstrShape(Type operand) const { helper.deduceMmaType(operand); auto mmaInstrShape = helper.getMmaInstrShape(); int mmaInstrM = mmaInstrShape[0]; int mmaInstrN = mmaInstrShape[1]; int mmaInstrK = mmaInstrShape[2]; return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK); } // Get the mmaMatShape deducing either from $a or $b. std::tuple getMmaMatShape(Type operand) const { helper.deduceMmaType(operand); auto matShape = helper.getMmaMatShape(); int matShapeM = matShape[0]; int matShapeN = matShape[1]; int matShapeK = matShape[2]; return std::make_tuple(matShapeM, matShapeN, matShapeK); } // \param operand is either $a or $b's type. inline int getNumRepM(Type operand, int M) const { return getNumRepM(operand, M, wpt[0]); } // \param operand is either $a or $b's type. inline int getNumRepN(Type operand, int N) const { return getNumRepN(operand, N, wpt[1]); } // \param operand is either $a or $b's type. inline int getNumRepK(Type operand, int K) const { return getNumRepK_(operand, K); } static int getNumRepM(Type operand, int M, int wpt) { auto tensorCoreType = DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); int mmaInstrM = DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[0]; return std::max(M / (wpt * mmaInstrM), 1); } static int getNumRepN(Type operand, int N, int wpt) { auto tensorCoreType = DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); int mmaInstrN = DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[1]; return std::max(N / (wpt * mmaInstrN), 1); } static int getNumRepK_(Type operand, int K) { auto tensorCoreType = DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); int mmaInstrK = DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[2]; return std::max(K / mmaInstrK, 1); } // Get number of elements per thread for $a operand. static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) { auto shape = operand.getShape(); int repM = getNumRepM(operand, shape[0], wpt); int repK = getNumRepK_(operand, shape[1]); return 4 * repM * repK; } // Get number of elements per thread for $b operand. static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) { auto shape = operand.getShape(); int repK = getNumRepK_(operand, shape[0]); int repN = getNumRepN(operand, shape[1], wpt); return 4 * std::max(repN / 2, 1) * repK; } // Loading $a from smem to registers, returns a LLVM::Struct. Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { auto aTensorTy = tensor.getType().cast(); SmallVector shape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); ValueTable ha; std::function loadFn; auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); int numRepM = getNumRepM(aTensorTy, shape[0]); int numRepK = getNumRepK(aTensorTy, shape[1]); if (aTensorTy.getEncoding().isa()) { Value warpM = getWarpM(shape[0]); // load from smem int wpt = std::min(mmaLayout.getWarpsPerCTA()[0], shape[0] / matShapeM); loadFn = getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, true /*isA*/); } else if (aTensorTy.getEncoding().isa()) { // load from registers, used in gemm fuse // TODO(Superjomn) Port the logic. assert(false && "Loading A from register is not supported yet."); } else { assert(false && "A's layout is not supported."); } // step1. Perform loading. for (int m = 0; m < numRepM; ++m) for (int k = 0; k < numRepK; ++k) loadFn(2 * m, 2 * k); // step2. Format the values to LLVM::Struct to passing to mma codegen. return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); } // Loading $b from smem to registers, returns a LLVM::Struct. Value loadB(Value tensor, const SharedMemoryObject &smemObj) { ValueTable hb; auto tensorTy = tensor.getType().cast(); SmallVector shape(tensorTy.getShape().begin(), tensorTy.getShape().end()); // TODO[Superjomn]: transB cannot be accessed in ConvertLayoutOp. bool transB = false; if (transB) { std::swap(shape[0], shape[1]); } auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy); auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy); int numRepK = getNumRepK(tensorTy, shape[0]); int numRepN = getNumRepN(tensorTy, shape[1]); Value warpN = getWarpN(shape[1]); int wpt = std::min(mmaLayout.getWarpsPerCTA()[1], shape[1] / matShapeN); auto loadFn = getLoadMatrixFn(tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, false /*isA*/); for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { for (int k = 0; k < numRepK; ++k) loadFn(2 * n, 2 * k); } Value result = composeValuesToDotOperandLayoutStruct( hb, std::max(numRepN / 2, 1), numRepK); return result; } // Loading $c to registers, returns a Value. Value loadC(Value tensor, Value llTensor) const { auto tensorTy = tensor.getType().cast(); auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy); size_t fcSize = 4 * repM * repN; assert(tensorTy.getEncoding().isa() && "Currently, we only support $c with a mma layout."); // Load a normal C tensor with mma layout, that should be a // LLVM::struct with fcSize elements. auto structTy = llTensor.getType().cast(); assert(structTy.getBody().size() == fcSize && "DotOp's $c operand should pass the same number of values as $d in " "mma layout."); return llTensor; } // Conduct the Dot conversion. // \param a, \param b, \param c and \param d are DotOp operands. // \param loadedA, \param loadedB, \param loadedC, all of them are result of // loading. LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, DotOp op, DotOpAdaptor adaptor) const { helper.deduceMmaType(op); auto aTensorTy = a.getType().cast(); auto dTensorTy = d.getType().cast(); SmallVector aShape(aTensorTy.getShape().begin(), aTensorTy.getShape().end()); auto dShape = dTensorTy.getShape(); // shape / shape_per_cta int numRepM = getNumRepM(aTensorTy, dShape[0]); int numRepN = getNumRepN(aTensorTy, dShape[1]); int numRepK = getNumRepK(aTensorTy, aShape[1]); ValueTable ha = getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK); ValueTable hb = getValuesFromDotOperandLayoutStruct( loadedB, std::max(numRepN / 2, 1), numRepK); auto fc = getElementsFromStruct(loc, loadedC, rewriter); auto callMma = [&](unsigned m, unsigned n, unsigned k) { unsigned colsPerThread = numRepN * 2; PTXBuilder builder; auto &mma = *builder.create(helper.getMmaInstr().str()); auto retArgs = builder.newListOperand(4, "=r"); auto aArgs = builder.newListOperand({ {ha[{m, k}], "r"}, {ha[{m + 1, k}], "r"}, {ha[{m, k + 1}], "r"}, {ha[{m + 1, k + 1}], "r"}, }); auto bArgs = builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); auto cArgs = builder.newListOperand(); for (int i = 0; i < 4; ++i) { cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i], std::to_string(i))); // reuse the output registers } mma(retArgs, aArgs, bArgs, cArgs); Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); auto getIntAttr = [&](int v) { return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); }; Type elemTy = mmaOut.getType().cast().getBody()[0]; for (int i = 0; i < 4; ++i) fc[m * colsPerThread + 4 * n + i] = extract_val(elemTy, mmaOut, getIntAttr(i)); }; for (int k = 0; k < numRepK; ++k) for (int m = 0; m < numRepM; ++m) for (int n = 0; n < numRepN; ++n) callMma(2 * m, n, 2 * k); Type resElemTy = dTensorTy.getElementType(); for (auto &elem : fc) { elem = bitcast(elem, resElemTy); } // replace with new packed result Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(fc.size(), resElemTy)); Value res = getStructFromElements(loc, fc, rewriter, structTy); rewriter.replaceOp(op, res); return success(); } private: std::function getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, SmallVector instrShape, SmallVector matShape, Value warpId, ValueTable &vals, bool isA) const { auto tensorTy = tensor.getType().cast(); // We assumes that the input operand of Dot should be from shared layout. // TODO(Superjomn) Consider other layouts if needed later. auto sharedLayout = tensorTy.getEncoding().cast(); const int perPhase = sharedLayout.getPerPhase(); const int maxPhase = sharedLayout.getMaxPhase(); const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; auto order = sharedLayout.getOrder(); // the original register_lds2, but discard the prefetch logic. auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { vals[{mn, k}] = val; }; // (a, b) is the coordinate. auto load = [=, &vals, &ld2](int a, int b) { MMA16816SmemLoader loader( wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(warpId, lane, cSwizzleOffset); const int numPtrs = loader.getNumPtrs(); SmallVector ptrs(numPtrs); Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); Type smemPtrTy = helper.getShemPtrTy(); for (int i = 0; i < numPtrs; ++i) { ptrs[i] = bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy); } auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, ptrs, helper.getMatType(), helper.getShemPtrTy()); if (isA) { ld2(vals, a, b, ha0); ld2(vals, a + 1, b, ha1); ld2(vals, a, b + 1, ha2); ld2(vals, a + 1, b + 1, ha3); } else { ld2(vals, a, b, ha0); ld2(vals, a + 1, b, ha2); ld2(vals, a, b + 1, ha1); ld2(vals, a + 1, b + 1, ha3); } }; return load; } // Compose a map of Values to a LLVM::Struct. // The layout is a list of Value with coordinate of (i,j), the order is as // the follows: // [ // (0,0), (0,1), (1,0), (1,1), # i=0, j=0 // (0,2), (0,3), (1,2), (1,3), # i=0, j=1 // (0,4), (0,5), (1,4), (1,5), # i=0, j=2 // ... // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 // (2,4), (2,5), (3,4), (3,5), # i=1, j=2 // ... // ] // i \in [0, n0) and j \in [0, n1) // There should be \param n0 * \param n1 elements in the output Struct. Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, int n1) const { std::vector elems; for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) { elems.push_back(vals.at({2 * m, 2 * k})); elems.push_back(vals.at({2 * m, 2 * k + 1})); elems.push_back(vals.at({2 * m + 1, 2 * k})); elems.push_back(vals.at({2 * m + 1, 2 * k + 1})); } assert(!elems.empty()); Type elemTy = elems[0].getType(); Type structTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(elems.size(), elemTy)); auto result = getStructFromElements(loc, elems, rewriter, structTy); return result; } ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1) const { auto elems = getElementsFromStruct(loc, value, rewriter); int offset{}; ValueTable vals; for (int i = 0; i < n0; ++i) { for (int j = 0; j < n1; j++) { vals[{2 * i, 2 * j}] = elems[offset++]; vals[{2 * i, 2 * j + 1}] = elems[offset++]; vals[{2 * i + 1, 2 * j}] = elems[offset++]; vals[{2 * i + 1, 2 * j + 1}] = elems[offset++]; } } return vals; } }; // Helper for conversion of FMA DotOp. struct DotOpFMAConversionHelper { Attribute layout; MLIRContext *ctx{}; using ValueTable = std::map, Value>; explicit DotOpFMAConversionHelper(Attribute layout) : layout(layout), ctx(layout.getContext()) {} SmallVector getThreadIds(Value threadId, ArrayRef shapePerCTA, ArrayRef order, ConversionPatternRewriter &rewriter, Location loc) const; Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc, ConversionPatternRewriter &rewriter) const; Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc, ConversionPatternRewriter &rewriter) const; ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, int sizePerThread, ConversionPatternRewriter &rewriter, Location loc) const; Value getStructFromValueTable(const ValueTable &vals, ConversionPatternRewriter &rewriter, Location loc) const { SmallVector elemTypes(vals.size(), f32_ty); SmallVector elems; elems.reserve(vals.size()); for (auto &item : vals) { elems.push_back(item.second); } Type structTy = struct_ty(elemTypes); return getStructFromElements(loc, elems, rewriter, structTy); } // get number of elements per thread for $a or $b. static int getNumElemsPerThread(ArrayRef shape, DotOperandEncodingAttr dotOpLayout) { auto blockedLayout = dotOpLayout.getParent().cast(); auto shapePerCTA = getShapePerCTA(blockedLayout); auto sizePerThread = getSizePerThread(blockedLayout); // TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it // if not. int K = dotOpLayout.getOpIdx() == 0 ? shape[1] : shape[0]; int otherDim = dotOpLayout.getOpIdx() == 1 ? shape[1] : shape[0]; bool isM = dotOpLayout.getOpIdx() == 0; int shapePerCTAMN = getShapePerCTAForMN(blockedLayout, isM); int sizePerThreadMN = getsizePerThreadForMN(blockedLayout, isM); return K * std::max(otherDim / shapePerCTAMN, 1) * sizePerThreadMN; } // Get shapePerCTA for M or N axis. static int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) { auto order = layout.getOrder(); auto shapePerCTA = getShapePerCTA(layout); int mShapePerCTA = order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; int nShapePerCTA = order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; return isM ? mShapePerCTA : nShapePerCTA; } // Get sizePerThread for M or N axis. static int getsizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { auto order = layout.getOrder(); auto sizePerThread = getSizePerThread(layout); int mSizePerThread = order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; int nSizePerThread = order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; return isM ? mSizePerThread : nSizePerThread; } }; Value DotOpMmaV1ConversionHelper::loadA( Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); auto sharedLayout = tensorTy.getEncoding().cast(); SmallVector shape(tensorTy.getShape().begin(), tensorTy.getShape().end()); SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isARow = order[0] != 0; bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes // TODO[Superjomn]: Support the case when isAVec4=false later // Currently, we only support ld.v2, for the mma layout varies with different // ld vector width. isAVec4 = true; int packSize0 = (isARow || isAVec4) ? 1 : 2; SmallVector fpw({2, 2, 1}); int repM = 2 * packSize0; int repK = 1; int spwM = fpw[0] * 4 * repM; SmallVector rep({repM, 0, repK}); // pad N with 0 SmallVector spw({spwM, 0, 1}); // pad N with 0 auto [offsetAM, offsetAK, _0, _1] = computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); // TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp. bool transA = false; if (transA) { std::swap(shape[0], shape[1]); std::swap(offsetAM, offsetAK); std::swap(order[0], order[1]); } int vecA = sharedLayout.getVec(); auto strides = smemObj.strides; Value strideAM = isARow ? strides[0] : i32_val(1); Value strideAK = isARow ? i32_val(1) : strides[1]; Value strideA0 = isARow ? strideAK : strideAM; Value strideA1 = isARow ? strideAM : strideAK; int strideRepM = wpt[0] * fpw[0] * 8; int strideRepK = 1; // swizzling int perPhaseA = sharedLayout.getPerPhase(); int maxPhaseA = sharedLayout.getMaxPhase(); int stepA0 = isARow ? strideRepK : strideRepM; int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1); int NK = shape[1]; // pre-compute pointer lanes Value offA0 = isARow ? offsetAK : offsetAM; Value offA1 = isARow ? offsetAM : offsetAK; Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); offA0 = add(offA0, cSwizzleOffset); SmallVector offA(numPtrA); for (int i = 0; i < numPtrA; i++) { Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); offA0I = udiv(offA0I, i32_val(vecA)); offA0I = xor_(offA0I, phaseA); offA0I = mul(offA0I, i32_val(vecA)); offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); } Type f16x2Ty = vec_ty(f16_ty, 2); // prepare arguments SmallVector ptrA(numPtrA); std::map, std::pair> has; for (int i = 0; i < numPtrA; i++) ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]); unsigned numM = std::max(rep[0] * shape[0] / (spw[0] * wpt[0]), 1); Type f16PtrTy = ptr_ty(f16_ty); auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { vals[{m, k}] = {val0, val1}; }; auto loadA = [&](int m, int k) { int offidx = (isARow ? k / 4 : m) % numPtrA; Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]); int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), mul(i32_val(stepAK), strideAK)); Value pa = gep(f16PtrTy, thePtrA, offset); Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); Value ha = load(bitcast(pa, aPtrTy)); // record lds that needs to be moved Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); ld(has, m, k, ha00, ha01); if (vecA > 4) { Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); if (isARow) ld(has, m, k + 4, ha10, ha11); else ld(has, m + 1, k, ha10, ha11); } }; for (unsigned k = 0; k < NK; k += 4) for (unsigned m = 0; m < numM / 2; ++m) loadA(m, k); SmallVector elems; elems.reserve(has.size() * 2); for (auto item : has) { // has is a map, the key should be ordered. elems.push_back(item.second.first); elems.push_back(item.second.second); } Type resTy = struct_ty(SmallVector(elems.size(), f16x2Ty)); Value res = getStructFromElements(loc, elems, rewriter, resTy); return res; } Value DotOpMmaV1ConversionHelper::loadB( Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { // smem auto strides = smemObj.strides; auto *ctx = rewriter.getContext(); auto tensorTy = tensor.getType().cast(); auto sharedLayout = tensorTy.getEncoding().cast(); SmallVector shape(tensorTy.getShape().begin(), tensorTy.getShape().end()); SmallVector order(sharedLayout.getOrder().begin(), sharedLayout.getOrder().end()); Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); bool isBRow = order[0] != 0; bool isBVec4 = isBRow && shape[order[0]] <= 16; // TODO[Superjomn]: Support the case when isBVec4=false later // Currently, we only support ld.v2, for the mma layout varies with different // ld vector width. isBVec4 = true; int packSize1 = (isBRow && !isBVec4) ? 2 : 1; SmallVector fpw({2, 2, 1}); SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 int vecB = sharedLayout.getVec(); Value strideBN = isBRow ? i32_val(1) : strides[1]; Value strideBK = isBRow ? strides[0] : i32_val(1); Value strideB0 = isBRow ? strideBN : strideBK; Value strideB1 = isBRow ? strideBK : strideBN; int strideRepN = wpt[1] * fpw[1] * 8; int strideRepK = 1; // TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp. bool transB = false; auto [_0, _1, offsetBN, offsetBK] = computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); if (transB) { std::swap(order[0], order[1]); std::swap(shape[0], shape[1]); std::swap(offsetBK, offsetBN); } // swizzling int perPhaseB = sharedLayout.getPerPhase(); int maxPhaseB = sharedLayout.getMaxPhase(); int stepB0 = isBRow ? strideRepN : strideRepK; int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1); int NK = shape[0]; Value offB0 = isBRow ? offsetBN : offsetBK; Value offB1 = isBRow ? offsetBK : offsetBN; Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); offB0 = add(offB0, cSwizzleOffset); SmallVector offB(numPtrB); for (int i = 0; i < numPtrB; ++i) { Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); offB0I = udiv(offB0I, i32_val(vecB)); offB0I = xor_(offB0I, phaseB); offB0I = mul(offB0I, i32_val(vecB)); offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); } Type f16PtrTy = ptr_ty(f16_ty); Type f16x2Ty = vec_ty(f16_ty, 2); SmallVector ptrB(numPtrB); ValueTable hbs; for (int i = 0; i < numPtrB; ++i) ptrB[i] = gep(ptr_ty(f16_ty), smem, offB[i]); auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { vals[{m, k}] = {val0, val1}; }; auto loadB = [&](int n, int K) { int offidx = (isBRow ? n : K / 4) % numPtrB; Value thePtrB = ptrB[offidx]; int stepBN = isBRow ? n / numPtrB * numPtrB : n; int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), mul(i32_val(stepBK), strideBK)); Value pb = gep(f16PtrTy, thePtrB, offset); Value hb = load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); // record lds that needs to be moved Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); ld(hbs, n, K, hb00, hb01); if (vecB > 4) { Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); if (isBRow) ld(hbs, n + 1, K, hb10, hb11); else ld(hbs, n, K + 4, hb10, hb11); } }; unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); for (unsigned k = 0; k < NK; k += 4) for (unsigned n = 0; n < numN / 2; ++n) { if (!hbs.count({n, k})) loadB(n, k); } SmallVector elems; for (auto &item : hbs) { // has is a map, the key should be ordered. elems.push_back(item.second.first); elems.push_back(item.second.second); } Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); Type resTy = struct_ty(SmallVector(elems.size(), fp16x2Ty)); Value res = getStructFromElements(loc, elems, rewriter, resTy); return res; } std::tuple DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, ArrayRef spw, ArrayRef rep, ConversionPatternRewriter &rewriter, Location loc) const { auto *ctx = rewriter.getContext(); Value _1 = i32_val(1); Value _3 = i32_val(3); Value _4 = i32_val(4); Value _16 = i32_val(16); Value _32 = i32_val(32); Value lane = urem(threadId, _32); Value warp = udiv(threadId, _32); // warp offset Value warp0 = urem(warp, i32_val(wpt[0])); Value warp12 = udiv(warp, i32_val(wpt[0])); Value warp1 = urem(warp12, i32_val(wpt[1])); Value warpMOff = mul(warp0, i32_val(spw[0])); Value warpNOff = mul(warp1, i32_val(spw[1])); // Quad offset Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0])); Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1])); // Pair offset Value pairMOff = udiv(urem(lane, _16), _4); pairMOff = urem(pairMOff, i32_val(fpw[0])); pairMOff = mul(pairMOff, _4); Value pairNOff = udiv(urem(lane, _16), _4); pairNOff = udiv(pairNOff, i32_val(fpw[0])); pairNOff = urem(pairNOff, i32_val(fpw[1])); pairNOff = mul(pairNOff, _4); // scale pairMOff = mul(pairMOff, i32_val(rep[0] / 2)); quadMOff = mul(quadMOff, i32_val(rep[0] / 2)); pairNOff = mul(pairNOff, i32_val(rep[1] / 2)); quadNOff = mul(quadNOff, i32_val(rep[1] / 2)); // Quad pair offset Value laneMOff = add(pairMOff, quadMOff); Value laneNOff = add(pairNOff, quadNOff); // A offset Value offsetAM = add(warpMOff, laneMOff); Value offsetAK = and_(lane, _3); // B offset Value offsetBN = add(warpNOff, laneNOff); Value offsetBK = and_(lane, _3); // i indices Value offsetCM = add(and_(lane, _1), offsetAM); if (isARow) { offsetAM = add(offsetAM, urem(threadId, _4)); offsetAK = i32_val(0); } if (!isBRow) { offsetBN = add(offsetBN, urem(threadId, _4)); offsetBK = i32_val(0); } return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK); } DotOpMmaV1ConversionHelper::ValueTable DotOpMmaV1ConversionHelper::extractLoadedOperand( Value llStruct, int NK, ConversionPatternRewriter &rewriter) const { ValueTable rcds; SmallVector elems = getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter); int offset = 0; for (int i = 0; offset < elems.size(); ++i) { for (int k = 0; k < NK; k += 4) { rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); offset += 2; } } return rcds; } Value DotOpFMAConversionHelper::loadA( Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { auto aTensorTy = A.getType().cast(); auto aLayout = aTensorTy.getEncoding().cast(); auto aShape = aTensorTy.getShape(); auto aOrder = aLayout.getOrder(); auto order = dLayout.getOrder(); bool isARow = aOrder[0] == 1; int strideAM = isARow ? aShape[1] : 1; int strideAK = isARow ? 1 : aShape[0]; int strideA0 = isARow ? strideAK : strideAM; int strideA1 = isARow ? strideAM : strideAK; int aNumPtr = 8; int NK = aShape[1]; auto shapePerCTA = getShapePerCTA(dLayout); auto sizePerThread = getSizePerThread(dLayout); Value _0 = i32_val(0); Value mContig = i32_val(sizePerThread[order[1]]); // threadId in blocked layout auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); Value threadIdM = threadIds[0]; Value offA0 = isARow ? _0 : mul(threadIdM, mContig); Value offA1 = isARow ? mul(threadIdM, mContig) : _0; SmallVector aOff(aNumPtr); for (int i = 0; i < aNumPtr; ++i) { aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); } auto aSmem = getSharedMemoryObjectFromStruct(loc, llA, rewriter); Type f32PtrTy = ptr_ty(f32_ty); SmallVector aPtrs(aNumPtr); for (int i = 0; i < aNumPtr; ++i) aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]); ValueTable has; int M = aShape[aOrder[1]]; int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/); int mSizePerThread = getsizePerThreadForMN(dLayout, true /*isM*/); for (unsigned k = 0; k < NK; ++k) { for (unsigned m = 0; m < M; m += mShapePerCTA) for (unsigned mm = 0; mm < mSizePerThread; ++mm) if (!has.count({m + mm, k})) { Value pa = gep(f32PtrTy, aPtrs[0], i32_val((m + mm) * strideAM + k * strideAK)); Value va = load(pa); has[{m + mm, k}] = va; } } return getStructFromValueTable(has, rewriter, loc); } Value DotOpFMAConversionHelper::loadB( Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc, ConversionPatternRewriter &rewriter) const { auto bTensorTy = B.getType().cast(); auto bLayout = bTensorTy.getEncoding().cast(); auto bShape = bTensorTy.getShape(); auto bOrder = bLayout.getOrder(); auto order = dLayout.getOrder(); bool isBRow = bOrder[0] == 1; int strideBN = isBRow ? 1 : bShape[0]; int strideBK = isBRow ? bShape[1] : 1; int strideB0 = isBRow ? strideBN : strideBK; int strideB1 = isBRow ? strideBK : strideBN; int bNumPtr = 8; int NK = bShape[0]; auto shapePerCTA = getShapePerCTA(dLayout); auto sizePerThread = getSizePerThread(dLayout); Value _0 = i32_val(0); Value nContig = i32_val(sizePerThread[order[0]]); // threadId in blocked layout auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); Value threadIdN = threadIds[1]; Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); SmallVector bOff(bNumPtr); for (int i = 0; i < bNumPtr; ++i) { bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); } auto bSmem = getSharedMemoryObjectFromStruct(loc, llB, rewriter); Type f32PtrTy = ptr_ty(f32_ty); SmallVector bPtrs(bNumPtr); for (int i = 0; i < bNumPtr; ++i) bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]); int N = bShape[bOrder[0]]; ValueTable hbs; int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/); int nSizePerThread = getsizePerThreadForMN(dLayout, false /*isM*/); for (unsigned k = 0; k < NK; ++k) for (unsigned n = 0; n < N; n += nShapePerCTA) for (unsigned nn = 0; nn < nSizePerThread; ++nn) { Value pb = gep(f32PtrTy, bPtrs[0], i32_val((n + nn) * strideBN + k * strideBK)); Value vb = load(pb); hbs[{n + nn, k}] = vb; } return getStructFromValueTable(hbs, rewriter, loc); } DotOpFMAConversionHelper::ValueTable DotOpFMAConversionHelper::getValueTableFromStruct( Value val, int K, int n0, int shapePerCTA, int sizePerThread, ConversionPatternRewriter &rewriter, Location loc) const { ValueTable res; auto elems = getElementsFromStruct(loc, val, rewriter); int id = 0; std::set> keys; // ordered for (unsigned k = 0; k < K; ++k) { for (unsigned m = 0; m < n0; m += shapePerCTA) for (unsigned mm = 0; mm < sizePerThread; ++mm) { keys.insert({m + mm, k}); } } for (auto &key : llvm::enumerate(keys)) { res[key.value()] = elems[key.index()]; } return res; } SmallVector DotOpFMAConversionHelper::getThreadIds( Value threadId, ArrayRef shapePerCTA, ArrayRef order, ConversionPatternRewriter &rewriter, Location loc) const { int dim = order.size(); SmallVector threadIds(dim); for (unsigned k = 0; k < dim - 1; k++) { Value dimK = i32_val(shapePerCTA[order[k]]); Value rem = urem(threadId, dimK); threadId = udiv(threadId, dimK); threadIds[order[k]] = rem; } Value dimK = i32_val(shapePerCTA[order[dim - 1]]); threadIds[order[dim - 1]] = urem(threadId, dimK); return threadIds; } } // namespace LLVM } // namespace mlir #endif