[OPTIMIZER] Rewrite patterns for layout conversions (#64)

This commit is contained in:
Philippe Tillet
2022-08-18 12:49:37 -07:00
committed by GitHub
parent e0bedeb44c
commit 192be76b3c
19 changed files with 851 additions and 127 deletions

View File

@@ -156,8 +156,7 @@ struct TritonExpandDimsPattern
Attribute _argEncoding = argType.getEncoding();
if (!_argEncoding)
return failure();
auto argEncoding =
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
// return shape
auto retShape = argType.getShape().vec();
retShape.insert(retShape.begin() + op.axis(), 1);
@@ -170,10 +169,10 @@ struct TritonExpandDimsPattern
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
SmallVector<unsigned, 4> retOrder(retShape.size());
std::iota(retOrder.begin(), retOrder.end(), 0);
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
triton::gpu::TritonGPUBlockedEncodingAttr::get(
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
retOrder);
triton::gpu::BlockedEncodingAttr retEncoding =
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder);
// return type
RankedTensorType retType =
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
@@ -201,16 +200,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value a = adaptor.a();
Value b = adaptor.b();
SmallVector<unsigned, 2> order{1, 0};
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
getContext(), 1, 1, 1, order);
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
Attribute encoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
getContext(), 1, 1, 1, order);
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
Attribute encoding =
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);