adapt isMMAv1Row in backend (#969)
This commit is contained in:
@@ -434,7 +434,7 @@ section 9.7.13.4.1 for more details.
|
||||
"Attribute":$parent), [{
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().getVersion() == 1){
|
||||
llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n";
|
||||
llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field");
|
||||
return {};
|
||||
}
|
||||
Attribute none;
|
||||
|
@@ -3428,12 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
|
||||
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
|
||||
isHMMA) { // tensor core v1
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto srcSharedLayout = src.getType().cast<RankedTensorType>().getEncoding().cast<SharedEncodingAttr>();
|
||||
bool isMMAv1Row =
|
||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<SharedEncodingAttr>();
|
||||
|
||||
// Can only convert [1, 0] to row or [0, 1] to col for now
|
||||
if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){
|
||||
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
|
||||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
|
||||
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
|
||||
return Value();
|
||||
}
|
||||
@@ -3550,6 +3554,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<MmaEncodingAttr>();
|
||||
auto ALayout = A.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto BLayout = B.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
auto ATensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto BTensorTy = B.getType().cast<RankedTensorType>();
|
||||
@@ -3561,12 +3573,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
|
||||
auto DShape = DTensorTy.getShape();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
|
||||
// TODO[Superjomn]: order cannot accessed in DotOp.
|
||||
SmallVector<unsigned> AOrder({1, 0});
|
||||
SmallVector<unsigned> BOrder({1, 0});
|
||||
|
||||
bool isARow = AOrder[0] != 0;
|
||||
bool isBRow = BOrder[0] != 0;
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
|
||||
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
|
||||
// TODO[Superjomn]: ld.v4 is not supported.
|
||||
|
@@ -303,9 +303,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||
|
||||
[128, 32, 32, 1, 128, 32, 32, True, False],
|
||||
[128, 32, 32, 1, 128, 32, 32, True, True],
|
||||
|
||||
# split-K
|
||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||
|
||||
[16, 16, 32, 1, 16, 16, 16, True, False],
|
||||
[16, 16, 32, 1, 16, 16, 16, True, True],
|
||||
[64, 64, 128, 1, 64, 64, 32, True, True],
|
||||
])
|
||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||
|
||||
|
Reference in New Issue
Block a user