more optimizations

This commit is contained in:
Phil Tillet
2023-01-06 20:27:49 -08:00
parent 18c7a72973
commit 600bcefb12
4 changed files with 262 additions and 28 deletions

View File

@@ -27,7 +27,32 @@ class TritonGPUDecomposeConversionsToDotOperandPass
public:
TritonGPUDecomposeConversionsToDotOperandPass() = default;
void runOnOperation() override { return; }
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
OpBuilder builder(cvtOp);
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
auto dstType = cvtOp.getType().cast<RankedTensorType>();
auto srcBlocked =
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (srcBlocked && dstDotOp) {
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
getOrder(srcBlocked), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
});
}
};
std::unique_ptr<Pass>