From 2fa0dfbce9203176983c1a2a495dc3089c59b973 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 9 Jan 2023 22:50:38 -0800 Subject: [PATCH] . --- .../DecomposeConversionsToDotOperand.cpp | 28 +++++++++++-------- python/tutorials/06-fused-attention.py | 2 +- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp index 29fc123c5..441287f8c 100644 --- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversionsToDotOperand.cpp @@ -39,19 +39,23 @@ public: return; auto dstDotOp = dstType.getEncoding().dyn_cast(); - if (dstDotOp) { - auto tmpType = RankedTensorType::get( - dstType.getShape(), dstType.getElementType(), - triton::gpu::SharedEncodingAttr::get( - mod.getContext(), dstDotOp, srcType.getShape(), - triton::gpu::getOrder(srcEncoding), srcType.getElementType())); - auto tmp = builder.create( - cvtOp.getLoc(), tmpType, cvtOp.getOperand()); - auto newConvert = builder.create( - cvtOp.getLoc(), dstType, tmp); - cvtOp.replaceAllUsesWith(newConvert.getResult()); - cvtOp.erase(); + if(!dstDotOp) + return; + if (auto srcMmaEncoding = srcEncoding.dyn_cast()) { + if(srcMmaEncoding.getWarpsPerCTA()[1] == 1 && dstDotOp.getParent()==srcMmaEncoding) + return; } + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), + triton::gpu::getOrder(srcEncoding), srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); }); } }; diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index c961c9a62..72675b577 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -345,7 +345,7 @@ configs = [triton.testing.Benchmark( ylabel='ms', plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} -) for mode in ['bwd']] +) for mode in ['fwd']] @triton.testing.perf_report(configs)