adapt isMMAv1Row in backend (#969)
This commit is contained in:
@@ -434,7 +434,7 @@ section 9.7.13.4.1 for more details.
|
|||||||
"Attribute":$parent), [{
|
"Attribute":$parent), [{
|
||||||
if(parent.isa<MmaEncodingAttr>() &&
|
if(parent.isa<MmaEncodingAttr>() &&
|
||||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||||
llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n";
|
llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field");
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
Attribute none;
|
Attribute none;
|
||||||
|
@@ -3428,12 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
|||||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||||
isHMMA) { // tensor core v1
|
isHMMA) { // tensor core v1
|
||||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||||
bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
bool isMMAv1Row =
|
||||||
auto srcSharedLayout = src.getType().cast<RankedTensorType>().getEncoding().cast<SharedEncodingAttr>();
|
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
|
auto srcSharedLayout = src.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.cast<SharedEncodingAttr>();
|
||||||
|
|
||||||
// Can only convert [1, 0] to row or [0, 1] to col for now
|
// Can only convert [1, 0] to row or [0, 1] to col for now
|
||||||
if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){
|
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
|
||||||
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
||||||
return Value();
|
return Value();
|
||||||
}
|
}
|
||||||
@@ -3550,6 +3554,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getEncoding()
|
.getEncoding()
|
||||||
.cast<MmaEncodingAttr>();
|
.cast<MmaEncodingAttr>();
|
||||||
|
auto ALayout = A.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.cast<DotOperandEncodingAttr>();
|
||||||
|
auto BLayout = B.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.cast<DotOperandEncodingAttr>();
|
||||||
|
|
||||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||||
@@ -3561,12 +3573,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
|||||||
auto DShape = DTensorTy.getShape();
|
auto DShape = DTensorTy.getShape();
|
||||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||||
|
|
||||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
SmallVector<unsigned> AOrder({1, 0});
|
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||||
SmallVector<unsigned> BOrder({1, 0});
|
|
||||||
|
|
||||||
bool isARow = AOrder[0] != 0;
|
|
||||||
bool isBRow = BOrder[0] != 0;
|
|
||||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||||
// TODO[Superjomn]: ld.v4 is not supported.
|
// TODO[Superjomn]: ld.v4 is not supported.
|
||||||
|
@@ -303,9 +303,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
|||||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||||
|
|
||||||
|
[128, 32, 32, 1, 128, 32, 32, True, False],
|
||||||
|
[128, 32, 32, 1, 128, 32, 32, True, True],
|
||||||
|
|
||||||
# split-K
|
# split-K
|
||||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||||
|
|
||||||
|
[16, 16, 32, 1, 16, 16, 16, True, False],
|
||||||
|
[16, 16, 32, 1, 16, 16, 16, True, True],
|
||||||
|
[64, 64, 128, 1, 64, 64, 32, True, True],
|
||||||
])
|
])
|
||||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user