Fix v100 fp32
This commit is contained in:
@@ -790,18 +790,16 @@ public:
|
||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
// for FMA, should retain the blocked layout.
|
||||
if (A.getElementType().isF32() && B.getElementType().isF32() &&
|
||||
!dotOp.allowTF32())
|
||||
if (!supportMMA(dotOp, version))
|
||||
return failure();
|
||||
|
||||
// get MMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
|
Reference in New Issue
Block a user