diff --git a/lib/Conversion/TritonGPUToLLVM/DotHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h new file mode 100644 index 000000000..a50dbb66d --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/DotHelpers.h @@ -0,0 +1,1835 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_HELPERS_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_DOT_HELPERS_H + +#include "./Utility.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include +#include + +namespace mlir { +namespace LLVM { +using namespace mlir::triton; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::MmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; + +// Forward declaration necessary functions locates in TritonGPUToLLVM.cpp . +llvm::SmallVector +getElementsFromStruct(mlir::Location loc, mlir::Value llvmStruct, + mlir::ConversionPatternRewriter &rewriter); + +mlir::LLVM::SharedMemoryObject +getSharedMemoryObjectFromStruct(mlir::Location loc, mlir::Value llvmStruct, + mlir::ConversionPatternRewriter &rewriter); + +// Helper for conversion of DotOp with mma, that is sm<80 +struct DotOpMmaV1ConversionHelper { + MmaEncodingAttr mmaLayout; + ArrayRef wpt; + + using ValueTable = std::map, std::pair>; + + explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout) + : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} + + int getRepM(int M) const { + return std::max(M / (wpt[0] * instrShape[0]), 1); + } + int getRepN(int N) const { + return std::max(N / (wpt[1] * instrShape[1]), 1); + } + + static ArrayRef getMmaInstrShape() { return instrShape; } + + static Type getMmaRetType(TensorType operand) { + auto *ctx = operand.getContext(); + Type fp32Ty = type::f32Ty(ctx); + // f16*f16+f32->f32 + return struct_ty(SmallVector{8, fp32Ty}); + } + + // number of fp16x2 elements for $a. + int numElemsPerThreadA(RankedTensorType tensorTy) const { + auto shape = tensorTy.getShape(); + auto order = getOrder(); + + bool isARow = order[0] != 0; + bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes + int packSize0 = (isARow || isAVec4) ? 1 : 2; + + SmallVector fpw({2, 2, 1}); + int repM = 2 * packSize0; + int repK = 1; + int spwM = fpw[0] * 4 * repM; + SmallVector rep({repM, 0, repK}); // pad N with 0 + SmallVector spw({spwM, 0, 1}); // pad N with 0 + + int NK = shape[1]; + unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); + + // NOTE: We couldn't get the vec from the shared layout. + // int vecA = sharedLayout.getVec(); + // TODO[Superjomn]: Consider the case when vecA > 4 + bool vecGt4 = false; + int elemsPerLd = vecGt4 ? 4 : 2; + return (numM / 2) * (NK / 4) * elemsPerLd; + } + + // number of fp16x2 elements for $b. + int numElemsPerThreadB(RankedTensorType tensorTy) const { + auto shape = tensorTy.getShape(); + auto order = getOrder(); + bool isBRow = order[0] != 0; + bool isBVec4 = isBRow && shape[order[0]] <= 16; + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + SmallVector fpw({2, 2, 1}); + SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 + SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 + // NOTE: We couldn't get the vec from the shared layout. + // int vecB = sharedLayout.getVec(); + // TODO[Superjomn]: Consider the case when vecA > 4 + bool vecGt4 = false; + int elemsPerLd = vecGt4 ? 4 : 2; + int NK = shape[0]; + + unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); + return (numN / 2) * (NK / 4) * elemsPerLd; + } + + // Loading $a from smem to registers, returns a LLVM::Struct. + Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; + + // Loading $b from smem to registers, returns a LLVM::Struct. + Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; + + static ArrayRef getOrder() { return mmaOrder; } + + // Compute the offset of the matrix to load. + // Returns offsetAM, offsetAK, offsetBN, offsetBK. + // NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at + // the same time in the usage in convert_layout[shared->dot_op], we leave + // the noexist info to be 0 and only use the desired argument from the + // composed result. In this way we want to retain the original code + // structure in convert_mma884 method for easier debugging. + std::tuple + computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, + ArrayRef spw, ArrayRef rep, + ConversionPatternRewriter &rewriter, Location loc) const; + + // Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1. + DotOpMmaV1ConversionHelper::ValueTable + extractLoadedOperand(Value llStruct, int NK, + ConversionPatternRewriter &rewriter) const; + +private: + static constexpr unsigned instrShape[] = {16, 16, 4}; + static constexpr unsigned mmaOrder[] = {0, 1}; +}; + +// Helper for conversion of DotOp with mma, that is sm>=80 +struct DotOpMmaV2ConversionHelper { + enum class TensorCoreType : uint8_t { + // floating-point tensor core instr + FP32_FP16_FP16_FP32 = 0, // default + FP32_BF16_BF16_FP32, + FP32_TF32_TF32_FP32, + // integer tensor core instr + INT32_INT1_INT1_INT32, // Not implemented + INT32_INT4_INT4_INT32, // Not implemented + INT32_INT8_INT8_INT32, // Not implemented + // + NOT_APPLICABLE, + }; + + MmaEncodingAttr mmaLayout; + MLIRContext *ctx{}; + + explicit DotOpMmaV2ConversionHelper(MmaEncodingAttr mmaLayout) + : mmaLayout(mmaLayout) { + ctx = mmaLayout.getContext(); + } + + void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); } + void deduceMmaType(Type operandTy) const { + mmaType = getTensorCoreTypeFromOperand(operandTy); + } + + // Get the M and N of mma instruction shape. + static std::tuple getInstrShapeMN() { + // According to DotOpConversionHelper::mmaInstrShape, all the M,N are + // {16,8} + return {16, 8}; + } + + static std::tuple getRepMN(const RankedTensorType &tensorTy) { + auto mmaLayout = tensorTy.getEncoding().cast(); + auto wpt = mmaLayout.getWarpsPerCTA(); + + int M = tensorTy.getShape()[0]; + int N = tensorTy.getShape()[1]; + auto [instrM, instrN] = getInstrShapeMN(); + int repM = std::max(M / (wpt[0] * instrM), 1); + int repN = std::max(N / (wpt[1] * instrN), 1); + return {repM, repN}; + } + + Type getShemPtrTy() const { + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return ptr_ty(type::f16Ty(ctx), 3); + case TensorCoreType::FP32_BF16_BF16_FP32: + return ptr_ty(type::bf16Ty(ctx), 3); + case TensorCoreType::FP32_TF32_TF32_FP32: + return ptr_ty(type::f32Ty(ctx), 3); + case TensorCoreType::INT32_INT8_INT8_INT32: + return ptr_ty(type::i8Ty(ctx), 3); + default: + llvm::report_fatal_error("mma16816 data type not supported"); + } + return Type{}; + } + + // The type of matrix that loaded by either a ldmatrix or composed lds. + Type getMatType() const { + Type fp32Ty = type::f32Ty(ctx); + Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); + Type bf16x2Ty = vec_ty(type::bf16Ty(ctx), 2); + // floating point types + Type fp16x2Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); + Type bf16x2Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, bf16x2Ty)); + Type fp32Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); + // integer types + Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); + Type i8x4Pack4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); + + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return fp16x2Pack4Ty; + case TensorCoreType::FP32_BF16_BF16_FP32: + return bf16x2Pack4Ty; + case TensorCoreType::FP32_TF32_TF32_FP32: + return fp32Pack4Ty; + case TensorCoreType::INT32_INT8_INT8_INT32: + return i8x4Pack4Ty; + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; + } + + Type getLoadElemTy() { + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return vec_ty(type::f16Ty(ctx), 2); + case TensorCoreType::FP32_BF16_BF16_FP32: + return vec_ty(type::bf16Ty(ctx), 2); + case TensorCoreType::FP32_TF32_TF32_FP32: + return type::f32Ty(ctx); + case TensorCoreType::INT32_INT8_INT8_INT32: + return type::i32Ty(ctx); + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; + } + + Type getMmaRetType() const { + Type fp32Ty = type::f32Ty(ctx); + Type i32Ty = type::i32Ty(ctx); + Type fp32x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); + Type i32x4Ty = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i32Ty)); + switch (mmaType) { + case TensorCoreType::FP32_FP16_FP16_FP32: + return fp32x4Ty; + case TensorCoreType::FP32_BF16_BF16_FP32: + return fp32x4Ty; + case TensorCoreType::FP32_TF32_TF32_FP32: + return fp32x4Ty; + case TensorCoreType::INT32_INT8_INT8_INT32: + return i32x4Ty; + default: + llvm::report_fatal_error("Unsupported mma type found"); + } + + return Type{}; + } + + ArrayRef getMmaInstrShape() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrShape.at(mmaType); + } + + static ArrayRef getMmaInstrShape(TensorCoreType tensorCoreType) { + assert(tensorCoreType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrShape.at(tensorCoreType); + } + + ArrayRef getMmaMatShape() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaMatShape.at(mmaType); + } + + // Deduce the TensorCoreType from either $a or $b's type. + static TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) { + auto tensorTy = operandTy.cast(); + auto elemTy = tensorTy.getElementType(); + if (elemTy.isF16()) + return TensorCoreType::FP32_FP16_FP16_FP32; + if (elemTy.isF32()) + return TensorCoreType::FP32_TF32_TF32_FP32; + if (elemTy.isBF16()) + return TensorCoreType::FP32_BF16_BF16_FP32; + if (elemTy.isInteger(8)) + return TensorCoreType::INT32_INT8_INT8_INT32; + return TensorCoreType::NOT_APPLICABLE; + } + + int getVec() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrVec.at(mmaType); + } + + StringRef getMmaInstr() const { + assert(mmaType != TensorCoreType::NOT_APPLICABLE && + "Unknown mma type found."); + return mmaInstrPtx.at(mmaType); + } + + static TensorCoreType getMmaType(triton::DotOp op) { + Value A = op.a(); + Value B = op.b(); + auto aTy = A.getType().cast(); + auto bTy = B.getType().cast(); + // d = a*b + c + auto dTy = op.d().getType().cast(); + + if (dTy.getElementType().isF32()) { + if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) + return TensorCoreType::FP32_FP16_FP16_FP32; + if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) + return TensorCoreType::FP32_BF16_BF16_FP32; + if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && + op.allowTF32()) + return TensorCoreType::FP32_TF32_TF32_FP32; + } else if (dTy.getElementType().isInteger(32)) { + if (aTy.getElementType().isInteger(8) && + bTy.getElementType().isInteger(8)) + return TensorCoreType::INT32_INT8_INT8_INT32; + } + + return TensorCoreType::NOT_APPLICABLE; + } + +private: + mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE}; + + // Used on nvidia GPUs mma layout .version == 2 + // Refer to + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-storage + // for more details. + inline static const std::map> + mmaInstrShape = { + {TensorCoreType::FP32_FP16_FP16_FP32, {16, 8, 16}}, + {TensorCoreType::FP32_BF16_BF16_FP32, {16, 8, 16}}, + {TensorCoreType::FP32_TF32_TF32_FP32, {16, 8, 8}}, + + {TensorCoreType::INT32_INT1_INT1_INT32, {16, 8, 256}}, + {TensorCoreType::INT32_INT4_INT4_INT32, {16, 8, 64}}, + {TensorCoreType::INT32_INT8_INT8_INT32, {16, 8, 32}}, + }; + + // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) + // Refer to + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + // for more details. + inline static const std::map> + mmaMatShape = { + {TensorCoreType::FP32_FP16_FP16_FP32, {8, 8, 8}}, + {TensorCoreType::FP32_BF16_BF16_FP32, {8, 8, 8}}, + {TensorCoreType::FP32_TF32_TF32_FP32, {8, 8, 4}}, + + {TensorCoreType::INT32_INT1_INT1_INT32, {8, 8, 64}}, + {TensorCoreType::INT32_INT4_INT4_INT32, {8, 8, 32}}, + {TensorCoreType::INT32_INT8_INT8_INT32, {8, 8, 16}}, + }; + + // Supported mma instruction in PTX. + // Refer to + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + // for more details. + inline static const std::map mmaInstrPtx = { + {TensorCoreType::FP32_FP16_FP16_FP32, + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, + {TensorCoreType::FP32_BF16_BF16_FP32, + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, + {TensorCoreType::FP32_TF32_TF32_FP32, + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, + + {TensorCoreType::INT32_INT1_INT1_INT32, + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, + {TensorCoreType::INT32_INT4_INT4_INT32, + "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, + {TensorCoreType::INT32_INT8_INT8_INT32, + "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, + }; + + // vector length per ldmatrix (16*8/element_size_in_bits) + inline static const std::map mmaInstrVec = { + {TensorCoreType::FP32_FP16_FP16_FP32, 8}, + {TensorCoreType::FP32_BF16_BF16_FP32, 8}, + {TensorCoreType::FP32_TF32_TF32_FP32, 4}, + + {TensorCoreType::INT32_INT1_INT1_INT32, 128}, + {TensorCoreType::INT32_INT4_INT4_INT32, 32}, + {TensorCoreType::INT32_INT8_INT8_INT32, 16}, + }; +}; + +// Data loader for mma.16816 instruction. +class MMA16816SmemLoader { +public: + MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, + ArrayRef smemStrides, ArrayRef tileShape, + ArrayRef instrShape, ArrayRef matShape, + int perPhase, int maxPhase, int elemBytes, + ConversionPatternRewriter &rewriter, + TypeConverter *typeConverter, const Location &loc) + : order(order.begin(), order.end()), kOrder(kOrder), + tileShape(tileShape.begin(), tileShape.end()), + instrShape(instrShape.begin(), instrShape.end()), + matShape(matShape.begin(), matShape.end()), perPhase(perPhase), + maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc), + ctx(rewriter.getContext()) { + cMatShape = matShape[order[0]]; + sMatShape = matShape[order[1]]; + + sStride = smemStrides[order[1]]; + + // rule: k must be the fast-changing axis. + needTrans = kOrder != order[0]; + canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16 + + if (canUseLdmatrix) { + // Each CTA, the warps is arranged as [1xwpt] if not transposed, + // otherwise [wptx1], and each warp will perform a mma. + numPtrs = + tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; + } else { + numPtrs = tileShape[order[0]] / wpt / matShape[order[0]]; + } + numPtrs = std::max(numPtrs, 2); + + // Special rule for i8/u8, 4 ptrs for each matrix + if (!canUseLdmatrix && elemBytes == 1) + numPtrs *= 4; + + int loadStrideInMat[2]; + loadStrideInMat[kOrder] = + 2; // instrShape[kOrder] / matShape[kOrder], always 2 + loadStrideInMat[kOrder ^ 1] = + wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]); + + pLoadStrideInMat = loadStrideInMat[order[0]]; + sMatStride = + loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]); + + // Each matArr contains warpOffStride matrices. + matArrStride = kOrder == 1 ? 1 : wpt; + warpOffStride = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]; + } + + // lane = thread % 32 + // warpOff = (thread/32) % wpt(0) + llvm::SmallVector computeOffsets(Value warpOff, Value lane, + Value cSwizzleOffset) { + if (canUseLdmatrix) + return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); + else if (elemBytes == 4 && needTrans) + return computeB32MatOffs(warpOff, lane, cSwizzleOffset); + else if (elemBytes == 1 && needTrans) + return computeB8MatOffs(warpOff, lane, cSwizzleOffset); + else + llvm::report_fatal_error("Invalid smem load config"); + + return {}; + } + + int getNumPtrs() const { return numPtrs; } + + // Compute the offset to the matrix this thread(indexed by warpOff and lane) + // mapped to. + SmallVector computeLdmatrixMatOffs(Value warpId, Value lane, + Value cSwizzleOffset) { + // 4x4 matrices + Value c = urem(lane, i32_val(8)); + Value s = udiv(lane, i32_val(8)); // sub-warp-id + + // Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a + // warp + Value s0 = urem(s, i32_val(2)); + Value s1 = udiv(s, i32_val(2)); + + // We use different orders for a and b for better performance. + Value kMatArr = kOrder == 1 ? s1 : s0; + Value nkMatArr = kOrder == 1 ? s0 : s1; + + // matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and + // [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is + // |0 0 1 1 2 2| -> 0,1,2 are the warpids + // |0 0 1 1 2 2| + // + // for B(kOrder=0) is + // |0 0| -> 0,1,2 are the warpids + // |1 1| + // |2 2| + // |0 0| + // |1 1| + // |2 2| + // Note, for each warp, it handles a 2x2 matrices, that is the coordinate + // address (s0,s1) annotates. + + Value matOff[2]; + matOff[kOrder ^ 1] = add( + mul(warpId, i32_val(warpOffStride)), // warp offset + mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp + matOff[kOrder] = kMatArr; + + // Physical offset (before swizzling) + Value cMatOff = matOff[order[0]]; + Value sMatOff = matOff[order[1]]; + Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); + cMatOff = add(cMatOff, cSwizzleMatOff); + + // row offset inside a matrix, each matrix has 8 rows. + Value sOffInMat = c; + + SmallVector offs(numPtrs); + Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); + Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); + for (int i = 0; i < numPtrs; ++i) { + Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); + cMatOffI = xor_(cMatOffI, phase); + offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride)); + } + + return offs; + } + + // Compute 32-bit matrix offsets. + SmallVector computeB32MatOffs(Value warpOff, Value lane, + Value cSwizzleOffset) { + assert(needTrans && "Only used in transpose mode."); + // Load tf32 matrices with lds32 + Value cOffInMat = udiv(lane, i32_val(4)); + Value sOffInMat = urem(lane, i32_val(4)); + + Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); + SmallVector offs(numPtrs); + + for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time + int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; + int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; + if (kMatArrInt > 0) // we don't need pointers for k + continue; + Value kMatArr = i32_val(kMatArrInt); + Value nkMatArr = i32_val(nkMatArrInt); + + Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), + mul(nkMatArr, i32_val(matArrStride))); + Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); + cMatOff = add(cMatOff, cSwizzleMatOff); + + Value sMatOff = kMatArr; + Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); + // FIXME: (kOrder == 1?) is really dirty hack + for (int i = 0; i < numPtrs / 2; ++i) { + Value cMatOffI = + add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); + cMatOffI = xor_(cMatOffI, phase); + Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); + cOff = urem(cOff, i32_val(tileShape[order[0]])); + sOff = urem(sOff, i32_val(tileShape[order[1]])); + offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride)); + } + } + return offs; + } + + // compute 8-bit matrix offset. + SmallVector computeB8MatOffs(Value warpOff, Value lane, + Value cSwizzleOffset) { + assert(needTrans && "Only used in transpose mode."); + Value cOffInMat = udiv(lane, i32_val(4)); + Value sOffInMat = + mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols + + SmallVector offs(numPtrs); + for (int mat = 0; mat < 4; ++mat) { + int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; + int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; + if (kMatArrInt > 0) // we don't need pointers for k + continue; + Value kMatArr = i32_val(kMatArrInt); + Value nkMatArr = i32_val(nkMatArrInt); + + Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), + mul(nkMatArr, i32_val(matArrStride))); + Value sMatOff = kMatArr; + + for (int loadx4Off = 0; loadx4Off < numPtrs / 8; ++loadx4Off) { + for (int elemOff = 0; elemOff < 4; ++elemOff) { + int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; + Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * + (kOrder == 1 ? 1 : 2))); + Value sOffInMatElem = add(sOffInMat, i32_val(elemOff)); + + // disable swizzling ... + + Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); + Value sOff = add(sOffInMatElem, mul(sMatOff, i32_val(sMatShape))); + // To prevent out-of-bound access when tile is too small. + cOff = urem(cOff, i32_val(tileShape[order[0]])); + sOff = urem(sOff, i32_val(tileShape[order[1]])); + offs[ptrOff] = add(cOff, mul(sOff, sStride)); + } + } + } + return offs; + } + + // Load 4 matrices and returns 4 vec<2> elements. + std::tuple + loadX4(int mat0, int mat1, ArrayRef offs, ArrayRef ptrs, + Type ldmatrixRetTy, Type shemPtrTy) const { + assert(mat0 % 2 == 0 && mat1 % 2 == 0 && + "smem matrix load must be aligned"); + int matIdx[2] = {mat0, mat1}; + + int ptrIdx{-1}; + + if (canUseLdmatrix) + ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); + else if (elemBytes == 4 && needTrans) + ptrIdx = matIdx[order[0]]; + else if (elemBytes == 1 && needTrans) + ptrIdx = matIdx[order[0]] * 4; + else + llvm::report_fatal_error("unsupported mma type found"); + + // The main difference with the original triton code is we removed the + // prefetch-related logic here for the upstream optimizer phase should + // take care with it, and that is transparent in dot conversion. + auto getPtr = [&](int idx) { return ptrs[idx]; }; + + Value ptr = getPtr(ptrIdx); + + if (canUseLdmatrix) { + Value sOffset = + mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride); + Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset); + + PTXBuilder builder; + // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a + // thread. + auto resArgs = builder.newListOperand(4, "=r"); + auto addrArg = builder.newAddrOperand(sOffsetPtr, "r"); + + auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") + ->o("trans", needTrans /*predicate*/) + .o("shared.b16"); + ldmatrix(resArgs, addrArg); + + // The result type is 4xi32, each i32 is composed of 2xf16 + // elements(adjacent two columns in a row) + Value resV4 = builder.launch(rewriter, loc, ldmatrixRetTy); + + auto getIntAttr = [&](int v) { + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); + }; + + // The struct should have exactly the same element types. + Type elemType = resV4.getType().cast().getBody()[0]; + + return {extract_val(elemType, resV4, getIntAttr(0)), + extract_val(elemType, resV4, getIntAttr(1)), + extract_val(elemType, resV4, getIntAttr(2)), + extract_val(elemType, resV4, getIntAttr(3))}; + } else if (elemBytes == 4 && + needTrans) { // Use lds.32 to load tf32 matrices + Value ptr2 = getPtr(ptrIdx + 1); + assert(sMatStride == 1); + int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); + Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); + int sOffsetArrElem = sMatStride * sMatShape; + Value sOffsetArrElemVal = + add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); + + Value elems[4]; + Type elemTy = type::f32Ty(ctx); + Type elemPtrTy = ptr_ty(elemTy); + if (kOrder == 1) { + elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); + elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); + elems[2] = + load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + elems[3] = + load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + } else { + elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); + elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); + elems[1] = + load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + elems[3] = + load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + } + return {elems[0], elems[1], elems[2], elems[3]}; + + } else if (elemBytes == 1 && needTrans) { // work with int8 + std::array, 2> ptrs; + ptrs[0] = { + getPtr(ptrIdx), + getPtr(ptrIdx + 1), + getPtr(ptrIdx + 2), + getPtr(ptrIdx + 3), + }; + + ptrs[1] = { + getPtr(ptrIdx + 4), + getPtr(ptrIdx + 5), + getPtr(ptrIdx + 6), + getPtr(ptrIdx + 7), + }; + + assert(sMatStride == 1); + int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); + Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); + int sOffsetArrElem = 1 * (sMatStride * sMatShape); + Value sOffsetArrElemVal = + add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); + + std::array i8v4Elems; + std::array i32Elems; + i8v4Elems.fill( + rewriter.create(loc, vec_ty(type::i8Ty(ctx), 4))); + + Value i8Elems[4][4]; + Type elemTy = type::i8Ty(ctx); + Type elemPtrTy = ptr_ty(elemTy); + Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); + if (kOrder == 1) { + for (int i = 0; i < 2; ++i) + for (int j = 0; j < 4; ++j) + i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], sOffsetElemVal)); + + for (int i = 2; i < 4; ++i) + for (int j = 0; j < 4; ++j) + i8Elems[i][j] = + load(gep(elemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal)); + + for (int m = 0; m < 4; ++m) { + for (int e = 0; e < 4; ++e) + i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], + i8Elems[m][e], i32_val(e)); + i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); + } + } else { // k first + for (int j = 0; j < 4; ++j) + i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetElemVal)); + for (int j = 0; j < 4; ++j) + i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetElemVal)); + for (int j = 0; j < 4; ++j) + i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetArrElemVal)); + for (int j = 0; j < 4; ++j) + i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetArrElemVal)); + + for (int m = 0; m < 4; ++m) { + for (int e = 0; e < 4; ++e) + i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], + i8Elems[m][e], i32_val(e)); + i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); + } + } + + return {i32Elems[0], i32Elems[1], i32Elems[2], i32Elems[3]}; + } + + assert(false && "Invalid smem load"); + return {Value{}, Value{}, Value{}, Value{}}; + } + +private: + SmallVector order; + int kOrder; + SmallVector tileShape; + SmallVector instrShape; + SmallVector matShape; + int perPhase; + int maxPhase; + int elemBytes; + ConversionPatternRewriter &rewriter; + const Location &loc; + MLIRContext *ctx{}; + + int cMatShape; + int sMatShape; + + Value sStride; + + bool needTrans; + bool canUseLdmatrix; + + int numPtrs; + + int pLoadStrideInMat; + int sMatStride; + + int matArrStride; + int warpOffStride; +}; + +// This class helps to adapt the existing DotOpConversion to the latest +// DotOpOperand layout design. It decouples the exising implementation to two +// parts: +// 1. loading the specific operand matrix(for $a, $b, $c) from smem +// 2. passing the loaded value and perform the mma codegen +struct MMA16816ConversionHelper { + MmaEncodingAttr mmaLayout; + ArrayRef wpt; + SmallVector properWpt; + + Value thread, lane, warp; + + DotOpMmaV2ConversionHelper helper; + ConversionPatternRewriter &rewriter; + TypeConverter *typeConverter; + Location loc; + MLIRContext *ctx{}; + + using ValueTable = std::map, Value>; + + // dotOperand: type of either one operand of dotOp. + MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, + Value thread, ConversionPatternRewriter &rewriter, + TypeConverter *typeConverter, Location loc) + : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), + rewriter(rewriter), typeConverter(typeConverter), loc(loc), + ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { + helper.deduceMmaType(dotOperand); + + Value _32 = i32_val(32); + lane = urem(thread, _32); + warp = udiv(thread, _32); + } + + // Get a warpId for M axis. + Value getWarpM(int M) const { + auto matShape = helper.getMmaMatShape(); + return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0])); + } + + // Get a warpId for N axis. + Value getWarpN(int N) const { + auto matShape = helper.getMmaMatShape(); + Value warpMN = udiv(warp, i32_val(wpt[0])); + return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1])); + } + + // Get the mmaInstrShape deducing either from $a or $b. + std::tuple getMmaInstrShape(Type operand) const { + helper.deduceMmaType(operand); + auto mmaInstrShape = helper.getMmaInstrShape(); + int mmaInstrM = mmaInstrShape[0]; + int mmaInstrN = mmaInstrShape[1]; + int mmaInstrK = mmaInstrShape[2]; + return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK); + } + + // Get the mmaMatShape deducing either from $a or $b. + std::tuple getMmaMatShape(Type operand) const { + helper.deduceMmaType(operand); + auto matShape = helper.getMmaMatShape(); + int matShapeM = matShape[0]; + int matShapeN = matShape[1]; + int matShapeK = matShape[2]; + return std::make_tuple(matShapeM, matShapeN, matShapeK); + } + + // \param operand is either $a or $b's type. + inline int getNumRepM(Type operand, int M) const { + return getNumRepM(operand, M, wpt[0]); + } + + // \param operand is either $a or $b's type. + inline int getNumRepN(Type operand, int N) const { + return getNumRepN(operand, N, wpt[1]); + } + + // \param operand is either $a or $b's type. + inline int getNumRepK(Type operand, int K) const { + return getNumRepK_(operand, K); + } + + static int getNumRepM(Type operand, int M, int wpt) { + auto tensorCoreType = + DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrM = + DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[0]; + return std::max(M / (wpt * mmaInstrM), 1); + } + + static int getNumRepN(Type operand, int N, int wpt) { + auto tensorCoreType = + DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrN = + DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[1]; + return std::max(N / (wpt * mmaInstrN), 1); + } + + static int getNumRepK_(Type operand, int K) { + auto tensorCoreType = + DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); + int mmaInstrK = + DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[2]; + return std::max(K / mmaInstrK, 1); + } + + // Get number of elements per thread for $a operand. + static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) { + auto shape = operand.getShape(); + int repM = getNumRepM(operand, shape[0], wpt); + int repK = getNumRepK_(operand, shape[1]); + return 4 * repM * repK; + } + + // Get number of elements per thread for $b operand. + static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) { + auto shape = operand.getShape(); + int repK = getNumRepK_(operand, shape[0]); + int repN = getNumRepN(operand, shape[1], wpt); + return 4 * std::max(repN / 2, 1) * repK; + } + + // Loading $a from smem to registers, returns a LLVM::Struct. + Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { + auto aTensorTy = tensor.getType().cast(); + auto layout = aTensorTy.getEncoding().cast(); + + SmallVector shape(aTensorTy.getShape().begin(), + aTensorTy.getShape().end()); + // TODO[Superjomn]: transA cannot be accessed in ConvertLayoutOp. + bool transA = false; + if (transA) { + std::swap(shape[0], shape[1]); + } + + ValueTable ha; + std::function loadFn; + auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); + + int numRepM = getNumRepM(aTensorTy, shape[0]); + int numRepK = getNumRepK(aTensorTy, shape[1]); + + if (aTensorTy.getEncoding().isa()) { + Value warpM = getWarpM(shape[0]); + // load from smem + loadFn = getLoadMatrixFn( + tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, + 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, + true /*isA*/); + } else if (aTensorTy.getEncoding().isa()) { + // load from registers, used in gemm fuse + // TODO(Superjomn) Port the logic. + assert(false && "Loading A from register is not supported yet."); + } else { + assert(false && "A's layout is not supported."); + } + + // step1. Perform loading. + for (int m = 0; m < numRepM; ++m) + for (int k = 0; k < numRepK; ++k) + loadFn(2 * m, 2 * k); + + // step2. Format the values to LLVM::Struct to passing to mma codegen. + return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); + } + + // Loading $b from smem to registers, returns a LLVM::Struct. + Value loadB(Value tensor, const SharedMemoryObject &smemObj) { + ValueTable hb; + auto tensorTy = tensor.getType().cast(); + auto layout = tensorTy.getEncoding().cast(); + + SmallVector shape(tensorTy.getShape().begin(), + tensorTy.getShape().end()); + + // TODO[Superjomn]: transB cannot be accessed in ConvertLayoutOp. + bool transB = false; + if (transB) { + std::swap(shape[0], shape[1]); + } + + auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy); + auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy); + int numRepK = getNumRepK(tensorTy, shape[0]); + int numRepN = getNumRepN(tensorTy, shape[1]); + + Value warpN = getWarpN(shape[1]); + auto loadFn = getLoadMatrixFn( + tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, + 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, + false /*isA*/); + + for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { + for (int k = 0; k < numRepK; ++k) + loadFn(2 * n, 2 * k); + } + + Value result = composeValuesToDotOperandLayoutStruct( + hb, std::max(numRepN / 2, 1), numRepK); + return result; + } + + // Loading $c to registers, returns a Value. + Value loadC(Value tensor, Value llTensor) const { + auto tensorTy = tensor.getType().cast(); + auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy); + size_t fcSize = 4 * repM * repN; + + assert(tensorTy.getEncoding().isa() && + "Currently, we only support $c with a mma layout."); + // Load a normal C tensor with mma layout, that should be a + // LLVM::struct with fcSize elements. + auto structTy = llTensor.getType().cast(); + assert(structTy.getBody().size() == fcSize && + "DotOp's $c operand should pass the same number of values as $d in " + "mma layout."); + return llTensor; + } + + // Conduct the Dot conversion. + // \param a, \param b, \param c and \param d are DotOp operands. + // \param loadedA, \param loadedB, \param loadedC, all of them are result of + // loading. + LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA, + Value loadedB, Value loadedC, DotOp op, + DotOpAdaptor adaptor) const { + helper.deduceMmaType(op); + + auto aTensorTy = a.getType().cast(); + auto dTensorTy = d.getType().cast(); + + SmallVector aShape(aTensorTy.getShape().begin(), + aTensorTy.getShape().end()); + if (op.transA()) + std::swap(aShape[0], aShape[1]); + + auto dShape = dTensorTy.getShape(); + + // shape / shape_per_cta + int numRepM = getNumRepM(aTensorTy, dShape[0]); + int numRepN = getNumRepN(aTensorTy, dShape[1]); + int numRepK = getNumRepK(aTensorTy, aShape[1]); + + ValueTable ha = + getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK); + ValueTable hb = getValuesFromDotOperandLayoutStruct( + loadedB, std::max(numRepN / 2, 1), numRepK); + auto fc = getElementsFromStruct(loc, loadedC, rewriter); + + auto callMma = [&](unsigned m, unsigned n, unsigned k) { + unsigned colsPerThread = numRepN * 2; + PTXBuilder builder; + auto &mma = *builder.create(helper.getMmaInstr().str()); + auto retArgs = builder.newListOperand(4, "=r"); + auto aArgs = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs = + builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < 4; ++i) { + cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i], + std::to_string(i))); + // reuse the output registers + } + + mma(retArgs, aArgs, bArgs, cArgs); + Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); + + auto getIntAttr = [&](int v) { + return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); + }; + + Type elemTy = mmaOut.getType().cast().getBody()[0]; + for (int i = 0; i < 4; ++i) + fc[m * colsPerThread + 4 * n + i] = + extract_val(elemTy, mmaOut, getIntAttr(i)); + }; + + for (int k = 0; k < numRepK; ++k) + for (int m = 0; m < numRepM; ++m) + for (int n = 0; n < numRepN; ++n) + callMma(2 * m, n, 2 * k); + + Type resElemTy = dTensorTy.getElementType(); + + for (auto &elem : fc) { + elem = bitcast(elem, resElemTy); + } + + // replace with new packed result + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(fc.size(), resElemTy)); + Value res = getStructFromElements(loc, fc, rewriter, structTy); + rewriter.replaceOp(op, res); + + return success(); + } + +private: + std::function + getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, + MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, + ArrayRef instrShape, ArrayRef matShape, + Value warpId, ValueTable &vals, bool isA) const { + auto tensorTy = tensor.getType().cast(); + // We assumes that the input operand of Dot should be from shared layout. + // TODO(Superjomn) Consider other layouts if needed later. + auto sharedLayout = tensorTy.getEncoding().cast(); + const int perPhase = sharedLayout.getPerPhase(); + const int maxPhase = sharedLayout.getMaxPhase(); + const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; + auto order = sharedLayout.getOrder(); + + // the original register_lds2, but discard the prefetch logic. + auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { + vals[{mn, k}] = val; + }; + + // (a, b) is the coordinate. + auto load = [=, &vals, &ld2](int a, int b) { + MMA16816SmemLoader loader( + wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, + tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, + maxPhase, elemBytes, rewriter, typeConverter, loc); + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + SmallVector offs = + loader.computeOffsets(warpId, lane, cSwizzleOffset); + const int numPtrs = loader.getNumPtrs(); + SmallVector ptrs(numPtrs); + + Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + Type smemPtrTy = helper.getShemPtrTy(); + for (int i = 0; i < numPtrs; ++i) { + ptrs[i] = + bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy); + } + + auto [ha0, ha1, ha2, ha3] = loader.loadX4( + (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, + ptrs, helper.getMatType(), helper.getShemPtrTy()); + + if (isA) { + ld2(vals, a, b, ha0); + ld2(vals, a + 1, b, ha1); + ld2(vals, a, b + 1, ha2); + ld2(vals, a + 1, b + 1, ha3); + } else { + ld2(vals, a, b, ha0); + ld2(vals, a + 1, b, ha2); + ld2(vals, a, b + 1, ha1); + ld2(vals, a + 1, b + 1, ha3); + } + }; + + return load; + } + + // Compose a map of Values to a LLVM::Struct. + // The layout is a list of Value with coordinate of (i,j), the order is as + // the follows: + // [ + // (0,0), (0,1), (1,0), (1,1), # i=0, j=0 + // (0,2), (0,3), (1,2), (1,3), # i=0, j=1 + // (0,4), (0,5), (1,4), (1,5), # i=0, j=2 + // ... + // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 + // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 + // (2,4), (2,5), (3,4), (3,5), # i=1, j=2 + // ... + // ] + // i \in [0, n0) and j \in [0, n1) + // There should be \param n0 * \param n1 elements in the output Struct. + Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, + int n1) const { + std::vector elems; + for (int m = 0; m < n0; ++m) + for (int k = 0; k < n1; ++k) { + elems.push_back(vals.at({2 * m, 2 * k})); + elems.push_back(vals.at({2 * m, 2 * k + 1})); + elems.push_back(vals.at({2 * m + 1, 2 * k})); + elems.push_back(vals.at({2 * m + 1, 2 * k + 1})); + } + + assert(!elems.empty()); + + Type elemTy = elems[0].getType(); + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems.size(), elemTy)); + auto result = getStructFromElements(loc, elems, rewriter, structTy); + return result; + } + + ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, + int n1) const { + auto elems = getElementsFromStruct(loc, value, rewriter); + + int offset{}; + ValueTable vals; + for (int i = 0; i < n0; ++i) { + for (int j = 0; j < n1; j++) { + vals[{2 * i, 2 * j}] = elems[offset++]; + vals[{2 * i, 2 * j + 1}] = elems[offset++]; + vals[{2 * i + 1, 2 * j}] = elems[offset++]; + vals[{2 * i + 1, 2 * j + 1}] = elems[offset++]; + } + } + return vals; + } +}; + +// Helper for conversion of FMA DotOp. +struct DotOpFMAConversionHelper { + Attribute layout; + MLIRContext *ctx{}; + + using ValueTable = std::map, Value>; + + explicit DotOpFMAConversionHelper(Attribute layout) + : layout(layout), ctx(layout.getContext()) {} + + SmallVector getThreadIds(Value threadId, + ArrayRef shapePerCTA, + ArrayRef order, + ConversionPatternRewriter &rewriter, + Location loc) const; + + Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; + + Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, + Location loc, ConversionPatternRewriter &rewriter) const; + + ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, + int sizePerThread, + ConversionPatternRewriter &rewriter, + Location loc) const; + + Value getStructFromValueTable(const ValueTable &vals, + ConversionPatternRewriter &rewriter, + Location loc) const { + SmallVector elemTypes(vals.size(), f32_ty); + SmallVector elems; + elems.reserve(vals.size()); + for (auto &item : vals) { + elems.push_back(item.second); + } + + Type structTy = struct_ty(elemTypes); + return getStructFromElements(loc, elems, rewriter, structTy); + } + // get number of elements per thread for $a or $b. + static int getNumElemsPerThread(ArrayRef shape, + DotOperandEncodingAttr dotOpLayout) { + auto blockedLayout = dotOpLayout.getParent().cast(); + auto shapePerCTA = getShapePerCTA(blockedLayout); + auto sizePerThread = getSizePerThread(blockedLayout); + auto order = blockedLayout.getOrder(); + + // TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it + // if not. + int K = dotOpLayout.getOpIdx() == 0 ? shape[1] : shape[0]; + int otherDim = dotOpLayout.getOpIdx() == 1 ? shape[1] : shape[0]; + + bool isM = dotOpLayout.getOpIdx() == 0; + int shapePerCTAMN = getShapePerCTAForMN(blockedLayout, isM); + int sizePerThreadMN = getsizePerThreadForMN(blockedLayout, isM); + return K * std::max(otherDim / shapePerCTAMN, 1) * sizePerThreadMN; + } + + // Get shapePerCTA for M or N axis. + static int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTA = getShapePerCTA(layout); + + int mShapePerCTA = + order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + int nShapePerCTA = + order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; + return isM ? mShapePerCTA : nShapePerCTA; + } + + // Get sizePerThread for M or N axis. + static int getsizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = getSizePerThread(layout); + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; + } +}; + +Value DotOpMmaV1ConversionHelper::loadA( + Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto *ctx = rewriter.getContext(); + auto tensorTy = tensor.getType().cast(); + auto sharedLayout = tensorTy.getEncoding().cast(); + SmallVector shape(tensorTy.getShape().begin(), + tensorTy.getShape().end()); + SmallVector order(sharedLayout.getOrder().begin(), + sharedLayout.getOrder().end()); + + // TODO [Superjomn]: transA cannot be accessed in ConvertLayoutOp. + bool transA = false; + if (transA) { + std::swap(shape[0], shape[1]); + std::swap(order[0], order[1]); + } + + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + + bool isARow = order[0] != 0; + bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes + int packSize0 = (isARow || isAVec4) ? 1 : 2; + + SmallVector fpw({2, 2, 1}); + int repM = 2 * packSize0; + int repK = 1; + int spwM = fpw[0] * 4 * repM; + SmallVector rep({repM, 0, repK}); // pad N with 0 + SmallVector spw({spwM, 0, 1}); // pad N with 0 + + int vecA = sharedLayout.getVec(); + + auto strides = smemObj.strides; + Value strideAM = isARow ? strides[0] : i32_val(1); + Value strideAK = isARow ? i32_val(1) : strides[1]; + Value strideA0 = isARow ? strideAK : strideAM; + Value strideA1 = isARow ? strideAM : strideAK; + + int strideRepM = wpt[0] * fpw[0] * 8; + int strideRepK = 1; + + auto [offsetAM, offsetAK, _0, _1] = + computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); + + // swizzling + int perPhaseA = sharedLayout.getPerPhase(); + int maxPhaseA = sharedLayout.getMaxPhase(); + int stepA0 = isARow ? strideRepK : strideRepM; + int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1); + int NK = shape[1]; + + // pre-compute pointer lanes + Value offA0 = isARow ? offsetAK : offsetAM; + Value offA1 = isARow ? offsetAM : offsetAK; + Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); + offA0 = add(offA0, cSwizzleOffset); + SmallVector offA(numPtrA); + for (int i = 0; i < numPtrA; i++) { + Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); + offA0I = udiv(offA0I, i32_val(vecA)); + offA0I = xor_(offA0I, phaseA); + offA0I = mul(offA0I, i32_val(vecA)); + offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); + } + + Type f16x2Ty = vec_ty(f16_ty, 2); + // One thread get 8 elements as result + Type retTy = + LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx))); + + // prepare arguments + SmallVector ptrA(numPtrA); + + std::map, std::pair> has; + auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); + for (int i = 0; i < numPtrA; i++) + ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]); + + auto instrShape = getMmaInstrShape(); + unsigned numM = std::max(rep[0] * shape[0] / (spw[0] * wpt[0]), 1); + + Type f16PtrTy = ptr_ty(f16_ty); + + auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { + vals[{m, k}] = {val0, val1}; + }; + auto loadA = [&](int m, int k) { + int offidx = (isARow ? k / 4 : m) % numPtrA; + Value thePtrA = gep(f16PtrTy, smem, offA[offidx]); + + int stepAM = isARow ? m : m / numPtrA * numPtrA; + int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; + Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), + mul(i32_val(stepAK), strideAK)); + Value pa = gep(f16PtrTy, thePtrA, offset); + Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); + Value ha = load(bitcast(pa, aPtrTy)); + // record lds that needs to be moved + Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); + Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); + ld(has, m, k, ha00, ha01); + + if (vecA > 4) { + Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); + Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); + if (isARow) + ld(has, m, k + 4, ha10, ha11); + else + ld(has, m + 1, k, ha10, ha11); + } + }; + + for (unsigned k = 0; k < NK; k += 4) + for (unsigned m = 0; m < numM / 2; ++m) + if (!has.count({m, k})) + loadA(m, k); + + SmallVector elems; + elems.reserve(has.size() * 2); + auto vecTy = vec_ty(f16_ty, 2); + for (auto item : has) { // has is a map, the key should be ordered. + elems.push_back(item.second.first); + elems.push_back(item.second.second); + } + + Type resTy = struct_ty(SmallVector(elems.size(), f16x2Ty)); + Value res = getStructFromElements(loc, elems, rewriter, resTy); + return res; +} + +Value DotOpMmaV1ConversionHelper::loadB( + Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, + ConversionPatternRewriter &rewriter) const { + // smem + Value smem = smemObj.base; + auto strides = smemObj.strides; + + auto *ctx = rewriter.getContext(); + auto tensorTy = tensor.getType().cast(); + auto sharedLayout = tensorTy.getEncoding().cast(); + + SmallVector shape(tensorTy.getShape().begin(), + tensorTy.getShape().end()); + SmallVector order(sharedLayout.getOrder().begin(), + sharedLayout.getOrder().end()); + + // TODO [Superjomn]: transB cannot be accessed in ConvertLayoutOp. + bool transB = false; + + if (transB) { + std::swap(order[0], order[1]); + std::swap(shape[0], shape[1]); + } + + bool isBRow = order[0] != 0; + bool isBVec4 = isBRow && shape[order[0]] <= 16; + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + SmallVector fpw({2, 2, 1}); + SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 + SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 + int vecB = sharedLayout.getVec(); + Value strideBN = isBRow ? i32_val(1) : strides[1]; + Value strideBK = isBRow ? strides[0] : i32_val(1); + Value strideB0 = isBRow ? strideBN : strideBK; + Value strideB1 = isBRow ? strideBK : strideBN; + int strideRepN = wpt[1] * fpw[1] * 8; + int strideRepK = 1; + + // swizzling + int perPhaseA = sharedLayout.getPerPhase(); + int maxPhaseA = sharedLayout.getMaxPhase(); + int perPhaseB = sharedLayout.getPerPhase(); + int maxPhaseB = sharedLayout.getMaxPhase(); + int stepB0 = isBRow ? strideRepN : strideRepK; + int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1); + int NK = shape[0]; + + auto [_0, _1, offsetBN, offsetBK] = + computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); + if (transB) + std::swap(offsetBK, offsetBN); + + Value offB0 = isBRow ? offsetBN : offsetBK; + Value offB1 = isBRow ? offsetBK : offsetBN; + Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); + Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); + offB0 = add(offB0, cSwizzleOffset); + SmallVector offB(numPtrB); + for (int i = 0; i < numPtrB; ++i) { + Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); + offB0I = udiv(offB0I, i32_val(vecB)); + offB0I = xor_(offB0I, phaseB); + offB0I = mul(offB0I, i32_val(vecB)); + offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); + } + + Type f16PtrTy = ptr_ty(f16_ty); + Type f16x2Ty = vec_ty(f16_ty, 2); + + SmallVector ptrB(numPtrB); + ValueTable hbs; + for (int i = 0; i < numPtrB; ++i) + ptrB[i] = gep(ptr_ty(f16_ty), smem, offB[i]); + + auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { + vals[{m, k}] = {val0, val1}; + }; + + auto loadB = [&](int n, int K) { + int offidx = (isBRow ? n : K / 4) % numPtrB; + Value thePtrB = ptrB[offidx]; + + int stepBN = isBRow ? n / numPtrB * numPtrB : n; + int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); + Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), + mul(i32_val(stepBK), strideBK)); + Value pb = gep(f16PtrTy, thePtrB, offset); + Value hb = + load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); + // record lds that needs to be moved + Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); + Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); + ld(hbs, n, K, hb00, hb01); + if (vecB > 4) { + Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); + Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); + if (isBRow) + ld(hbs, n + 1, K, hb10, hb11); + else + ld(hbs, n, K + 4, hb10, hb11); + } + }; + + unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); + for (unsigned k = 0; k < NK; k += 4) + for (unsigned n = 0; n < numN / 2; ++n) { + if (!hbs.count({n, k})) + loadB(n, k); + } + + SmallVector elems; + for (auto &item : hbs) { // has is a map, the key should be ordered. + elems.push_back(item.second.first); + elems.push_back(item.second.second); + } + Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); + Type resTy = struct_ty(SmallVector(elems.size(), fp16x2Ty)); + Value res = getStructFromElements(loc, elems, rewriter, resTy); + return res; +} + +std::tuple +DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow, + bool isBRow, ArrayRef fpw, + ArrayRef spw, ArrayRef rep, + ConversionPatternRewriter &rewriter, + Location loc) const { + auto *ctx = rewriter.getContext(); + Value _1 = i32_val(1); + Value _3 = i32_val(3); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + + Value lane = urem(threadId, _32); + Value warp = udiv(threadId, _32); + + // warp offset + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + Value warpMOff = mul(warp0, i32_val(spw[0])); + Value warpNOff = mul(warp1, i32_val(spw[1])); + // Quad offset + Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0])); + Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1])); + // Pair offset + Value pairMOff = udiv(urem(lane, _16), _4); + pairMOff = urem(pairMOff, i32_val(fpw[0])); + pairMOff = mul(pairMOff, _4); + Value pairNOff = udiv(urem(lane, _16), _4); + pairNOff = udiv(pairNOff, i32_val(fpw[0])); + pairNOff = urem(pairNOff, i32_val(fpw[1])); + pairNOff = mul(pairNOff, _4); + // scale + pairMOff = mul(pairMOff, i32_val(rep[0] / 2)); + quadMOff = mul(quadMOff, i32_val(rep[0] / 2)); + pairNOff = mul(pairNOff, i32_val(rep[1] / 2)); + quadNOff = mul(quadNOff, i32_val(rep[1] / 2)); + // Quad pair offset + Value laneMOff = add(pairMOff, quadMOff); + Value laneNOff = add(pairNOff, quadNOff); + // A offset + Value offsetAM = add(warpMOff, laneMOff); + Value offsetAK = and_(lane, _3); + // B offset + Value offsetBN = add(warpNOff, laneNOff); + Value offsetBK = and_(lane, _3); + // i indices + Value offsetCM = add(and_(lane, _1), offsetAM); + if (isARow) { + offsetAM = add(offsetAM, urem(threadId, _4)); + offsetAK = i32_val(0); + } + if (!isBRow) { + offsetBN = add(offsetBN, urem(threadId, _4)); + offsetBK = i32_val(0); + } + + return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK); +} + +DotOpMmaV1ConversionHelper::ValueTable +DotOpMmaV1ConversionHelper::extractLoadedOperand( + Value llStruct, int NK, ConversionPatternRewriter &rewriter) const { + ValueTable rcds; + SmallVector elems = + getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter); + + for (int k = 0, offset = 0, i = 0; k < NK && offset < elems.size(); + k += 4, i++, offset += 2) { + rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); + } + + return rcds; +} + +Value DotOpFMAConversionHelper::loadA( + Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc, + ConversionPatternRewriter &rewriter) const { + auto aTensorTy = A.getType().cast(); + auto aLayout = aTensorTy.getEncoding().cast(); + auto aShape = aTensorTy.getShape(); + + auto aOrder = aLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isARow = aOrder[0] == 1; + + int strideAM = isARow ? aShape[1] : 1; + int strideAK = isARow ? 1 : aShape[0]; + int strideA0 = isARow ? strideAK : strideAM; + int strideA1 = isARow ? strideAM : strideAK; + int lda = isARow ? strideAM : strideAK; + int aNumPtr = 8; + int bNumPtr = 8; + int NK = aShape[1]; + + auto shapePerCTA = getShapePerCTA(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); + + Value threadIdM = threadIds[0]; + Value threadIdN = threadIds[1]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); + } + + auto aSmem = getSharedMemoryObjectFromStruct(loc, llA, rewriter); + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]); + + ValueTable has; + int M = aShape[aOrder[1]]; + + int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/); + int mSizePerThread = getsizePerThreadForMN(dLayout, true /*isM*/); + + for (unsigned k = 0; k < NK; ++k) { + for (unsigned m = 0; m < M; m += mShapePerCTA) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + if (!has.count({m + mm, k})) { + Value pa = gep(f32PtrTy, aPtrs[0], + i32_val((m + mm) * strideAM + k * strideAK)); + Value va = load(pa); + has[{m + mm, k}] = va; + } + } + + return getStructFromValueTable(has, rewriter, loc); +} + +Value DotOpFMAConversionHelper::loadB( + Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc, + ConversionPatternRewriter &rewriter) const { + + auto bTensorTy = B.getType().cast(); + auto bLayout = bTensorTy.getEncoding().cast(); + auto bShape = bTensorTy.getShape(); + + auto bOrder = bLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + int strideBN = isBRow ? 1 : bShape[0]; + int strideBK = isBRow ? bShape[1] : 1; + int strideB0 = isBRow ? strideBN : strideBK; + int strideB1 = isBRow ? strideBK : strideBN; + int ldb = isBRow ? strideBK : strideBN; + int bNumPtr = 8; + int NK = bShape[0]; + + auto shapePerCTA = getShapePerCTA(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); + Value threadIdN = threadIds[1]; + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); + } + + auto bSmem = getSharedMemoryObjectFromStruct(loc, llB, rewriter); + + Type f32PtrTy = ptr_ty(f32_ty); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]); + + int N = bShape[bOrder[0]]; + ValueTable hbs; + + int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/); + int nSizePerThread = getsizePerThreadForMN(dLayout, false /*isM*/); + + for (unsigned k = 0; k < NK; ++k) + for (unsigned n = 0; n < N; n += nShapePerCTA) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + Value pb = gep(f32PtrTy, bPtrs[0], + i32_val((n + nn) * strideBN + k * strideBK)); + Value vb = load(pb); + hbs[{n + nn, k}] = vb; + } + + return getStructFromValueTable(hbs, rewriter, loc); +} + +DotOpFMAConversionHelper::ValueTable +DotOpFMAConversionHelper::getValueTableFromStruct( + Value val, int K, int n0, int shapePerCTA, int sizePerThread, + ConversionPatternRewriter &rewriter, Location loc) const { + ValueTable res; + auto elems = getElementsFromStruct(loc, val, rewriter); + int id = 0; + std::set> keys; // ordered + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + keys.insert({m + mm, k}); + } + } + + for (auto &key : llvm::enumerate(keys)) { + res[key.value()] = elems[key.index()]; + } + + return res; +} +SmallVector DotOpFMAConversionHelper::getThreadIds( + Value threadId, ArrayRef shapePerCTA, + ArrayRef order, ConversionPatternRewriter &rewriter, + Location loc) const { + int dim = order.size(); + SmallVector threadIds(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTA[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTA[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + return threadIds; +} + +} // namespace LLVM +} // namespace mlir + +#endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 908764579..883e718b9 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,5 +1,7 @@ #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "../PassDetail.h" +#include "./DotHelpers.h" +#include "./Utility.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" @@ -32,6 +34,15 @@ using namespace mlir; using namespace mlir::triton; +using ::mlir::LLVM::DotOpFMAConversionHelper; +using ::mlir::LLVM::DotOpMmaV1ConversionHelper; +using ::mlir::LLVM::DotOpMmaV2ConversionHelper; +using ::mlir::LLVM::getElementsFromStruct; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using ::mlir::LLVM::getStructFromElements; +using ::mlir::LLVM::MMA16816ConversionHelper; +using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getElemsPerThread; @@ -48,110 +59,10 @@ namespace LLVM { static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; } -namespace { - -// Create a 32-bit integer constant. -Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { - auto i32ty = rewriter.getIntegerType(32); - return rewriter.create(loc, i32ty, - IntegerAttr::get(i32ty, v)); -} - -Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { - auto type = type::f32Ty(rewriter.getContext()); - return rewriter.create(loc, type, - rewriter.getF32FloatAttr(v)); -} - -Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) { - auto type = type::f64Ty(rewriter.getContext()); - return rewriter.create(loc, type, - rewriter.getF64FloatAttr(v)); -} - -// Create an index type constant. -Value createIndexConstant(OpBuilder &builder, Location loc, - TypeConverter *converter, int64_t value) { - Type ty = converter->convertType(builder.getIndexType()); - return builder.create(loc, ty, - builder.getIntegerAttr(ty, value)); -} - -// Create an integer constant of \param width bits. -Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, - int64_t value) { - Type ty = builder.getIntegerType(width); - return builder.create(loc, ty, - builder.getIntegerAttr(ty, value)); -} - -} // namespace - // A helper function for using printf in LLVM conversion. void llPrintf(StringRef msg, ValueRange args, ConversionPatternRewriter &rewriter); -// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive -#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) -#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) -#define zext(...) rewriter.create(loc, __VA_ARGS__) -#define udiv(...) rewriter.create(loc, __VA_ARGS__) -#define urem(...) rewriter.create(loc, __VA_ARGS__) -#define add(...) rewriter.create(loc, __VA_ARGS__) -#define sub(...) rewriter.create(loc, __VA_ARGS__) -#define fadd(...) rewriter.create(loc, __VA_ARGS__) -#define mul(...) rewriter.create(loc, __VA_ARGS__) -#define smax(...) rewriter.create(loc, __VA_ARGS__) -#define umax(...) rewriter.create(loc, __VA_ARGS__) -#define fmax(...) rewriter.create(loc, __VA_ARGS__) -#define smin(...) rewriter.create(loc, __VA_ARGS__) -#define umin(...) rewriter.create(loc, __VA_ARGS__) -#define fmin(...) rewriter.create(loc, __VA_ARGS__) -#define and_(...) rewriter.create(loc, __VA_ARGS__) -#define xor_(...) rewriter.create(loc, __VA_ARGS__) -#define bitcast(val__, type__) \ - rewriter.create(loc, type__, val__) -#define gep(...) rewriter.create(loc, __VA_ARGS__) -#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) -#define insert_val(...) rewriter.create(loc, __VA_ARGS__) -#define extract_val(...) rewriter.create(loc, __VA_ARGS__) -#define insert_element(...) \ - rewriter.create(loc, __VA_ARGS__) -#define extract_element(...) \ - rewriter.create(loc, __VA_ARGS__) -#define load(...) rewriter.create(loc, __VA_ARGS__) -#define store(val, ptr) rewriter.create(loc, val, ptr) -#define icmp_eq(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) -#define icmp_ne(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) -#define icmp_slt(...) \ - rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) -#define select(...) rewriter.create(loc, __VA_ARGS__) -#define address_of(...) rewriter.create(loc, __VA_ARGS__) -#define barrier() rewriter.create(loc) -#define undef(...) rewriter.create(loc, __VA_ARGS__) -#define i32_ty rewriter.getIntegerType(32) -#define ui32_ty rewriter.getIntegerType(32, false) -#define f16_ty rewriter.getF16Type() -#define bf16_ty rewriter.getBF16Type() -#define i8_ty rewriter.getIntegerType(8) -#define f32_ty rewriter.getF32Type() -#define f64_ty rewriter.getF64Type() -#define vec_ty(type, num) VectorType::get(num, type) -#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) -#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) -#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) -#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) - -// Creator for constant -#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) -#define int_val(width, val) \ - LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) -#define idx_val(...) \ - LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ - __VA_ARGS__) - // Helper function #define tid_val() getThreadId(rewriter, loc) #define llprintf(fmt, ...) LLVM::llPrintf(fmt, {__VA_ARGS__}, rewriter) @@ -205,20 +116,6 @@ auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) { b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs)); } -/// Helper function to get strides from a given shape and its order -auto getStridesFromShapeAndOrder(ArrayRef shape, - ArrayRef order, Location loc, - ConversionPatternRewriter &rewriter) { - auto rank = shape.size(); - SmallVector strides(rank); - auto stride = 1; - for (auto idx : order) { - strides[idx] = i32_val(stride); - stride *= shape[idx]; - } - return strides; -} - struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -353,22 +250,6 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { } }; -Value getStructFromElements(Location loc, ValueRange resultVals, - ConversionPatternRewriter &rewriter, - Type structType) { - if (!structType.isa()) { - return *resultVals.begin(); - } - - Value llvmStruct = rewriter.create(loc, structType); - for (const auto &v : llvm::enumerate(resultVals)) { - assert(v.value() && "can not insert null values"); - llvmStruct = insert_val(structType, llvmStruct, v.value(), - rewriter.getI64ArrayAttr(v.index())); - } - return llvmStruct; -} - // delinearize supposing order is [0, 1, .. , n] template SmallVector getMultiDimIndexImpl(T linearIndex, ArrayRef shape) { @@ -440,97 +321,7 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, return builder.launch(rewriter, loc, void_ty(ctx)); } -struct SharedMemoryObject { - Value base; // i32 ptr. The start address of the shared memory object. - // We need to store strides as Values but not integers because the - // extract_slice instruction can take a slice at artibary offsets. - // Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is - // 32, we need to let the instruction that uses $a to be aware of that. - // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If - // we store strides into an attribute array of integers, the information - // cannot pass through block argument assignment because attributes are - // associated with operations but not Values. - // TODO(Keren): We may need to figure out a way to store strides as integers - // if we want to support more optimizations. - SmallVector - strides; // i32 int. The strides of the shared memory object. - SmallVector offsets; // i32 int. The offsets of the shared memory - // objects from the originally allocated object. - - SharedMemoryObject(Value base, ArrayRef strides, - ArrayRef offsets) - : base(base), strides(strides.begin(), strides.end()), - offsets(offsets.begin(), offsets.end()) {} - - SharedMemoryObject(Value base, ArrayRef shape, - ArrayRef order, Location loc, - ConversionPatternRewriter &rewriter) - : base(base) { - strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); - - for (auto idx : order) { - offsets.emplace_back(i32_val(0)); - } - } - - SmallVector getElems() const { - SmallVector elems; - elems.push_back(base); - elems.append(strides.begin(), strides.end()); - elems.append(offsets.begin(), offsets.end()); - return elems; - } - - SmallVector getTypes() const { - SmallVector types; - types.push_back(base.getType()); - types.append(strides.size(), IntegerType::get(base.getContext(), 32)); - types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); - return types; - } - - Value getCSwizzleOffset(int order) const { - assert(order >= 0 && order < strides.size()); - return offsets[order]; - } - - Value getBaseBeforeSwizzle(int order, Location loc, - ConversionPatternRewriter &rewriter) const { - Value cSwizzleOffset = getCSwizzleOffset(order); - Value offset = sub(i32_val(0), cSwizzleOffset); - Type type = base.getType(); - return gep(type, base, offset); - } -}; - struct ConvertTritonGPUOpToLLVMPatternBase { - static SmallVector - getElementsFromStruct(Location loc, Value llvmStruct, - ConversionPatternRewriter &rewriter) { - if (llvmStruct.getType().isIntOrIndexOrFloat() || - llvmStruct.getType().isa() || - llvmStruct.getType().isa()) - return {llvmStruct}; - ArrayRef types = - llvmStruct.getType().cast().getBody(); - SmallVector results(types.size()); - for (unsigned i = 0; i < types.size(); ++i) { - Type type = types[i]; - results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i)); - } - return results; - } - - static SharedMemoryObject - getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, - ConversionPatternRewriter &rewriter) { - auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); - auto rank = (elems.size() - 1) / 2; - return {/*base=*/elems[0], - /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, - /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; - } - static Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, @@ -1927,7 +1718,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - auto vals = this->getElementsFromStruct(loc, adaptor.src(), rewriter); + auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); Value view = getStructFromElements(loc, vals, rewriter, structTy); rewriter.replaceOp(op, view); return success(); @@ -1945,7 +1736,7 @@ struct PrintfOpConversion auto loc = op->getLoc(); SmallVector operands; for (auto operand : adaptor.getOperands()) { - auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); + auto sub_operands = getElementsFromStruct(loc, operand, rewriter); for (auto elem : sub_operands) { operands.push_back(elem); } @@ -2683,7 +2474,7 @@ protected: const unsigned elems, Location loc) const { SmallVector> operands(elems); for (auto operand : adaptor.getOperands()) { - auto sub_operands = this->getElementsFromStruct(loc, operand, rewriter); + auto sub_operands = getElementsFromStruct(loc, operand, rewriter); for (size_t i = 0; i < elems; ++i) { operands[i].push_back(sub_operands[i]); } @@ -3332,422 +3123,7 @@ struct InsertSliceOpConversion /// ====================== dot codegen begin ========================== -// Data loader for mma.16816 instruction. -class MMA16816SmemLoader { -public: - MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, - ArrayRef smemStrides, ArrayRef tileShape, - ArrayRef instrShape, ArrayRef matShape, - int perPhase, int maxPhase, int elemBytes, - ConversionPatternRewriter &rewriter, - TypeConverter *typeConverter, const Location &loc) - : order(order.begin(), order.end()), kOrder(kOrder), - tileShape(tileShape.begin(), tileShape.end()), - instrShape(instrShape.begin(), instrShape.end()), - matShape(matShape.begin(), matShape.end()), perPhase(perPhase), - maxPhase(maxPhase), elemBytes(elemBytes), rewriter(rewriter), loc(loc), - ctx(rewriter.getContext()) { - cMatShape = matShape[order[0]]; - sMatShape = matShape[order[1]]; - - sStride = smemStrides[order[1]]; - - // rule: k must be the fast-changing axis. - needTrans = kOrder != order[0]; - canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16 - - if (canUseLdmatrix) { - // Each CTA, the warps is arranged as [1xwpt] if not transposed, - // otherwise [wptx1], and each warp will perform a mma. - numPtrs = - tileShape[order[0]] / (needTrans ? wpt : 1) / instrShape[order[0]]; - } else { - numPtrs = tileShape[order[0]] / wpt / matShape[order[0]]; - } - numPtrs = std::max(numPtrs, 2); - - // Special rule for i8/u8, 4 ptrs for each matrix - if (!canUseLdmatrix && elemBytes == 1) - numPtrs *= 4; - - int loadStrideInMat[2]; - loadStrideInMat[kOrder] = - 2; // instrShape[kOrder] / matShape[kOrder], always 2 - loadStrideInMat[kOrder ^ 1] = - wpt * (instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]); - - pLoadStrideInMat = loadStrideInMat[order[0]]; - sMatStride = - loadStrideInMat[order[1]] / (instrShape[order[1]] / matShape[order[1]]); - - // Each matArr contains warpOffStride matrices. - matArrStride = kOrder == 1 ? 1 : wpt; - warpOffStride = instrShape[kOrder ^ 1] / matShape[kOrder ^ 1]; - } - - // lane = thread % 32 - // warpOff = (thread/32) % wpt(0) - llvm::SmallVector computeOffsets(Value warpOff, Value lane, - Value cSwizzleOffset) { - if (canUseLdmatrix) - return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); - else if (elemBytes == 4 && needTrans) - return computeB32MatOffs(warpOff, lane, cSwizzleOffset); - else if (elemBytes == 1 && needTrans) - return computeB8MatOffs(warpOff, lane, cSwizzleOffset); - else - llvm::report_fatal_error("Invalid smem load config"); - - return {}; - } - - int getNumPtrs() const { return numPtrs; } - - // Compute the offset to the matrix this thread(indexed by warpOff and lane) - // mapped to. - SmallVector computeLdmatrixMatOffs(Value warpId, Value lane, - Value cSwizzleOffset) { - // 4x4 matrices - Value c = urem(lane, i32_val(8)); - Value s = udiv(lane, i32_val(8)); // sub-warp-id - - // Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a - // warp - Value s0 = urem(s, i32_val(2)); - Value s1 = udiv(s, i32_val(2)); - - // We use different orders for a and b for better performance. - Value kMatArr = kOrder == 1 ? s1 : s0; - Value nkMatArr = kOrder == 1 ? s0 : s1; - - // matrix coordinate inside a CTA, the matrix layout is [2x2wpt] for A and - // [2wptx2] for B. e.g. Setting wpt=3, The data layout for A(kOrder=1) is - // |0 0 1 1 2 2| -> 0,1,2 are the warpids - // |0 0 1 1 2 2| - // - // for B(kOrder=0) is - // |0 0| -> 0,1,2 are the warpids - // |1 1| - // |2 2| - // |0 0| - // |1 1| - // |2 2| - // Note, for each warp, it handles a 2x2 matrices, that is the coordinate - // address (s0,s1) annotates. - - Value matOff[2]; - matOff[kOrder ^ 1] = add( - mul(warpId, i32_val(warpOffStride)), // warp offset - mul(nkMatArr, i32_val(matArrStride))); // matrix offset inside a warp - matOff[kOrder] = kMatArr; - - // Physical offset (before swizzling) - Value cMatOff = matOff[order[0]]; - Value sMatOff = matOff[order[1]]; - Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); - cMatOff = add(cMatOff, cSwizzleMatOff); - - // row offset inside a matrix, each matrix has 8 rows. - Value sOffInMat = c; - - SmallVector offs(numPtrs); - Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); - Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); - for (int i = 0; i < numPtrs; ++i) { - Value cMatOffI = add(cMatOff, i32_val(i * pLoadStrideInMat)); - cMatOffI = xor_(cMatOffI, phase); - offs[i] = add(mul(cMatOffI, i32_val(cMatShape)), mul(sOff, sStride)); - } - - return offs; - } - - // Compute 32-bit matrix offsets. - SmallVector computeB32MatOffs(Value warpOff, Value lane, - Value cSwizzleOffset) { - assert(needTrans && "Only used in transpose mode."); - // Load tf32 matrices with lds32 - Value cOffInMat = udiv(lane, i32_val(4)); - Value sOffInMat = urem(lane, i32_val(4)); - - Value phase = urem(udiv(sOffInMat, i32_val(perPhase)), i32_val(maxPhase)); - SmallVector offs(numPtrs); - - for (int mat = 0; mat < 4; ++mat) { // Load 4 mats each time - int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; - int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; - if (kMatArrInt > 0) // we don't need pointers for k - continue; - Value kMatArr = i32_val(kMatArrInt); - Value nkMatArr = i32_val(nkMatArrInt); - - Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), - mul(nkMatArr, i32_val(matArrStride))); - Value cSwizzleMatOff = udiv(cSwizzleOffset, i32_val(cMatShape)); - cMatOff = add(cMatOff, cSwizzleMatOff); - - Value sMatOff = kMatArr; - Value sOff = add(sOffInMat, mul(sMatOff, i32_val(sMatShape))); - // FIXME: (kOrder == 1?) is really dirty hack - for (int i = 0; i < numPtrs / 2; ++i) { - Value cMatOffI = - add(cMatOff, i32_val(i * pLoadStrideInMat * (kOrder == 1 ? 1 : 2))); - cMatOffI = xor_(cMatOffI, phase); - Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); - cOff = urem(cOff, i32_val(tileShape[order[0]])); - sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[2 * i + nkMatArrInt] = add(cOff, mul(sOff, sStride)); - } - } - return offs; - } - - // compute 8-bit matrix offset. - SmallVector computeB8MatOffs(Value warpOff, Value lane, - Value cSwizzleOffset) { - assert(needTrans && "Only used in transpose mode."); - Value cOffInMat = udiv(lane, i32_val(4)); - Value sOffInMat = - mul(urem(lane, i32_val(4)), i32_val(4)); // each thread load 4 cols - - SmallVector offs(numPtrs); - for (int mat = 0; mat < 4; ++mat) { - int kMatArrInt = kOrder == 1 ? mat / 2 : mat % 2; - int nkMatArrInt = kOrder == 1 ? mat % 2 : mat / 2; - if (kMatArrInt > 0) // we don't need pointers for k - continue; - Value kMatArr = i32_val(kMatArrInt); - Value nkMatArr = i32_val(nkMatArrInt); - - Value cMatOff = add(mul(warpOff, i32_val(warpOffStride)), - mul(nkMatArr, i32_val(matArrStride))); - Value sMatOff = kMatArr; - - for (int loadx4Off = 0; loadx4Off < numPtrs / 8; ++loadx4Off) { - for (int elemOff = 0; elemOff < 4; ++elemOff) { - int ptrOff = loadx4Off * 8 + nkMatArrInt * 4 + elemOff; - Value cMatOffI = add(cMatOff, i32_val(loadx4Off * pLoadStrideInMat * - (kOrder == 1 ? 1 : 2))); - Value sOffInMatElem = add(sOffInMat, i32_val(elemOff)); - - // disable swizzling ... - - Value cOff = add(cOffInMat, mul(cMatOffI, i32_val(cMatShape))); - Value sOff = add(sOffInMatElem, mul(sMatOff, i32_val(sMatShape))); - // To prevent out-of-bound access when tile is too small. - cOff = urem(cOff, i32_val(tileShape[order[0]])); - sOff = urem(sOff, i32_val(tileShape[order[1]])); - offs[ptrOff] = add(cOff, mul(sOff, sStride)); - } - } - } - return offs; - } - - // Load 4 matrices and returns 4 vec<2> elements. - std::tuple - loadX4(int mat0, int mat1, ArrayRef offs, ArrayRef ptrs, - Type ldmatrixRetTy, Type shemPtrTy) const { - assert(mat0 % 2 == 0 && mat1 % 2 == 0 && - "smem matrix load must be aligned"); - int matIdx[2] = {mat0, mat1}; - - int ptrIdx{-1}; - - if (canUseLdmatrix) - ptrIdx = matIdx[order[0]] / (instrShape[order[0]] / matShape[order[0]]); - else if (elemBytes == 4 && needTrans) - ptrIdx = matIdx[order[0]]; - else if (elemBytes == 1 && needTrans) - ptrIdx = matIdx[order[0]] * 4; - else - llvm::report_fatal_error("unsupported mma type found"); - - // The main difference with the original triton code is we removed the - // prefetch-related logic here for the upstream optimizer phase should - // take care with it, and that is transparent in dot conversion. - auto getPtr = [&](int idx) { return ptrs[idx]; }; - - Value ptr = getPtr(ptrIdx); - - if (canUseLdmatrix) { - Value sOffset = - mul(i32_val(matIdx[order[1]] * sMatStride * sMatShape), sStride); - Value sOffsetPtr = gep(shemPtrTy, ptr, sOffset); - - PTXBuilder builder; - // ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a - // thread. - auto resArgs = builder.newListOperand(4, "=r"); - auto addrArg = builder.newAddrOperand(sOffsetPtr, "r"); - - auto ldmatrix = builder.create("ldmatrix.sync.aligned.m8n8.x4") - ->o("trans", needTrans /*predicate*/) - .o("shared.b16"); - ldmatrix(resArgs, addrArg); - - // The result type is 4xi32, each i32 is composed of 2xf16 - // elements(adjacent two columns in a row) - Value resV4 = builder.launch(rewriter, loc, ldmatrixRetTy); - - auto getIntAttr = [&](int v) { - return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); - }; - - // The struct should have exactly the same element types. - Type elemType = resV4.getType().cast().getBody()[0]; - - return {extract_val(elemType, resV4, getIntAttr(0)), - extract_val(elemType, resV4, getIntAttr(1)), - extract_val(elemType, resV4, getIntAttr(2)), - extract_val(elemType, resV4, getIntAttr(3))}; - } else if (elemBytes == 4 && - needTrans) { // Use lds.32 to load tf32 matrices - Value ptr2 = getPtr(ptrIdx + 1); - assert(sMatStride == 1); - int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); - Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); - int sOffsetArrElem = sMatStride * sMatShape; - Value sOffsetArrElemVal = - add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); - - Value elems[4]; - Type elemTy = type::f32Ty(ctx); - Type elemPtrTy = ptr_ty(elemTy); - if (kOrder == 1) { - elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); - elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); - elems[2] = - load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); - elems[3] = - load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); - } else { - elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); - elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); - elems[1] = - load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); - elems[3] = - load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); - } - return {elems[0], elems[1], elems[2], elems[3]}; - - } else if (elemBytes == 1 && needTrans) { // work with int8 - std::array, 2> ptrs; - ptrs[0] = { - getPtr(ptrIdx), - getPtr(ptrIdx + 1), - getPtr(ptrIdx + 2), - getPtr(ptrIdx + 3), - }; - - ptrs[1] = { - getPtr(ptrIdx + 4), - getPtr(ptrIdx + 5), - getPtr(ptrIdx + 6), - getPtr(ptrIdx + 7), - }; - - assert(sMatStride == 1); - int sOffsetElem = matIdx[order[1]] * (sMatStride * sMatShape); - Value sOffsetElemVal = mul(i32_val(sOffsetElem), sStride); - int sOffsetArrElem = 1 * (sMatStride * sMatShape); - Value sOffsetArrElemVal = - add(sOffsetElemVal, mul(i32_val(sOffsetArrElem), sStride)); - - std::array i8v4Elems; - std::array i32Elems; - i8v4Elems.fill( - rewriter.create(loc, vec_ty(type::i8Ty(ctx), 4))); - - Value i8Elems[4][4]; - Type elemTy = type::i8Ty(ctx); - Type elemPtrTy = ptr_ty(elemTy); - Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); - if (kOrder == 1) { - for (int i = 0; i < 2; ++i) - for (int j = 0; j < 4; ++j) - i8Elems[i][j] = load(gep(elemPtrTy, ptrs[i][j], sOffsetElemVal)); - - for (int i = 2; i < 4; ++i) - for (int j = 0; j < 4; ++j) - i8Elems[i][j] = - load(gep(elemPtrTy, ptrs[i - 2][j], sOffsetArrElemVal)); - - for (int m = 0; m < 4; ++m) { - for (int e = 0; e < 4; ++e) - i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], - i8Elems[m][e], i32_val(e)); - i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); - } - } else { // k first - for (int j = 0; j < 4; ++j) - i8Elems[0][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetElemVal)); - for (int j = 0; j < 4; ++j) - i8Elems[2][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetElemVal)); - for (int j = 0; j < 4; ++j) - i8Elems[1][j] = load(gep(elemPtrTy, ptrs[0][j], sOffsetArrElemVal)); - for (int j = 0; j < 4; ++j) - i8Elems[3][j] = load(gep(elemPtrTy, ptrs[1][j], sOffsetArrElemVal)); - - for (int m = 0; m < 4; ++m) { - for (int e = 0; e < 4; ++e) - i8v4Elems[m] = insert_element(i8v4Elems[m].getType(), i8v4Elems[m], - i8Elems[m][e], i32_val(e)); - i32Elems[m] = bitcast(i8v4Elems[m], i8x4Ty); - } - } - - return {i32Elems[0], i32Elems[1], i32Elems[2], i32Elems[3]}; - } - - assert(false && "Invalid smem load"); - return {Value{}, Value{}, Value{}, Value{}}; - } - -private: - SmallVector order; - int kOrder; - SmallVector tileShape; - SmallVector instrShape; - SmallVector matShape; - int perPhase; - int maxPhase; - int elemBytes; - ConversionPatternRewriter &rewriter; - const Location &loc; - MLIRContext *ctx{}; - - int cMatShape; - int sMatShape; - - Value sStride; - - bool needTrans; - bool canUseLdmatrix; - - int numPtrs; - - int pLoadStrideInMat; - int sMatStride; - - int matArrStride; - int warpOffStride; -}; - struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { - enum class TensorCoreType : uint8_t { - // floating-point tensor core instr - FP32_FP16_FP16_FP32 = 0, // default - FP32_BF16_BF16_FP32, - FP32_TF32_TF32_FP32, - // integer tensor core instr - INT32_INT1_INT1_INT32, // Not implemented - INT32_INT4_INT4_INT32, // Not implemented - INT32_INT8_INT8_INT32, // Not implemented - // - NOT_APPLICABLE, - }; - using ConvertTritonGPUOpToLLVMPattern< triton::DotOp>::ConvertTritonGPUOpToLLVMPattern; @@ -3850,862 +3226,6 @@ private: ConversionPatternRewriter &rewriter) const; }; -// Helper for conversion of DotOp with mma, that is sm<80 -struct DotOpMmaV1ConversionHelper { - MmaEncodingAttr mmaLayout; - ArrayRef wpt; - - using ValueTable = std::map, std::pair>; - - explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout) - : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} - - int getRepM(int M) const { - return std::max(M / (wpt[0] * instrShape[0]), 1); - } - int getRepN(int N) const { - return std::max(N / (wpt[1] * instrShape[1]), 1); - } - - static ArrayRef getMmaInstrShape() { return instrShape; } - - static Type getMmaRetType(TensorType operand) { - auto *ctx = operand.getContext(); - Type fp32Ty = type::f32Ty(ctx); - // f16*f16+f32->f32 - return struct_ty(SmallVector{8, fp32Ty}); - } - - // number of fp16x2 elements for $a. - int numElemsPerThreadA(RankedTensorType tensorTy) const { - auto shape = tensorTy.getShape(); - auto order = getOrder(); - - bool isARow = order[0] != 0; - bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes - int packSize0 = (isARow || isAVec4) ? 1 : 2; - - SmallVector fpw({2, 2, 1}); - int repM = 2 * packSize0; - int repK = 1; - int spwM = fpw[0] * 4 * repM; - SmallVector rep({repM, 0, repK}); // pad N with 0 - SmallVector spw({spwM, 0, 1}); // pad N with 0 - - int NK = shape[1]; - unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); - - // NOTE: We couldn't get the vec from the shared layout. - // int vecA = sharedLayout.getVec(); - // TODO[Superjomn]: Consider the case when vecA > 4 - bool vecGt4 = false; - int elemsPerLd = vecGt4 ? 4 : 2; - return (numM / 2) * (NK / 4) * elemsPerLd; - } - - // number of fp16x2 elements for $b. - int numElemsPerThreadB(RankedTensorType tensorTy) const { - auto shape = tensorTy.getShape(); - auto order = getOrder(); - bool isBRow = order[0] != 0; - bool isBVec4 = isBRow && shape[order[0]] <= 16; - int packSize1 = (isBRow && !isBVec4) ? 2 : 1; - SmallVector fpw({2, 2, 1}); - SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 - SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 - // NOTE: We couldn't get the vec from the shared layout. - // int vecB = sharedLayout.getVec(); - // TODO[Superjomn]: Consider the case when vecA > 4 - bool vecGt4 = false; - int elemsPerLd = vecGt4 ? 4 : 2; - int NK = shape[0]; - - unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); - return (numN / 2) * (NK / 4) * elemsPerLd; - } - - // Loading $a from smem to registers, returns a LLVM::Struct. - Value loadA(Value A, const SharedMemoryObject &smemObj, Value thread, - Location loc, ConversionPatternRewriter &rewriter) const; - - // Loading $b from smem to registers, returns a LLVM::Struct. - Value loadB(Value B, const SharedMemoryObject &smemObj, Value thread, - Location loc, ConversionPatternRewriter &rewriter) const; - - static ArrayRef getOrder() { return mmaOrder; } - - // Compute the offset of the matrix to load. - // Returns offsetAM, offsetAK, offsetBN, offsetBK. - // NOTE, the information M(from $a) and N(from $b) couldn't be retrieved at - // the same time in the usage in convert_layout[shared->dot_op], we leave - // the noexist info to be 0 and only use the desired argument from the - // composed result. In this way we want to retain the original code - // structure in convert_mma884 method for easier debugging. - std::tuple - computeOffsets(Value threadId, bool isARow, bool isBRow, ArrayRef fpw, - ArrayRef spw, ArrayRef rep, - ConversionPatternRewriter &rewriter, Location loc) const; - - // Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1. - ValueTable extractLoadedOperand(Value llStruct, int n0, int n1, - ConversionPatternRewriter &rewriter) const; - -private: - static constexpr unsigned instrShape[] = {16, 16, 4}; - static constexpr unsigned mmaOrder[] = {0, 1}; -}; - -// Helper for conversion of DotOp with mma, that is sm>=80 -struct DotOpMmaV2ConversionHelper { - using TensorCoreType = DotOpConversion::TensorCoreType; - - MmaEncodingAttr mmaLayout; - MLIRContext *ctx{}; - - explicit DotOpMmaV2ConversionHelper(MmaEncodingAttr mmaLayout) - : mmaLayout(mmaLayout) { - ctx = mmaLayout.getContext(); - } - - void deduceMmaType(DotOp op) const { mmaType = getMmaType(op); } - void deduceMmaType(Type operandTy) const { - mmaType = getTensorCoreTypeFromOperand(operandTy); - } - - // Get the M and N of mma instruction shape. - static std::tuple getInstrShapeMN() { - // According to DotOpConversionHelper::mmaInstrShape, all the M,N are - // {16,8} - return {16, 8}; - } - - static std::tuple getRepMN(const RankedTensorType &tensorTy) { - auto mmaLayout = tensorTy.getEncoding().cast(); - auto wpt = mmaLayout.getWarpsPerCTA(); - - int M = tensorTy.getShape()[0]; - int N = tensorTy.getShape()[1]; - auto [instrM, instrN] = getInstrShapeMN(); - int repM = std::max(M / (wpt[0] * instrM), 1); - int repN = std::max(N / (wpt[1] * instrN), 1); - return {repM, repN}; - } - - Type getShemPtrTy() const { - switch (mmaType) { - case TensorCoreType::FP32_FP16_FP16_FP32: - return ptr_ty(type::f16Ty(ctx), 3); - case TensorCoreType::FP32_BF16_BF16_FP32: - return ptr_ty(type::bf16Ty(ctx), 3); - case TensorCoreType::FP32_TF32_TF32_FP32: - return ptr_ty(type::f32Ty(ctx), 3); - case TensorCoreType::INT32_INT8_INT8_INT32: - return ptr_ty(type::i8Ty(ctx), 3); - default: - llvm::report_fatal_error("mma16816 data type not supported"); - } - return Type{}; - } - - // The type of matrix that loaded by either a ldmatrix or composed lds. - Type getMatType() const { - Type fp32Ty = type::f32Ty(ctx); - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type bf16x2Ty = vec_ty(type::bf16Ty(ctx), 2); - // floating point types - Type fp16x2Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); - Type bf16x2Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, bf16x2Ty)); - Type fp32Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); - // integer types - Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); - Type i8x4Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); - - switch (mmaType) { - case TensorCoreType::FP32_FP16_FP16_FP32: - return fp16x2Pack4Ty; - case TensorCoreType::FP32_BF16_BF16_FP32: - return bf16x2Pack4Ty; - case TensorCoreType::FP32_TF32_TF32_FP32: - return fp32Pack4Ty; - case TensorCoreType::INT32_INT8_INT8_INT32: - return i8x4Pack4Ty; - default: - llvm::report_fatal_error("Unsupported mma type found"); - } - - return Type{}; - } - - Type getLoadElemTy() { - switch (mmaType) { - case TensorCoreType::FP32_FP16_FP16_FP32: - return vec_ty(type::f16Ty(ctx), 2); - case TensorCoreType::FP32_BF16_BF16_FP32: - return vec_ty(type::bf16Ty(ctx), 2); - case TensorCoreType::FP32_TF32_TF32_FP32: - return type::f32Ty(ctx); - case TensorCoreType::INT32_INT8_INT8_INT32: - return type::i32Ty(ctx); - default: - llvm::report_fatal_error("Unsupported mma type found"); - } - - return Type{}; - } - - Type getMmaRetType() const { - Type fp32Ty = type::f32Ty(ctx); - Type i32Ty = type::i32Ty(ctx); - Type fp32x4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32Ty)); - Type i32x4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i32Ty)); - switch (mmaType) { - case TensorCoreType::FP32_FP16_FP16_FP32: - return fp32x4Ty; - case TensorCoreType::FP32_BF16_BF16_FP32: - return fp32x4Ty; - case TensorCoreType::FP32_TF32_TF32_FP32: - return fp32x4Ty; - case TensorCoreType::INT32_INT8_INT8_INT32: - return i32x4Ty; - default: - llvm::report_fatal_error("Unsupported mma type found"); - } - - return Type{}; - } - - ArrayRef getMmaInstrShape() const { - assert(mmaType != TensorCoreType::NOT_APPLICABLE && - "Unknown mma type found."); - return mmaInstrShape.at(mmaType); - } - - static ArrayRef getMmaInstrShape(TensorCoreType tensorCoreType) { - assert(tensorCoreType != TensorCoreType::NOT_APPLICABLE && - "Unknown mma type found."); - return mmaInstrShape.at(tensorCoreType); - } - - ArrayRef getMmaMatShape() const { - assert(mmaType != TensorCoreType::NOT_APPLICABLE && - "Unknown mma type found."); - return mmaMatShape.at(mmaType); - } - - // Deduce the TensorCoreType from either $a or $b's type. - static TensorCoreType getTensorCoreTypeFromOperand(Type operandTy) { - auto tensorTy = operandTy.cast(); - auto elemTy = tensorTy.getElementType(); - if (elemTy.isF16()) - return TensorCoreType::FP32_FP16_FP16_FP32; - if (elemTy.isF32()) - return TensorCoreType::FP32_TF32_TF32_FP32; - if (elemTy.isBF16()) - return TensorCoreType::FP32_BF16_BF16_FP32; - if (elemTy.isInteger(8)) - return TensorCoreType::INT32_INT8_INT8_INT32; - return TensorCoreType::NOT_APPLICABLE; - } - - int getVec() const { - assert(mmaType != TensorCoreType::NOT_APPLICABLE && - "Unknown mma type found."); - return mmaInstrVec.at(mmaType); - } - - StringRef getMmaInstr() const { - assert(mmaType != TensorCoreType::NOT_APPLICABLE && - "Unknown mma type found."); - return mmaInstrPtx.at(mmaType); - } - - static TensorCoreType getMmaType(triton::DotOp op) { - Value A = op.a(); - Value B = op.b(); - auto aTy = A.getType().cast(); - auto bTy = B.getType().cast(); - // d = a*b + c - auto dTy = op.d().getType().cast(); - - if (dTy.getElementType().isF32()) { - if (aTy.getElementType().isF16() && bTy.getElementType().isF16()) - return TensorCoreType::FP32_FP16_FP16_FP32; - if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) - return TensorCoreType::FP32_BF16_BF16_FP32; - if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && - op.allowTF32()) - return TensorCoreType::FP32_TF32_TF32_FP32; - } else if (dTy.getElementType().isInteger(32)) { - if (aTy.getElementType().isInteger(8) && - bTy.getElementType().isInteger(8)) - return TensorCoreType::INT32_INT8_INT8_INT32; - } - - return TensorCoreType::NOT_APPLICABLE; - } - -private: - mutable TensorCoreType mmaType{TensorCoreType::NOT_APPLICABLE}; - - // Used on nvidia GPUs mma layout .version == 2 - // Refer to - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-storage - // for more details. - inline static const std::map> - mmaInstrShape = { - {TensorCoreType::FP32_FP16_FP16_FP32, {16, 8, 16}}, - {TensorCoreType::FP32_BF16_BF16_FP32, {16, 8, 16}}, - {TensorCoreType::FP32_TF32_TF32_FP32, {16, 8, 8}}, - - {TensorCoreType::INT32_INT1_INT1_INT32, {16, 8, 256}}, - {TensorCoreType::INT32_INT4_INT4_INT32, {16, 8, 64}}, - {TensorCoreType::INT32_INT8_INT8_INT32, {16, 8, 32}}, - }; - - // shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices) - // Refer to - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix - // for more details. - inline static const std::map> - mmaMatShape = { - {TensorCoreType::FP32_FP16_FP16_FP32, {8, 8, 8}}, - {TensorCoreType::FP32_BF16_BF16_FP32, {8, 8, 8}}, - {TensorCoreType::FP32_TF32_TF32_FP32, {8, 8, 4}}, - - {TensorCoreType::INT32_INT1_INT1_INT32, {8, 8, 64}}, - {TensorCoreType::INT32_INT4_INT4_INT32, {8, 8, 32}}, - {TensorCoreType::INT32_INT8_INT8_INT32, {8, 8, 16}}, - }; - - // Supported mma instruction in PTX. - // Refer to - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma - // for more details. - inline static const std::map mmaInstrPtx = { - {TensorCoreType::FP32_FP16_FP16_FP32, - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"}, - {TensorCoreType::FP32_BF16_BF16_FP32, - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"}, - {TensorCoreType::FP32_TF32_TF32_FP32, - "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"}, - - {TensorCoreType::INT32_INT1_INT1_INT32, - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"}, - {TensorCoreType::INT32_INT4_INT4_INT32, - "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"}, - {TensorCoreType::INT32_INT8_INT8_INT32, - "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"}, - }; - - // vector length per ldmatrix (16*8/element_size_in_bits) - inline static const std::map mmaInstrVec = { - {TensorCoreType::FP32_FP16_FP16_FP32, 8}, - {TensorCoreType::FP32_BF16_BF16_FP32, 8}, - {TensorCoreType::FP32_TF32_TF32_FP32, 4}, - - {TensorCoreType::INT32_INT1_INT1_INT32, 128}, - {TensorCoreType::INT32_INT4_INT4_INT32, 32}, - {TensorCoreType::INT32_INT8_INT8_INT32, 16}, - }; -}; - -// This class helps to adapt the existing DotOpConversion to the latest -// DotOpOperand layout design. It decouples the exising implementation to two -// parts: -// 1. loading the specific operand matrix(for $a, $b, $c) from smem -// 2. passing the loaded value and perform the mma codegen -struct MMA16816ConversionHelper { - MmaEncodingAttr mmaLayout; - ArrayRef wpt; - SmallVector properWpt; - - Value thread, lane, warp; - - DotOpMmaV2ConversionHelper helper; - ConversionPatternRewriter &rewriter; - TypeConverter *typeConverter; - Location loc; - MLIRContext *ctx{}; - - using ValueTable = std::map, Value>; - - // dotOperand: type of either one operand of dotOp. - MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, - Value thread, ConversionPatternRewriter &rewriter, - TypeConverter *typeConverter, Location loc) - : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), - rewriter(rewriter), typeConverter(typeConverter), loc(loc), - ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { - helper.deduceMmaType(dotOperand); - - Value _32 = i32_val(32); - lane = urem(thread, _32); - warp = udiv(thread, _32); - } - - // Get a warpId for M axis. - Value getWarpM(int M) const { - auto matShape = helper.getMmaMatShape(); - return urem(urem(warp, i32_val(wpt[0])), i32_val(M / matShape[0])); - } - - // Get a warpId for N axis. - Value getWarpN(int N) const { - auto matShape = helper.getMmaMatShape(); - Value warpMN = udiv(warp, i32_val(wpt[0])); - return urem(urem(warpMN, i32_val(wpt[1])), i32_val(N / matShape[1])); - } - - // Get the mmaInstrShape deducing either from $a or $b. - std::tuple getMmaInstrShape(Type operand) const { - helper.deduceMmaType(operand); - auto mmaInstrShape = helper.getMmaInstrShape(); - int mmaInstrM = mmaInstrShape[0]; - int mmaInstrN = mmaInstrShape[1]; - int mmaInstrK = mmaInstrShape[2]; - return std::make_tuple(mmaInstrM, mmaInstrN, mmaInstrK); - } - - // Get the mmaMatShape deducing either from $a or $b. - std::tuple getMmaMatShape(Type operand) const { - helper.deduceMmaType(operand); - auto matShape = helper.getMmaMatShape(); - int matShapeM = matShape[0]; - int matShapeN = matShape[1]; - int matShapeK = matShape[2]; - return std::make_tuple(matShapeM, matShapeN, matShapeK); - } - - // \param operand is either $a or $b's type. - inline int getNumRepM(Type operand, int M) const { - return getNumRepM(operand, M, wpt[0]); - } - - // \param operand is either $a or $b's type. - inline int getNumRepN(Type operand, int N) const { - return getNumRepN(operand, N, wpt[1]); - } - - // \param operand is either $a or $b's type. - inline int getNumRepK(Type operand, int K) const { - return getNumRepK_(operand, K); - } - - static int getNumRepM(Type operand, int M, int wpt) { - auto tensorCoreType = - DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); - int mmaInstrM = - DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[0]; - return std::max(M / (wpt * mmaInstrM), 1); - } - - static int getNumRepN(Type operand, int N, int wpt) { - auto tensorCoreType = - DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); - int mmaInstrN = - DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[1]; - return std::max(N / (wpt * mmaInstrN), 1); - } - - static int getNumRepK_(Type operand, int K) { - auto tensorCoreType = - DotOpMmaV2ConversionHelper::getTensorCoreTypeFromOperand(operand); - int mmaInstrK = - DotOpMmaV2ConversionHelper::getMmaInstrShape(tensorCoreType)[2]; - return std::max(K / mmaInstrK, 1); - } - - // Get number of elements per thread for $a operand. - static size_t getANumElemsPerThread(RankedTensorType operand, int wpt) { - auto shape = operand.getShape(); - int repM = getNumRepM(operand, shape[0], wpt); - int repK = getNumRepK_(operand, shape[1]); - return 4 * repM * repK; - } - - // Get number of elements per thread for $b operand. - static size_t getBNumElemsPerThread(RankedTensorType operand, int wpt) { - auto shape = operand.getShape(); - int repK = getNumRepK_(operand, shape[0]); - int repN = getNumRepN(operand, shape[1], wpt); - return 4 * std::max(repN / 2, 1) * repK; - } - - // Loading $a from smem to registers, returns a LLVM::Struct. - Value loadA(Value tensor, const SharedMemoryObject &smemObj) const { - auto aTensorTy = tensor.getType().cast(); - auto layout = aTensorTy.getEncoding().cast(); - - SmallVector shape(aTensorTy.getShape().begin(), - aTensorTy.getShape().end()); - ValueTable ha; - std::function loadFn; - auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(aTensorTy); - auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(aTensorTy); - - int numRepM = getNumRepM(aTensorTy, shape[0]); - int numRepK = getNumRepK(aTensorTy, shape[1]); - - if (aTensorTy.getEncoding().isa()) { - Value warpM = getWarpM(shape[0]); - // load from smem - loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[0] /*wpt*/, - 1 /*kOrder*/, {mmaInstrM, mmaInstrK} /*instrShape*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, ha /*vals*/, - true /*isA*/); - } else if (aTensorTy.getEncoding().isa()) { - // load from registers, used in gemm fuse - // TODO(Superjomn) Port the logic. - assert(false && "Loading A from register is not supported yet."); - } else { - assert(false && "A's layout is not supported."); - } - - // step1. Perform loading. - for (int m = 0; m < numRepM; ++m) - for (int k = 0; k < numRepK; ++k) - loadFn(2 * m, 2 * k); - - // step2. Format the values to LLVM::Struct to passing to mma codegen. - return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); - } - - // Loading $b from smem to registers, returns a LLVM::Struct. - Value loadB(Value tensor, const SharedMemoryObject &smemObj) { - ValueTable hb; - auto tensorTy = tensor.getType().cast(); - auto layout = tensorTy.getEncoding().cast(); - - SmallVector shape(tensorTy.getShape().begin(), - tensorTy.getShape().end()); - - auto [matShapeM, matShapeN, matShapeK] = getMmaMatShape(tensorTy); - auto [mmaInstrM, mmaInstrN, mmaInstrK] = getMmaInstrShape(tensorTy); - int numRepK = getNumRepK(tensorTy, shape[0]); - int numRepN = getNumRepN(tensorTy, shape[1]); - - Value warpN = getWarpN(shape[1]); - auto loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, mmaLayout.getWarpsPerCTA()[1] /*wpt*/, - 0 /*kOrder*/, {mmaInstrK, mmaInstrN} /*instrShape*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, hb /*vals*/, - false /*isA*/); - - for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { - for (int k = 0; k < numRepK; ++k) - loadFn(2 * n, 2 * k); - } - - Value result = composeValuesToDotOperandLayoutStruct( - hb, std::max(numRepN / 2, 1), numRepK); - return result; - } - - // Loading $c to registers, returns a Value. - Value loadC(Value tensor, Value llTensor) const { - auto tensorTy = tensor.getType().cast(); - auto [repM, repN] = DotOpMmaV2ConversionHelper::getRepMN(tensorTy); - size_t fcSize = 4 * repM * repN; - - assert(tensorTy.getEncoding().isa() && - "Currently, we only support $c with a mma layout."); - // Load a normal C tensor with mma layout, that should be a - // LLVM::struct with fcSize elements. - auto structTy = llTensor.getType().cast(); - assert(structTy.getBody().size() == fcSize && - "DotOp's $c operand should pass the same number of values as $d in " - "mma layout."); - return llTensor; - } - - // Conduct the Dot conversion. - // \param a, \param b, \param c and \param d are DotOp operands. - // \param loadedA, \param loadedB, \param loadedC, all of them are result of - // loading. - LogicalResult convertDot(Value a, Value b, Value c, Value d, Value loadedA, - Value loadedB, Value loadedC, DotOp op, - DotOpAdaptor adaptor) const { - helper.deduceMmaType(op); - - auto aTensorTy = a.getType().cast(); - auto dTensorTy = d.getType().cast(); - - SmallVector aShape(aTensorTy.getShape().begin(), - aTensorTy.getShape().end()); - if (op.transA()) - std::swap(aShape[0], aShape[1]); - - auto dShape = dTensorTy.getShape(); - - // shape / shape_per_cta - int numRepM = getNumRepM(aTensorTy, dShape[0]); - int numRepN = getNumRepN(aTensorTy, dShape[1]); - int numRepK = getNumRepK(aTensorTy, aShape[1]); - - ValueTable ha = - getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK); - ValueTable hb = getValuesFromDotOperandLayoutStruct( - loadedB, std::max(numRepN / 2, 1), numRepK); - auto fc = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( - loc, loadedC, rewriter); - - auto callMma = [&](unsigned m, unsigned n, unsigned k) { - unsigned colsPerThread = numRepN * 2; - PTXBuilder builder; - auto &mma = *builder.create(helper.getMmaInstr().str()); - auto retArgs = builder.newListOperand(4, "=r"); - auto aArgs = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs = - builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); - auto cArgs = builder.newListOperand(); - for (int i = 0; i < 4; ++i) { - cArgs->listAppend(builder.newOperand(fc[m * colsPerThread + 4 * n + i], - std::to_string(i))); - // reuse the output registers - } - - mma(retArgs, aArgs, bArgs, cArgs); - Value mmaOut = builder.launch(rewriter, loc, helper.getMmaRetType()); - - auto getIntAttr = [&](int v) { - return ArrayAttr::get(ctx, {IntegerAttr::get(i32_ty, v)}); - }; - - Type elemTy = mmaOut.getType().cast().getBody()[0]; - for (int i = 0; i < 4; ++i) - fc[m * colsPerThread + 4 * n + i] = - extract_val(elemTy, mmaOut, getIntAttr(i)); - }; - - for (int k = 0; k < numRepK; ++k) - for (int m = 0; m < numRepM; ++m) - for (int n = 0; n < numRepN; ++n) - callMma(2 * m, n, 2 * k); - - Type resElemTy = dTensorTy.getElementType(); - - for (auto &elem : fc) { - elem = bitcast(elem, resElemTy); - } - - // replace with new packed result - Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(fc.size(), resElemTy)); - Value res = getStructFromElements(loc, fc, rewriter, structTy); - rewriter.replaceOp(op, res); - - return success(); - } - -private: - std::function - getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, - MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, - ArrayRef instrShape, ArrayRef matShape, - Value warpId, ValueTable &vals, bool isA) const { - auto tensorTy = tensor.getType().cast(); - // We assumes that the input operand of Dot should be from shared layout. - // TODO(Superjomn) Consider other layouts if needed later. - auto sharedLayout = tensorTy.getEncoding().cast(); - const int perPhase = sharedLayout.getPerPhase(); - const int maxPhase = sharedLayout.getMaxPhase(); - const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; - auto order = sharedLayout.getOrder(); - - // the original register_lds2, but discard the prefetch logic. - auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { - vals[{mn, k}] = val; - }; - - // (a, b) is the coordinate. - auto load = [=, &vals, &ld2](int a, int b) { - MMA16816SmemLoader loader( - wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, - tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, - maxPhase, elemBytes, rewriter, typeConverter, loc); - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); - SmallVector offs = - loader.computeOffsets(warpId, lane, cSwizzleOffset); - const int numPtrs = loader.getNumPtrs(); - SmallVector ptrs(numPtrs); - - Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); - Type smemPtrTy = helper.getShemPtrTy(); - for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = - bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy); - } - - auto [ha0, ha1, ha2, ha3] = loader.loadX4( - (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, - ptrs, helper.getMatType(), helper.getShemPtrTy()); - - if (isA) { - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha1); - ld2(vals, a, b + 1, ha2); - ld2(vals, a + 1, b + 1, ha3); - } else { - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha2); - ld2(vals, a, b + 1, ha1); - ld2(vals, a + 1, b + 1, ha3); - } - }; - - return load; - } - - // Compose a map of Values to a LLVM::Struct. - // The layout is a list of Value with coordinate of (i,j), the order is as - // the follows: - // [ - // (0,0), (0,1), (1,0), (1,1), # i=0, j=0 - // (0,2), (0,3), (1,2), (1,3), # i=0, j=1 - // (0,4), (0,5), (1,4), (1,5), # i=0, j=2 - // ... - // (2,0), (2,1), (3,0), (3,1), # i=1, j=0 - // (2,2), (2,3), (3,2), (3,3), # i=1, j=1 - // (2,4), (2,5), (3,4), (3,5), # i=1, j=2 - // ... - // ] - // i \in [0, n0) and j \in [0, n1) - // There should be \param n0 * \param n1 elements in the output Struct. - Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0, - int n1) const { - std::vector elems; - for (int m = 0; m < n0; ++m) - for (int k = 0; k < n1; ++k) { - elems.push_back(vals.at({2 * m, 2 * k})); - elems.push_back(vals.at({2 * m, 2 * k + 1})); - elems.push_back(vals.at({2 * m + 1, 2 * k})); - elems.push_back(vals.at({2 * m + 1, 2 * k + 1})); - } - - assert(!elems.empty()); - - Type elemTy = elems[0].getType(); - Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), elemTy)); - auto result = getStructFromElements(loc, elems, rewriter, structTy); - return result; - } - - ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, - int n1) const { - auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( - loc, value, rewriter); - - int offset{}; - ValueTable vals; - for (int i = 0; i < n0; ++i) { - for (int j = 0; j < n1; j++) { - vals[{2 * i, 2 * j}] = elems[offset++]; - vals[{2 * i, 2 * j + 1}] = elems[offset++]; - vals[{2 * i + 1, 2 * j}] = elems[offset++]; - vals[{2 * i + 1, 2 * j + 1}] = elems[offset++]; - } - } - return vals; - } -}; - -// Helper for conversion of FMA DotOp. -struct DotOpFMAConversionHelper { - Attribute layout; - MLIRContext *ctx{}; - - using ValueTable = std::map, Value>; - - explicit DotOpFMAConversionHelper(Attribute layout) - : layout(layout), ctx(layout.getContext()) {} - - SmallVector getThreadIds(Value threadId, - ArrayRef shapePerCTA, - ArrayRef order, - ConversionPatternRewriter &rewriter, - Location loc) const; - - Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, - Location loc, ConversionPatternRewriter &rewriter) const; - - Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, - Location loc, ConversionPatternRewriter &rewriter) const; - - ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, - int sizePerThread, - ConversionPatternRewriter &rewriter, - Location loc) const; - - Value getStructFromValueTable(const ValueTable &vals, - ConversionPatternRewriter &rewriter, - Location loc) const { - SmallVector elemTypes(vals.size(), f32_ty); - SmallVector elems; - elems.reserve(vals.size()); - for (auto &item : vals) { - elems.push_back(item.second); - } - - Type structTy = struct_ty(elemTypes); - return getStructFromElements(loc, elems, rewriter, structTy); - } - // get number of elements per thread for $a or $b. - static int getNumElemsPerThread(ArrayRef shape, - DotOperandEncodingAttr dotOpLayout) { - auto blockedLayout = dotOpLayout.getParent().cast(); - auto shapePerCTA = getShapePerCTA(blockedLayout); - auto sizePerThread = getSizePerThread(blockedLayout); - auto order = blockedLayout.getOrder(); - - // TODO[Superjomn]: we assume the k aixs is fixed for $a and $b here, fix it - // if not. - int K = dotOpLayout.getOpIdx() == 0 ? shape[1] : shape[0]; - int otherDim = dotOpLayout.getOpIdx() == 1 ? shape[1] : shape[0]; - - bool isM = dotOpLayout.getOpIdx() == 0; - int shapePerCTAMN = getShapePerCTAForMN(blockedLayout, isM); - int sizePerThreadMN = getsizePerThreadForMN(blockedLayout, isM); - return K * std::max(otherDim / shapePerCTAMN, 1) * sizePerThreadMN; - } - - // Get shapePerCTA for M or N axis. - static int getShapePerCTAForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto shapePerCTA = getShapePerCTA(layout); - - int mShapePerCTA = - order[0] == 1 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - int nShapePerCTA = - order[0] == 0 ? shapePerCTA[order[1]] : shapePerCTA[order[0]]; - return isM ? mShapePerCTA : nShapePerCTA; - } - - // Get sizePerThread for M or N axis. - static int getsizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto sizePerThread = getSizePerThread(layout); - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - return isM ? mSizePerThread : nSizePerThread; - } -}; - Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, const MmaEncodingAttr &mmaLayout, @@ -4861,16 +3381,31 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); auto DTensorTy = D.getType().cast(); - auto AShape = ATensorTy.getShape(); - auto BShape = BTensorTy.getShape(); + SmallVector AShape(ATensorTy.getShape().begin(), + ATensorTy.getShape().end()); + SmallVector BShape(BTensorTy.getShape().begin(), + BTensorTy.getShape().end()); auto DShape = DTensorTy.getShape(); auto wpt = mmaLayout.getWarpsPerCTA(); bool transA = op.transA(); bool transB = op.transB(); - bool isARow = !transA; - bool isBRow = !transB; + // TODO[Superjomn]: order cannot accessed in DotOp. + SmallVector AOrder({1, 0}); + SmallVector BOrder({1, 0}); + + if (transA) { + std::swap(AShape[0], AShape[1]); + std::swap(AOrder[0], AOrder[1]); + } + if (transB) { + std::swap(BShape[0], BShape[1]); + std::swap(BOrder[0], BOrder[0]); + } + + bool isARow = AOrder[0] != 0; + bool isBRow = BOrder[0] != 0; bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes bool isBVec4 = isBRow && BShape[isBRow] <= 16; int packSize0 = (isARow || isAVec4) ? 1 : 2; @@ -4888,10 +3423,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, unsigned numN = rep[1] * DShape[1] / (spw[1] * wpt[0]); unsigned NK = AShape[1]; - auto has = helper.extractLoadedOperand(loadedA, numM / 2, NK, rewriter); - auto hbs = helper.extractLoadedOperand(loadedB, numN / 2, NK, rewriter); - - size_t accSize = numM * numN; + auto has = helper.extractLoadedOperand(loadedA, NK, rewriter); + auto hbs = helper.extractLoadedOperand(loadedB, NK, rewriter); // initialize accumulators SmallVector acc = getElementsFromStruct(loc, loadedC, rewriter); @@ -4957,491 +3490,6 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, return success(); } -Value DotOpMmaV1ConversionHelper::loadA( - Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, - ConversionPatternRewriter &rewriter) const { - - auto *ctx = rewriter.getContext(); - auto tensorTy = tensor.getType().cast(); - auto shape = tensorTy.getShape(); - auto sharedLayout = tensorTy.getEncoding().cast(); - auto order = sharedLayout.getOrder(); - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); - - bool isARow = order[0] != 0; - bool isAVec4 = !isARow && shape[order[0]] <= 16; // fp16*4 = 16bytes - int packSize0 = (isARow || isAVec4) ? 1 : 2; - - SmallVector fpw({2, 2, 1}); - int repM = 2 * packSize0; - int repK = 1; - int spwM = fpw[0] * 4 * repM; - SmallVector rep({repM, 0, repK}); // pad N with 0 - SmallVector spw({spwM, 0, 1}); // pad N with 0 - - int vecA = sharedLayout.getVec(); - - auto strides = smemObj.strides; - Value strideAM = isARow ? strides[0] : i32_val(1); - Value strideAK = isARow ? i32_val(1) : strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; - Value strideA1 = isARow ? strideAM : strideAK; - - int strideRepM = wpt[0] * fpw[0] * 8; - int strideRepK = 1; - - auto [offsetAM, offsetAK, _0, _1] = - computeOffsets(thread, isARow, false, fpw, spw, rep, rewriter, loc); - - // swizzling - int perPhaseA = sharedLayout.getPerPhase(); - int maxPhaseA = sharedLayout.getMaxPhase(); - int stepA0 = isARow ? strideRepK : strideRepM; - int numPtrA = std::max(2 * perPhaseA * maxPhaseA / stepA0, 1); - int NK = shape[1]; - - // pre-compute pointer lanes - Value offA0 = isARow ? offsetAK : offsetAM; - Value offA1 = isARow ? offsetAM : offsetAK; - Value phaseA = urem(udiv(offA1, i32_val(perPhaseA)), i32_val(maxPhaseA)); - offA0 = add(offA0, cSwizzleOffset); - SmallVector offA(numPtrA); - for (int i = 0; i < numPtrA; i++) { - Value offA0I = add(offA0, i32_val(i * (isARow ? 4 : strideRepM))); - offA0I = udiv(offA0I, i32_val(vecA)); - offA0I = xor_(offA0I, phaseA); - offA0I = xor_(offA0I, i32_val(vecA)); - offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); - } - - Type f16x2Ty = vec_ty(f16_ty, 2); - // One thread get 8 elements as result - Type retTy = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(8, type::f32Ty(ctx))); - - // prepare arguments - SmallVector ptrA(numPtrA); - - std::map, std::pair> has; - auto smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); - for (int i = 0; i < numPtrA; i++) - ptrA[i] = gep(ptr_ty(f16_ty), smem, offA[i]); - - auto instrShape = getMmaInstrShape(); - unsigned numM = rep[0] * shape[0] / (spw[0] * wpt[0]); - - Type f16PtrTy = ptr_ty(f16_ty); - - auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; - }; - auto loadA = [&](int m, int k) { - int offidx = (isARow ? k / 4 : m) % numPtrA; - Value thePtrA = gep(f16PtrTy, smem, offA[offidx]); - - int stepAM = isARow ? m : m / numPtrA * numPtrA; - int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; - Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), - mul(i32_val(stepAK), strideAK)); - Value pa = gep(f16PtrTy, thePtrA, offset); - Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); - Value ha = load(bitcast(pa, aPtrTy)); - // record lds that needs to be moved - Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); - Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); - ld(has, m, k, ha00, ha01); - - if (vecA > 4) { - Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); - Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); - if (isARow) - ld(has, m, k + 4, ha10, ha11); - else - ld(has, m + 1, k, ha10, ha11); - } - }; - - for (unsigned k = 0; k < NK; k += 4) - for (unsigned m = 0; m < numM / 2; ++m) - if (!has.count({m, k})) - loadA(m, k); - - SmallVector elems; - elems.reserve(has.size() * 2); - auto vecTy = vec_ty(f16_ty, 2); - for (auto item : has) { // has is a map, the key should be ordered. - elems.push_back(item.second.first); - elems.push_back(item.second.second); - } - - Type resTy = struct_ty(SmallVector(elems.size(), f16x2Ty)); - Value res = getStructFromElements(loc, elems, rewriter, resTy); - return res; -} - -Value DotOpMmaV1ConversionHelper::loadB( - Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc, - ConversionPatternRewriter &rewriter) const { - // smem - Value smem = smemObj.base; - auto strides = smemObj.strides; - - auto *ctx = rewriter.getContext(); - auto tensorTy = tensor.getType().cast(); - auto shape = tensorTy.getShape(); - auto sharedLayout = tensorTy.getEncoding().cast(); - auto order = sharedLayout.getOrder(); - bool isBRow = order[0] != 0; - bool isBVec4 = isBRow && shape[order[0]] <= 16; - int packSize1 = (isBRow && !isBVec4) ? 2 : 1; - SmallVector fpw({2, 2, 1}); - SmallVector rep({0, 2 * packSize1, 1}); // pad M with 0 - SmallVector spw({0, fpw[1] * 4 * rep[1], 1}); // pad M with 0 - int vecB = sharedLayout.getVec(); - Value strideBN = isBRow ? i32_val(1) : strides[1]; - Value strideBK = isBRow ? strides[0] : i32_val(1); - Value strideB0 = isBRow ? strideBN : strideBK; - Value strideB1 = isBRow ? strideBK : strideBN; - int strideRepN = wpt[1] * fpw[1] * 8; - int strideRepK = 1; - - // swizzling - int perPhaseA = sharedLayout.getPerPhase(); - int maxPhaseA = sharedLayout.getMaxPhase(); - int perPhaseB = sharedLayout.getPerPhase(); - int maxPhaseB = sharedLayout.getMaxPhase(); - int stepB0 = isBRow ? strideRepN : strideRepK; - int numPtrB = std::max(2 * perPhaseB * maxPhaseB / stepB0, 1); - int NK = shape[0]; - - auto [_0, _1, offsetBN, offsetBK] = - computeOffsets(thread, false, isBRow, fpw, spw, rep, rewriter, loc); - - Value offB0 = isBRow ? offsetBN : offsetBK; - Value offB1 = isBRow ? offsetBK : offsetBN; - Value phaseB = urem(udiv(offB1, i32_val(perPhaseB)), i32_val(maxPhaseB)); - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); - offB0 = add(offB0, cSwizzleOffset); - SmallVector offB(numPtrB); - for (int i = 0; i < numPtrB; ++i) { - Value offB0I = add(offB0, i32_val(i * (isBRow ? strideRepN : 4))); - offB0I = udiv(offB0I, i32_val(vecB)); - offB0I = xor_(offB0I, phaseB); - offB0I = mul(offB0I, i32_val(vecB)); - offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); - } - - Type f16PtrTy = ptr_ty(f16_ty); - Type f16x2Ty = vec_ty(f16_ty, 2); - - SmallVector ptrB(numPtrB); - ValueTable hbs; - for (int i = 0; i < numPtrB; ++i) - ptrB[i] = gep(ptr_ty(f16_ty), smem, offB[i]); - - auto ld = [&](decltype(hbs) &vals, int m, int k, Value val0, Value val1) { - vals[{m, k}] = {val0, val1}; - }; - - auto loadB = [&](int n, int K) { - int offidx = (isBRow ? n : K / 4) % numPtrB; - Value thePtrB = ptrB[offidx]; - - int stepBN = isBRow ? n / numPtrB * numPtrB : n; - int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); - Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), - mul(i32_val(stepBK), strideBK)); - Value pb = gep(f16PtrTy, thePtrB, offset); - Value hb = - load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); - // record lds that needs to be moved - Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); - Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); - ld(hbs, n, K, hb00, hb01); - if (vecB > 4) { - Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); - Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); - if (isBRow) - ld(hbs, n + 1, K, hb10, hb11); - else - ld(hbs, n, K + 4, hb10, hb11); - } - }; - - unsigned numN = rep[1] * shape[1] / (spw[1] * wpt[0]); - for (unsigned k = 0; k < NK; k += 4) - for (unsigned n = 0; n < numN / 2; ++n) { - if (!hbs.count({n, k})) - loadB(n, k); - } - - SmallVector elems; - for (auto &item : hbs) { // has is a map, the key should be ordered. - elems.push_back(item.second.first); - elems.push_back(item.second.second); - } - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type resTy = struct_ty(SmallVector(elems.size(), fp16x2Ty)); - Value res = getStructFromElements(loc, elems, rewriter, resTy); - return res; -} - -std::tuple -DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow, - bool isBRow, ArrayRef fpw, - ArrayRef spw, ArrayRef rep, - ConversionPatternRewriter &rewriter, - Location loc) const { - auto *ctx = rewriter.getContext(); - Value _1 = i32_val(1); - Value _3 = i32_val(3); - Value _4 = i32_val(4); - Value _16 = i32_val(16); - Value _32 = i32_val(32); - - Value lane = urem(threadId, _32); - Value warp = udiv(threadId, _32); - - // warp offset - Value warp0 = urem(warp, i32_val(wpt[0])); - Value warp12 = udiv(warp, i32_val(wpt[0])); - Value warp1 = urem(warp12, i32_val(wpt[1])); - Value warpMOff = mul(warp0, i32_val(spw[0])); - Value warpNOff = mul(warp1, i32_val(spw[1])); - // Quad offset - Value quadMOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[0])); - Value quadNOff = mul(udiv(and_(lane, _16), _4), i32_val(fpw[1])); - // Pair offset - Value pairMOff = udiv(urem(lane, _16), _4); - pairMOff = urem(pairMOff, i32_val(fpw[0])); - pairMOff = mul(pairMOff, _4); - Value pairNOff = udiv(urem(lane, _16), _4); - pairNOff = udiv(pairNOff, i32_val(fpw[0])); - pairNOff = urem(pairNOff, i32_val(fpw[1])); - pairNOff = mul(pairNOff, _4); - // scale - pairMOff = mul(pairMOff, i32_val(rep[0] / 2)); - quadMOff = mul(quadMOff, i32_val(rep[0] / 2)); - pairNOff = mul(pairNOff, i32_val(rep[1] / 2)); - quadNOff = mul(quadNOff, i32_val(rep[1] / 2)); - // Quad pair offset - Value laneMOff = add(pairMOff, quadMOff); - Value laneNOff = add(pairNOff, quadNOff); - // A offset - Value offsetAM = add(warpMOff, laneMOff); - Value offsetAK = and_(lane, _3); - // B offset - Value offsetBN = add(warpNOff, laneNOff); - Value offsetBK = and_(lane, _3); - // i indices - Value offsetCM = add(and_(lane, _1), offsetAM); - if (isARow) { - offsetAM = add(offsetAM, urem(threadId, _4)); - offsetAK = i32_val(0); - } - if (!isBRow) { - offsetBN = add(offsetBN, urem(threadId, _4)); - offsetBK = i32_val(0); - } - - return std::make_tuple(offsetAM, offsetAK, offsetBN, offsetBK); -} - -DotOpMmaV1ConversionHelper::ValueTable -DotOpMmaV1ConversionHelper::extractLoadedOperand( - Value llStruct, int n0, int n1, ConversionPatternRewriter &rewriter) const { - ValueTable rcds; - SmallVector elems = - ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( - llStruct.getLoc(), llStruct, rewriter); - - int offset = 0; - for (int i = 0; i < n0; ++i) - for (int k = 0; k < n1; k += 4) { - rcds[{i, k}] = std::make_pair(elems[offset], elems[offset + 1]); - offset += 2; - } - - return rcds; -} - -Value DotOpFMAConversionHelper::loadA( - Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc, - ConversionPatternRewriter &rewriter) const { - auto aTensorTy = A.getType().cast(); - auto aLayout = aTensorTy.getEncoding().cast(); - auto aShape = aTensorTy.getShape(); - - auto aOrder = aLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isARow = aOrder[0] == 1; - - int strideAM = isARow ? aShape[1] : 1; - int strideAK = isARow ? 1 : aShape[0]; - int strideA0 = isARow ? strideAK : strideAM; - int strideA1 = isARow ? strideAM : strideAK; - int lda = isARow ? strideAM : strideAK; - int aNumPtr = 8; - int bNumPtr = 8; - int NK = aShape[1]; - - auto shapePerCTA = getShapePerCTA(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value mContig = i32_val(sizePerThread[order[1]]); - Value nContig = i32_val(sizePerThread[order[0]]); - - // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); - - Value threadIdM = threadIds[0]; - Value threadIdN = threadIds[1]; - - Value offA0 = isARow ? _0 : mul(threadIdM, mContig); - Value offA1 = isARow ? mul(threadIdM, mContig) : _0; - SmallVector aOff(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) { - aOff[i] = add(mul(offA0, i32_val(strideA0)), mul(offA1, i32_val(strideA1))); - } - - auto aSmem = - ConvertTritonGPUOpToLLVMPatternBase::getSharedMemoryObjectFromStruct( - loc, llA, rewriter); - - Type f32PtrTy = ptr_ty(f32_ty); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) - aPtrs[i] = gep(f32PtrTy, aSmem.base, aOff[i]); - - ValueTable has; - int M = aShape[aOrder[1]]; - - int mShapePerCTA = getShapePerCTAForMN(dLayout, true /*isM*/); - int mSizePerThread = getsizePerThreadForMN(dLayout, true /*isM*/); - - for (unsigned k = 0; k < NK; ++k) { - for (unsigned m = 0; m < M; m += mShapePerCTA) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) - if (!has.count({m + mm, k})) { - Value pa = gep(f32PtrTy, aPtrs[0], - i32_val((m + mm) * strideAM + k * strideAK)); - Value va = load(pa); - has[{m + mm, k}] = va; - } - } - - return getStructFromValueTable(has, rewriter, loc); -} - -Value DotOpFMAConversionHelper::loadB( - Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc, - ConversionPatternRewriter &rewriter) const { - - auto bTensorTy = B.getType().cast(); - auto bLayout = bTensorTy.getEncoding().cast(); - auto bShape = bTensorTy.getShape(); - - auto bOrder = bLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isBRow = bOrder[0] == 1; - - int strideBN = isBRow ? 1 : bShape[0]; - int strideBK = isBRow ? bShape[1] : 1; - int strideB0 = isBRow ? strideBN : strideBK; - int strideB1 = isBRow ? strideBK : strideBN; - int ldb = isBRow ? strideBK : strideBN; - int bNumPtr = 8; - int NK = bShape[0]; - - auto shapePerCTA = getShapePerCTA(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value mContig = i32_val(sizePerThread[order[1]]); - Value nContig = i32_val(sizePerThread[order[0]]); - - // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTA, order, rewriter, loc); - Value threadIdN = threadIds[1]; - - Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; - Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); - SmallVector bOff(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) { - bOff[i] = add(mul(offB0, i32_val(strideB0)), mul(offB1, i32_val(strideB1))); - } - - auto bSmem = - ConvertTritonGPUOpToLLVMPatternBase::getSharedMemoryObjectFromStruct( - loc, llB, rewriter); - - Type f32PtrTy = ptr_ty(f32_ty); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) - bPtrs[i] = gep(f32PtrTy, bSmem.base, bOff[i]); - - int N = bShape[bOrder[0]]; - ValueTable hbs; - - int nShapePerCTA = getShapePerCTAForMN(dLayout, false /*isM*/); - int nSizePerThread = getsizePerThreadForMN(dLayout, false /*isM*/); - - for (unsigned k = 0; k < NK; ++k) - for (unsigned n = 0; n < N; n += nShapePerCTA) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - Value pb = gep(f32PtrTy, bPtrs[0], - i32_val((n + nn) * strideBN + k * strideBK)); - Value vb = load(pb); - hbs[{n + nn, k}] = vb; - } - - return getStructFromValueTable(hbs, rewriter, loc); -} - -DotOpFMAConversionHelper::ValueTable -DotOpFMAConversionHelper::getValueTableFromStruct( - Value val, int K, int n0, int shapePerCTA, int sizePerThread, - ConversionPatternRewriter &rewriter, Location loc) const { - ValueTable res; - auto elems = ConvertTritonGPUOpToLLVMPatternBase::getElementsFromStruct( - loc, val, rewriter); - int id = 0; - std::set> keys; // ordered - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTA) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - keys.insert({m + mm, k}); - } - } - - for (auto &key : llvm::enumerate(keys)) { - res[key.value()] = elems[key.index()]; - } - - return res; -} -SmallVector DotOpFMAConversionHelper::getThreadIds( - Value threadId, ArrayRef shapePerCTA, - ArrayRef order, ConversionPatternRewriter &rewriter, - Location loc) const { - int dim = order.size(); - SmallVector threadIds(dim); - for (unsigned k = 0; k < dim - 1; k++) { - Value dimK = i32_val(shapePerCTA[order[k]]); - Value rem = urem(threadId, dimK); - threadId = udiv(threadId, dimK); - threadIds[order[k]] = rem; - } - Value dimK = i32_val(shapePerCTA[order[dim - 1]]); - threadIds[order[dim - 1]] = urem(threadId, dimK); - return threadIds; -} - LogicalResult DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h new file mode 100644 index 000000000..e9f791875 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -0,0 +1,260 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include +#include + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) +#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) +#define zext(...) rewriter.create(loc, __VA_ARGS__) +#define udiv(...) rewriter.create(loc, __VA_ARGS__) +#define urem(...) rewriter.create(loc, __VA_ARGS__) +#define add(...) rewriter.create(loc, __VA_ARGS__) +#define sub(...) rewriter.create(loc, __VA_ARGS__) +#define fadd(...) rewriter.create(loc, __VA_ARGS__) +#define mul(...) rewriter.create(loc, __VA_ARGS__) +#define smax(...) rewriter.create(loc, __VA_ARGS__) +#define umax(...) rewriter.create(loc, __VA_ARGS__) +#define fmax(...) rewriter.create(loc, __VA_ARGS__) +#define smin(...) rewriter.create(loc, __VA_ARGS__) +#define umin(...) rewriter.create(loc, __VA_ARGS__) +#define fmin(...) rewriter.create(loc, __VA_ARGS__) +#define and_(...) rewriter.create(loc, __VA_ARGS__) +#define xor_(...) rewriter.create(loc, __VA_ARGS__) +#define bitcast(val__, type__) \ + rewriter.create(loc, type__, val__) +#define gep(...) rewriter.create(loc, __VA_ARGS__) +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define insert_val(...) rewriter.create(loc, __VA_ARGS__) +#define extract_val(...) rewriter.create(loc, __VA_ARGS__) +#define insert_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define extract_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define load(...) rewriter.create(loc, __VA_ARGS__) +#define store(val, ptr) rewriter.create(loc, val, ptr) +#define icmp_eq(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) +#define icmp_ne(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) +#define icmp_slt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) +#define select(...) rewriter.create(loc, __VA_ARGS__) +#define address_of(...) rewriter.create(loc, __VA_ARGS__) +#define barrier() rewriter.create(loc) +#define undef(...) rewriter.create(loc, __VA_ARGS__) +#define i32_ty rewriter.getIntegerType(32) +#define ui32_ty rewriter.getIntegerType(32, false) +#define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() +#define i8_ty rewriter.getIntegerType(8) +#define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() +#define vec_ty(type, num) VectorType::get(num, type) +#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) +#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) + +// Creator for constant +#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) +#define int_val(width, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) +#define idx_val(...) \ + LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ + __VA_ARGS__) + +namespace mlir { +namespace LLVM { + +static Value getStructFromElements(Location loc, ValueRange resultVals, + ConversionPatternRewriter &rewriter, + Type structType) { + if (!structType.isa()) { + return *resultVals.begin(); + } + + Value llvmStruct = rewriter.create(loc, structType); + for (const auto &v : llvm::enumerate(resultVals)) { + assert(v.value() && "can not insert null values"); + llvmStruct = insert_val(structType, llvmStruct, v.value(), + rewriter.getI64ArrayAttr(v.index())); + } + return llvmStruct; +} + +static SmallVector +getElementsFromStruct(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + if (llvmStruct.getType().isIntOrIndexOrFloat() || + llvmStruct.getType().isa() || + llvmStruct.getType().isa()) + return {llvmStruct}; + ArrayRef types = + llvmStruct.getType().cast().getBody(); + SmallVector results(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i)); + } + return results; +} + +namespace { +using namespace mlir::triton; + +// Create a 32-bit integer constant. +Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +} // namespace + +/// Helper function to get strides from a given shape and its order +static SmallVector +getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, + Location loc, ConversionPatternRewriter &rewriter) { + auto rank = shape.size(); + SmallVector strides(rank); + auto stride = 1; + for (auto idx : order) { + strides[idx] = i32_val(stride); + stride *= shape[idx]; + } + return strides; +} + +struct SharedMemoryObject { + Value base; // i32 ptr. The start address of the shared memory object. + // We need to store strides as Values but not integers because the + // extract_slice instruction can take a slice at artibary offsets. + // Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a to be aware of that. + // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If + // we store strides into an attribute array of integers, the information + // cannot pass through block argument assignment because attributes are + // associated with operations but not Values. + // TODO(Keren): We may need to figure out a way to store strides as integers + // if we want to support more optimizations. + SmallVector + strides; // i32 int. The strides of the shared memory object. + SmallVector offsets; // i32 int. The offsets of the shared memory + // objects from the originally allocated object. + + SharedMemoryObject(Value base, ArrayRef strides, + ArrayRef offsets) + : base(base), strides(strides.begin(), strides.end()), + offsets(offsets.begin(), offsets.end()) {} + + SharedMemoryObject(Value base, ArrayRef shape, + ArrayRef order, Location loc, + ConversionPatternRewriter &rewriter) + : base(base) { + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); + + for (auto idx : order) { + offsets.emplace_back(i32_val(0)); + } + } + + SmallVector getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(strides.begin(), strides.end()); + elems.append(offsets.begin(), offsets.end()); + return elems; + } + + SmallVector getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(strides.size(), IntegerType::get(base.getContext(), 32)); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; + } + + Value getCSwizzleOffset(int order) const { + assert(order >= 0 && order < strides.size()); + return offsets[order]; + } + + Value getBaseBeforeSwizzle(int order, Location loc, + ConversionPatternRewriter &rewriter) const { + Value cSwizzleOffset = getCSwizzleOffset(order); + Value offset = sub(i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return gep(type, base, offset); + } +}; + +static SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); + auto rank = (elems.size() - 1) / 2; + return {/*base=*/elems[0], + /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; +} + +} // namespace LLVM +} // namespace mlir + +#endif