diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 08b547468..3d4148e7b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2642,20 +2642,21 @@ public: Value elems[4]; Type elemTy = type::f32Ty(ctx); + Type elemPtrTy = ptr_ty(elemTy); if (kOrder == 1) { - elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); - elems[1] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); + elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); + elems[1] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); elems[2] = - load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); elems[3] = - load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); } else { - elems[0] = load(gep(elemTy, ptr, i32_val(sOffsetElem))); - elems[2] = load(gep(elemTy, ptr2, i32_val(sOffsetElem))); + elems[0] = load(gep(elemPtrTy, ptr, i32_val(sOffsetElem))); + elems[2] = load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem))); elems[1] = - load(gep(elemTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr, i32_val(sOffsetElem + sOffsetArrElem))); elems[3] = - load(gep(elemTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); + load(gep(elemPtrTy, ptr2, i32_val(sOffsetElem + sOffsetArrElem))); } return {elems[0], elems[1], elems[2], elems[3]}; @@ -2799,6 +2800,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { size_t reduceAxis = 1; unsigned K = AShape[reduceAxis]; bool isOuter = K == 1; + bool isMMA = D.getType() .cast() .getEncoding() @@ -2810,11 +2812,13 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { .getEncoding() .cast(); - if (!isOuter && isMMA) { + bool isHMMA = isDotHMMA(op); + if (!isOuter && isMMA && isHMMA) { if (mmaLayout.getVersion() == 1) return convertMMA884(op, adaptor, rewriter); if (mmaLayout.getVersion() == 2) return convertMMA16816(op, adaptor, rewriter); + llvm::report_fatal_error( "Unsupported MMA kind found when converting DotOp to LLVM."); } @@ -2827,6 +2831,46 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { "Unsupported DotOp found when converting TritonGPU to LLVM."); } + // Tell whether a DotOp support HMMA. + // This is port from the master branch, the original logic is retained. + static bool isDotHMMA(DotOp op) { + auto a = op.a(); + auto b = op.b(); + auto c = op.c(); + auto d = op.getResult(); + auto aTensorTy = a.getType().cast(); + auto bTensorTy = b.getType().cast(); + auto cTensorTy = c.getType().cast(); + auto dTensorTy = d.getType().cast(); + + if (!dTensorTy.getEncoding().isa()) + return false; + + auto mmaLayout = dTensorTy.getEncoding().cast(); + auto aElemTy = aTensorTy.getElementType(); + auto bElemTy = bTensorTy.getElementType(); + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + return (aElemTy.isF16() && bElemTy.isF16()) || + (aElemTy.isBF16() && bElemTy.isBF16()) || + (aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() && + mmaLayout.getVersion() >= 2) || + (aElemTy.isInteger(8) && bElemTy.isInteger(8) && + mmaLayout.getVersion() >= 2); + } + + // Tell whether a DotOp support HMMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. + // TODO[Superjomn]: Find a better way to implement it. + static bool isDotHMMA(TensorType operand, bool allowTF32, int mmaVersion) { + auto elemTy = operand.getElementType(); + return elemTy.isF16() || elemTy.isBF16() || + (elemTy.isF32() && allowTF32 && mmaVersion >= 2) || + (elemTy.isInteger(8) && mmaVersion >= 2); + } + private: // Convert to mma.m16n8k16 LogicalResult convertMMA16816(triton::DotOp a, OpAdaptor adaptor, @@ -3590,6 +3634,73 @@ private: } }; +// Helper for FMADot conversion. +class DotOpFMAConversionHelper { +public: + MmaEncodingAttr mmaLayout; + ArrayRef wpt; + + using ValueTable = std::map, Value>; + + explicit DotOpFMAConversionHelper(MmaEncodingAttr mmaLayout) + : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {} + + // Currently, we can tell whether to use FMAdot only from the operand type, + // while in the original code, FMADot requires that both the operand and + // result of dot should be fp32. + // This method should be safe to use in the cases where tensor core is not + // appliable. + static bool useFMA(TensorType operand) { + return operand.getElementType().isF32(); + } + + SmallVector getOrder() const { + SmallVector order(2); + if (mmaLayout.getVersion() == 1) + order = {0, 1}; + else if (mmaLayout.getVersion() == 0) + order = {1, 0}; + else { + assert(false && "Unexpected MMA version found."); + } + return order; + } + + Value loadA(Value tensor, Value llTensor, Value threadId, Location loc, + Value smem, ConversionPatternRewriter &rewriter) const { + + auto *ctx = rewriter.getContext(); + auto tensorTy = tensor.getType().cast(); + auto aShape = tensorTy.getShape(); + auto aLayout = tensorTy.getEncoding().cast(); + auto aOrder = aLayout.getOrder(); + auto order = getOrder(); + + bool isARow = aOrder[0] == 1; + + int strideAM = isARow ? aShape[1] : 1; + int strideAK = isARow ? 1 : aShape[0]; + int strideA0 = isARow ? strideAK : strideAM; + int strideA1 = isARow ? strideAM : strideAK; + int lda = isARow ? strideAM : strideAK; + int aPerPhase = aLayout.getPerPhase(); + int aMaxPhase = aLayout.getMaxPhase(); + int aNumPtr = 8; + int bNumPtr = 8; + int aVec = 2; + int NK = aShape[isARow ? 1 : 0]; + + return Value{}; + } + + Value loadB(Value tensor, Value llTensor, Value threadId, Location loc, + Value smem, ConversionPatternRewriter &rewriter) const { + return Value{}; + } + + ValueTable extractLoadedOperand(Value llTensor) const { return ValueTable{}; } +}; + LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -3605,8 +3716,24 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( dotOperandLayout.getParent().dyn_cast_or_null(); assert(mmaLayout); + bool isOuter{}; + { + int K{}; + if (dotOperandLayout.getOpIdx() == 0) // $a + K = dstTensorTy.getShape()[1]; + else // $b + K = dstTensorTy.getShape()[0]; + isOuter = K == 1; + } + + // TODO[Superjomn]: the allowTF32 is not available in ConvertLayoutOp for it + // is an attribute of DotOp. + bool allowTF32 = false; + bool isHMMA = DotOpConversion::isDotHMMA(dstTensorTy, allowTF32, + mmaLayout.getVersion()); + Value res; - if (mmaLayout.getVersion() == 2) { + if (!isOuter && mmaLayout.getVersion() == 2 && isHMMA) { // tensor core v2 MMA16816ConversionHelper mmaHelper(mmaLayout, getThreadId(rewriter, loc), rewriter, getTypeConverter(), op.getLoc()); @@ -3618,7 +3745,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( // operand $b res = mmaHelper.loadB(src, adaptor.src()); } - } else if (mmaLayout.getVersion() == 1) { + } else if (!isOuter && mmaLayout.getVersion() == 1 && + isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); if (dotOperandLayout.getOpIdx() == 0) { // operand $a @@ -3629,6 +3757,8 @@ LogicalResult ConvertLayoutOpConversion::lowerSharedToDotOperand( res = helper.loadB(src, adaptor.src(), getThreadId(rewriter, loc), adaptor.src(), loc, rewriter); } + } else if (DotOpFMAConversionHelper::useFMA(dstTensorTy)) { // fmadot + } else { assert(false && "Unsupported mma layout found"); } @@ -4184,9 +4314,11 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, for (int i = 0; i < bNumPtr; ++i) bPtrs[i] = gep(f32PtrTy, adaptor.b(), bOff[i]); + // TODO initialize ret with $c. std::map, Value> has, hbs; - // TODO initialize ret with zeros. - SmallVector ret(NK); + auto cc = getElementsFromStruct(loc, adaptor.c(), rewriter); + SmallVector ret(cShape[0] * cShape[1], cc[0]); + for (unsigned k = 0; k < NK; k++) { int z = 0; for (unsigned i = 0; i < cShape[order[1]]; i += cShapePerCTA[order[1]]) @@ -4203,14 +4335,15 @@ DotOpConversion::convertFMADot(triton::DotOp op, OpAdaptor adaptor, Value va = load(pa); has[{m + mm, k}] = va; } - if (!has.count({n + nn, k})) { + if (!hbs.count({n + nn, k})) { Value pb = gep(f32PtrTy, bPtrs[0], i32_val((n + nn) * strideBN + k * strideBK)); Value vb = load(pb); - has[{n + nn, k}] = vb; + hbs[{n + nn, k}] = vb; } - ret[z++] = rewriter.create( - loc, has[{m + mm, k}], hbs[{n + nn, k}], ret[z]); + ret[z] = rewriter.create(loc, has[{m + mm, k}], + hbs[{n + nn, k}], ret[z]); + ++z; } } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 807bd1396..eae06074a 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -789,3 +789,19 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { return } } + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#mma = #triton_gpu.mma<{version = 2, warpsPerCTA = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + // CHECK: llvm.intr.fmuladd + %28 = tt.dot %a, %b, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #shared> * tensor<16x32xf32, #shared> -> tensor<16x16xf32, #mma> + return + } +}