adapt isMMAv1Row in backend (#969)

This commit is contained in:
Yan Chunwei
2022-12-09 12:52:43 +08:00
committed by GitHub
parent 0d27912554
commit 13644e7ac4
3 changed files with 28 additions and 12 deletions

View File

@@ -434,7 +434,7 @@ section 9.7.13.4.1 for more details.
"Attribute":$parent), [{
if(parent.isa<MmaEncodingAttr>() &&
parent.cast<MmaEncodingAttr>().getVersion() == 1){
llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n";
llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field");
return {};
}
Attribute none;

View File

@@ -3428,8 +3428,12 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType().cast<RankedTensorType>().getEncoding().cast<SharedEncodingAttr>();
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
@@ -3550,6 +3554,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
@@ -3561,12 +3573,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.

View File

@@ -303,9 +303,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
[32, 32, 32, 1, 32, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, False, False],
[128, 32, 32, 1, 128, 32, 32, True, False],
[128, 32, 32, 1, 128, 32, 32, True, True],
# split-K
[16, 16, 32, 1, 16, 16, 16, False, False],
[64, 64, 128, 1, 64, 64, 32, False, False],
[16, 16, 32, 1, 16, 16, 16, True, False],
[16, 16, 32, 1, 16, 16, 16, True, True],
[64, 64, 128, 1, 64, 64, 32, True, True],
])
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)