update mma encoding & triton-opt
This commit is contained in:
@@ -112,11 +112,65 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
SmallVector<unsigned, 2> fragmentPerWarp;
|
||||
SmallVector<unsigned, 2> shapePerWarp;
|
||||
SmallVector<unsigned, 2> warpPerTile;
|
||||
SmallVector<unsigned, 2> shapePerTile;
|
||||
SmallVector<unsigned, 2> repetitions;
|
||||
SmallVector<unsigned, 2> contigPerThread;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerWarp") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpPerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "shapePerTile") {
|
||||
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "repetitions") {
|
||||
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "contigPerThread") {
|
||||
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
fragmentPerWarp,
|
||||
shapePerWarp,
|
||||
warpPerTile,
|
||||
shapePerTile,
|
||||
repetitions,
|
||||
contigPerThread);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
|
@@ -45,6 +45,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
|
@@ -100,6 +100,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> triton::gpu::createTritonGPUVerifier() {
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
|
||||
return std::make_unique<TritonGPUVerifier>();
|
||||
}
|
||||
|
Reference in New Issue
Block a user