more work
This commit is contained in:
@@ -34,16 +34,17 @@ public:
|
||||
OpBuilder builder(cvtOp);
|
||||
auto srcType = cvtOp.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvtOp.getType().cast<RankedTensorType>();
|
||||
auto srcBlocked =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
if(srcEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return;
|
||||
auto dstDotOp =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (srcBlocked && dstDotOp) {
|
||||
if (dstDotOp) {
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(),
|
||||
triton::gpu::SharedEncodingAttr::get(
|
||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||
getOrder(srcBlocked), srcType.getElementType()));
|
||||
triton::gpu::getOrder(srcEncoding), srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
|
Reference in New Issue
Block a user