Fix v100 fp32
This commit is contained in:
@@ -3336,9 +3336,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||
}
|
||||
|
||||
if (op.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
||||
A.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
||||
!op.allowTF32())
|
||||
// XXX: fp64 has not been tested yet. In theory, it should work.
|
||||
if (!isMMA)
|
||||
return convertFMADot(op, adaptor, rewriter);
|
||||
|
||||
llvm::report_fatal_error(
|
||||
@@ -3348,33 +3347,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
// Tell whether a DotOp support HMMA.
|
||||
// This is port from the master branch, the original logic is retained.
|
||||
static bool isDotHMMA(DotOp op) {
|
||||
auto a = op.a();
|
||||
auto b = op.b();
|
||||
auto c = op.c();
|
||||
auto d = op.getResult();
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
|
||||
return false;
|
||||
|
||||
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
auto aElemTy = aTensorTy.getElementType();
|
||||
auto bElemTy = bTensorTy.getElementType();
|
||||
|
||||
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
|
||||
"Unexpected MMA layout version found");
|
||||
// Refer to mma section for the data type supported by Volta and Hopper
|
||||
// Tensor Core in
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||
mmaLayout.getVersion() >= 2) ||
|
||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
|
||||
mmaLayout.getVersion() >= 2);
|
||||
return supportMMA(op, mmaLayout.getVersion());
|
||||
}
|
||||
|
||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||
|
Reference in New Issue
Block a user