.
This commit is contained in:
@@ -39,19 +39,23 @@ public:
|
|||||||
return;
|
return;
|
||||||
auto dstDotOp =
|
auto dstDotOp =
|
||||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||||
if (dstDotOp) {
|
if(!dstDotOp)
|
||||||
auto tmpType = RankedTensorType::get(
|
return;
|
||||||
dstType.getShape(), dstType.getElementType(),
|
if (auto srcMmaEncoding = srcEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
|
||||||
triton::gpu::SharedEncodingAttr::get(
|
if(srcMmaEncoding.getWarpsPerCTA()[1] == 1 && dstDotOp.getParent()==srcMmaEncoding)
|
||||||
mod.getContext(), dstDotOp, srcType.getShape(),
|
return;
|
||||||
triton::gpu::getOrder(srcEncoding), srcType.getElementType()));
|
|
||||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
||||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
|
||||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
|
||||||
cvtOp.getLoc(), dstType, tmp);
|
|
||||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
|
||||||
cvtOp.erase();
|
|
||||||
}
|
}
|
||||||
|
auto tmpType = RankedTensorType::get(
|
||||||
|
dstType.getShape(), dstType.getElementType(),
|
||||||
|
triton::gpu::SharedEncodingAttr::get(
|
||||||
|
mod.getContext(), dstDotOp, srcType.getShape(),
|
||||||
|
triton::gpu::getOrder(srcEncoding), srcType.getElementType()));
|
||||||
|
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||||
|
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||||
|
cvtOp.getLoc(), dstType, tmp);
|
||||||
|
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||||
|
cvtOp.erase();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark(
|
|||||||
ylabel='ms',
|
ylabel='ms',
|
||||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||||
) for mode in ['bwd']]
|
) for mode in ['fwd']]
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
@triton.testing.perf_report(configs)
|
||||||
|
Reference in New Issue
Block a user