A more general pipeliner

This commit is contained in:
Yan Da
2022-05-25 21:52:51 +08:00
parent 441fd7c3cc
commit 9308e9c90c
3 changed files with 210 additions and 112 deletions

View File

@@ -163,16 +163,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
// if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
// auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
// a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
// }
// if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
// auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
// b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
// }
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
);