Fix v100 fp32

This commit is contained in:
Jokeren
2022-12-12 15:52:16 -08:00
parent 3a1c140385
commit d8d6b9f3f1
4 changed files with 21 additions and 26 deletions

View File

@@ -790,18 +790,16 @@ public:
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
int version = computeCapabilityToMMAVersion(computeCapability);
// for FMA, should retain the blocked layout.
if (A.getElementType().isF32() && B.getElementType().isF32() &&
!dotOp.allowTF32())
if (!supportMMA(dotOp, version))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int version = computeCapabilityToMMAVersion(computeCapability);
auto newRetType = RankedTensorType::get(
retShape, oldRetType.getElementType(),