This commit is contained in:
Phil Tillet
2023-01-09 22:50:38 -08:00
parent 993bc17311
commit 2fa0dfbce9
2 changed files with 17 additions and 13 deletions

View File

@@ -39,19 +39,23 @@ public:
return; return;
auto dstDotOp = auto dstDotOp =
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>(); dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
if (dstDotOp) { if(!dstDotOp)
auto tmpType = RankedTensorType::get( return;
dstType.getShape(), dstType.getElementType(), if (auto srcMmaEncoding = srcEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
triton::gpu::SharedEncodingAttr::get( if(srcMmaEncoding.getWarpsPerCTA()[1] == 1 && dstDotOp.getParent()==srcMmaEncoding)
mod.getContext(), dstDotOp, srcType.getShape(), return;
triton::gpu::getOrder(srcEncoding), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
} }
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(),
triton::gpu::getOrder(srcEncoding), srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}); });
} }
}; };

View File

@@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark(
ylabel='ms', ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
) for mode in ['bwd']] ) for mode in ['fwd']]
@triton.testing.perf_report(configs) @triton.testing.perf_report(configs)