[BACKEND] Add isRow attribute for DotOp tensors whose parent is mmav1 (#970)

Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
Philippe Tillet
2022-12-11 19:01:57 -08:00
committed by GitHub
parent 4fb048873a
commit 52accd4c2b
7 changed files with 186 additions and 27 deletions

View File

@@ -3432,6 +3432,20 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(
} else if (!isOuter && mmaLayout.getVersion() == 1 &&
isHMMA) { // tensor core v1
DotOpMmaV1ConversionHelper helper(mmaLayout);
bool isMMAv1Row =
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<SharedEncodingAttr>();
// Can only convert [1, 0] to row or [0, 1] to col for now
if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) ||
(srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) {
llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n";
return Value();
}
if (dotOperandLayout.getOpIdx() == 0) { // operand $a
// TODO[Superjomn]: transA is not available here.
bool transA = false;
@@ -3544,6 +3558,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
.cast<RankedTensorType>()
.getEncoding()
.cast<MmaEncodingAttr>();
auto ALayout = A.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto BLayout = B.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<DotOperandEncodingAttr>();
auto ATensorTy = A.getType().cast<RankedTensorType>();
auto BTensorTy = B.getType().cast<RankedTensorType>();
@@ -3555,12 +3577,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor,
auto DShape = DTensorTy.getShape();
auto wpt = mmaLayout.getWarpsPerCTA();
// TODO[Superjomn]: order cannot accessed in DotOp.
SmallVector<unsigned> AOrder({1, 0});
SmallVector<unsigned> BOrder({1, 0});
bool isARow = AOrder[0] != 0;
bool isBRow = BOrder[0] != 0;
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes
bool isBVec4 = isBRow && BShape[isBRow] <= 16;
// TODO[Superjomn]: ld.v4 is not supported.