[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -29,7 +29,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
this->context, shape, sizePerThread, order, this->numWarps);
|
||||
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
|
||||
});
|
||||
@@ -95,9 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding =
|
||||
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding &&
|
||||
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
Reference in New Issue
Block a user