[Triton-MLIR][BACKEND] adapt DotOp layout for FMADot (#872)
This commit is contained in:
@@ -64,8 +64,7 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
}
|
||||
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() ||
|
||||
type.isa<triton::Float8Type>() ||
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
@@ -372,6 +371,9 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
|
||||
unsigned
|
||||
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = getParent().dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
}
|
||||
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user