[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -156,8 +156,7 @@ struct TritonExpandDimsPattern
|
||||
Attribute _argEncoding = argType.getEncoding();
|
||||
if (!_argEncoding)
|
||||
return failure();
|
||||
auto argEncoding =
|
||||
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
|
||||
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
||||
// return shape
|
||||
auto retShape = argType.getShape().vec();
|
||||
retShape.insert(retShape.begin() + op.axis(), 1);
|
||||
@@ -170,10 +169,10 @@ struct TritonExpandDimsPattern
|
||||
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
|
||||
SmallVector<unsigned, 4> retOrder(retShape.size());
|
||||
std::iota(retOrder.begin(), retOrder.end(), 0);
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
triton::gpu::BlockedEncodingAttr retEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
@@ -201,16 +200,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value a = adaptor.a();
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
|
Reference in New Issue
Block a user