update mma encoding & triton-opt

This commit is contained in:
Yan Da
2022-06-06 21:03:58 +08:00
parent 7807f64ef3
commit 366dddc3bc
13 changed files with 88 additions and 28 deletions

View File

@@ -112,11 +112,65 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
llvm_unreachable("Not implemented");
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> fragmentPerWarp;
SmallVector<unsigned, 2> shapePerWarp;
SmallVector<unsigned, 2> warpPerTile;
SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
if (parseIntArrayAttr(parser, attr, fragmentPerWarp, "fragmentPerWarp").failed())
return {};
} else if (attr.getName() == "shapePerWarp") {
if (parseIntArrayAttr(parser, attr, shapePerWarp, "shapePerWarp").failed())
return {};
} else if (attr.getName() == "warpPerTile") {
if (parseIntArrayAttr(parser, attr, warpPerTile, "warpPerTile").failed())
return {};
} else if (attr.getName() == "shapePerTile") {
if (parseIntArrayAttr(parser, attr, shapePerTile, "shapePerTile").failed())
return {};
} else if (attr.getName() == "repetitions") {
if (parseIntArrayAttr(parser, attr, repetitions, "repetitions").failed())
return {};
} else if (attr.getName() == "contigPerThread") {
if (parseIntArrayAttr(parser, attr, contigPerThread, "contigPerThread").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
fragmentPerWarp,
shapePerWarp,
warpPerTile,
shapePerTile,
repetitions,
contigPerThread);
}
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
printer << "<{"
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
<< ", warpPerTile = [" << getWarpPerTile() << "]"
<< ", shapePerTile = [" << getShapePerTile() << "]"
<< ", repetitions = [" << getRepetitions() << "]"
<< ", contigPerThread = [" << getContigPerThread() << "]"
<< "}>";
}
Attribute

View File

@@ -45,6 +45,6 @@ public:
}
};
std::unique_ptr<Pass> triton::gpu::createCombineOpsPass() {
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
return std::make_unique<TritonGPUCombineOpsPass>();
}

View File

@@ -100,6 +100,6 @@ private:
}
};
std::unique_ptr<Pass> triton::gpu::createTritonGPUVerifier() {
std::unique_ptr<Pass> mlir::createTritonGPUVerifier() {
return std::make_unique<TritonGPUVerifier>();
}