A more general pipeliner
This commit is contained in:
@@ -163,16 +163,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
return failure();
|
||||
Value a = adaptor.a();
|
||||
Value b = adaptor.b();
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
// if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||
// auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
// a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
// }
|
||||
// if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
// Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||
// auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
// b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
// }
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||
);
|
||||
|
Reference in New Issue
Block a user