diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 852e54532..20be101c2 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -43,6 +43,8 @@ bool maybeSharedAllocationOp(Operation *op); bool maybeAliasOp(Operation *op); +bool supportMMA(triton::DotOp op, int version); + std::string getValueOperandName(Value value, AsmState &state); template diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 57b3cfbfa..cbd5defa5 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -110,6 +110,19 @@ bool maybeAliasOp(Operation *op) { isa(op); } +bool supportMMA(triton::DotOp op, int version) { + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + auto aElemTy = op.a().getType().cast().getElementType(); + auto bElemTy = op.b().getType().cast().getElementType(); + return (aElemTy.isF16() && bElemTy.isF16()) || + (aElemTy.isBF16() && bElemTy.isBF16()) || + (aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() && + version >= 2) || + (aElemTy.isInteger(8) && bElemTy.isInteger(8) && version >= 2); +} + std::string getValueOperandName(Value value, AsmState &state) { std::string opName; llvm::raw_string_ostream ss(opName); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index f6d158625..ce59cb214 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3336,9 +3336,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { "Unsupported MMA kind found when converting DotOp to LLVM."); } - if (op.getType().cast().getElementType().isF32() && - A.getType().cast().getElementType().isF32() && - !op.allowTF32()) + // XXX: fp64 has not been tested yet. In theory, it should work. + if (!isMMA) return convertFMADot(op, adaptor, rewriter); llvm::report_fatal_error( @@ -3348,33 +3347,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { // Tell whether a DotOp support HMMA. // This is port from the master branch, the original logic is retained. static bool isDotHMMA(DotOp op) { - auto a = op.a(); - auto b = op.b(); - auto c = op.c(); auto d = op.getResult(); - auto aTensorTy = a.getType().cast(); - auto bTensorTy = b.getType().cast(); - auto cTensorTy = c.getType().cast(); auto dTensorTy = d.getType().cast(); if (!dTensorTy.getEncoding().isa()) return false; auto mmaLayout = dTensorTy.getEncoding().cast(); - auto aElemTy = aTensorTy.getElementType(); - auto bElemTy = bTensorTy.getElementType(); - assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) && "Unexpected MMA layout version found"); - // Refer to mma section for the data type supported by Volta and Hopper - // Tensor Core in - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 - return (aElemTy.isF16() && bElemTy.isF16()) || - (aElemTy.isBF16() && bElemTy.isBF16()) || - (aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() && - mmaLayout.getVersion() >= 2) || - (aElemTy.isInteger(8) && bElemTy.isInteger(8) && - mmaLayout.getVersion() >= 2); + return supportMMA(op, mmaLayout.getVersion()); } // Tell whether a DotOp support HMMA by the operand type(either $a or $b). diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 3bc68feba..193a1faa8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -790,18 +790,16 @@ public: if (oldRetType.getEncoding().isa()) return failure(); - auto A = dotOp.getOperand(0).getType().cast(); - auto B = dotOp.getOperand(1).getType().cast(); + int version = computeCapabilityToMMAVersion(computeCapability); + // for FMA, should retain the blocked layout. - if (A.getElementType().isF32() && B.getElementType().isF32() && - !dotOp.allowTF32()) + if (!supportMMA(dotOp, version)) return failure(); // get MMA encoding for the given number of warps auto retShape = oldRetType.getShape(); auto mod = op->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - int version = computeCapabilityToMMAVersion(computeCapability); auto newRetType = RankedTensorType::get( retShape, oldRetType.getElementType(),