Fix v100 fp32
This commit is contained in:
@@ -43,6 +43,8 @@ bool maybeSharedAllocationOp(Operation *op);
|
|||||||
|
|
||||||
bool maybeAliasOp(Operation *op);
|
bool maybeAliasOp(Operation *op);
|
||||||
|
|
||||||
|
bool supportMMA(triton::DotOp op, int version);
|
||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state);
|
std::string getValueOperandName(Value value, AsmState &state);
|
||||||
|
|
||||||
template <typename T_OUT, typename T_IN>
|
template <typename T_OUT, typename T_IN>
|
||||||
|
@@ -110,6 +110,19 @@ bool maybeAliasOp(Operation *op) {
|
|||||||
isa<tensor::InsertSliceOp>(op);
|
isa<tensor::InsertSliceOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool supportMMA(triton::DotOp op, int version) {
|
||||||
|
// Refer to mma section for the data type supported by Volta and Hopper
|
||||||
|
// Tensor Core in
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
||||||
|
auto aElemTy = op.a().getType().cast<RankedTensorType>().getElementType();
|
||||||
|
auto bElemTy = op.b().getType().cast<RankedTensorType>().getElementType();
|
||||||
|
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
||||||
|
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
||||||
|
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
||||||
|
version >= 2) ||
|
||||||
|
(aElemTy.isInteger(8) && bElemTy.isInteger(8) && version >= 2);
|
||||||
|
}
|
||||||
|
|
||||||
std::string getValueOperandName(Value value, AsmState &state) {
|
std::string getValueOperandName(Value value, AsmState &state) {
|
||||||
std::string opName;
|
std::string opName;
|
||||||
llvm::raw_string_ostream ss(opName);
|
llvm::raw_string_ostream ss(opName);
|
||||||
|
@@ -3336,9 +3336,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
"Unsupported MMA kind found when converting DotOp to LLVM.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
// XXX: fp64 has not been tested yet. In theory, it should work.
|
||||||
A.getType().cast<RankedTensorType>().getElementType().isF32() &&
|
if (!isMMA)
|
||||||
!op.allowTF32())
|
|
||||||
return convertFMADot(op, adaptor, rewriter);
|
return convertFMADot(op, adaptor, rewriter);
|
||||||
|
|
||||||
llvm::report_fatal_error(
|
llvm::report_fatal_error(
|
||||||
@@ -3348,33 +3347,16 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
|||||||
// Tell whether a DotOp support HMMA.
|
// Tell whether a DotOp support HMMA.
|
||||||
// This is port from the master branch, the original logic is retained.
|
// This is port from the master branch, the original logic is retained.
|
||||||
static bool isDotHMMA(DotOp op) {
|
static bool isDotHMMA(DotOp op) {
|
||||||
auto a = op.a();
|
|
||||||
auto b = op.b();
|
|
||||||
auto c = op.c();
|
|
||||||
auto d = op.getResult();
|
auto d = op.getResult();
|
||||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
|
||||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
|
||||||
auto cTensorTy = c.getType().cast<RankedTensorType>();
|
|
||||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
|
if (!dTensorTy.getEncoding().isa<MmaEncodingAttr>())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
auto mmaLayout = dTensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||||
auto aElemTy = aTensorTy.getElementType();
|
|
||||||
auto bElemTy = bTensorTy.getElementType();
|
|
||||||
|
|
||||||
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
|
assert((mmaLayout.getVersion() == 1 || mmaLayout.getVersion() == 2) &&
|
||||||
"Unexpected MMA layout version found");
|
"Unexpected MMA layout version found");
|
||||||
// Refer to mma section for the data type supported by Volta and Hopper
|
return supportMMA(op, mmaLayout.getVersion());
|
||||||
// Tensor Core in
|
|
||||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
|
||||||
return (aElemTy.isF16() && bElemTy.isF16()) ||
|
|
||||||
(aElemTy.isBF16() && bElemTy.isBF16()) ||
|
|
||||||
(aElemTy.isF32() && bElemTy.isF32() && op.allowTF32() &&
|
|
||||||
mmaLayout.getVersion() >= 2) ||
|
|
||||||
(aElemTy.isInteger(8) && bElemTy.isInteger(8) &&
|
|
||||||
mmaLayout.getVersion() >= 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
||||||
|
@@ -790,18 +790,16 @@ public:
|
|||||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto A = dotOp.getOperand(0).getType().cast<RankedTensorType>();
|
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||||
auto B = dotOp.getOperand(1).getType().cast<RankedTensorType>();
|
|
||||||
// for FMA, should retain the blocked layout.
|
// for FMA, should retain the blocked layout.
|
||||||
if (A.getElementType().isF32() && B.getElementType().isF32() &&
|
if (!supportMMA(dotOp, version))
|
||||||
!dotOp.allowTF32())
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// get MMA encoding for the given number of warps
|
// get MMA encoding for the given number of warps
|
||||||
auto retShape = oldRetType.getShape();
|
auto retShape = oldRetType.getShape();
|
||||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
|
||||||
|
|
||||||
auto newRetType = RankedTensorType::get(
|
auto newRetType = RankedTensorType::get(
|
||||||
retShape, oldRetType.getElementType(),
|
retShape, oldRetType.getElementType(),
|
||||||
|
Reference in New Issue
Block a user