From 520b69fe70f3dea81c626bdce776313d4b056b85 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Thu, 5 Jan 2023 16:05:11 -0800 Subject: [PATCH] more reassociation --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 22 +++++++++++++------ python/tutorials/06-fused-attention.py | 2 +- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 661b02afa..9cce8bab6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -58,6 +58,7 @@ void storeDistributedToShared(Value src, Value llSrc, Value outVecVal = i32_val(outVec); Value minVecVal = i32_val(minVec); Value word; + std::map cache; for (unsigned i = 0; i < numElems; ++i) { if (i % minVec == 0) word = undef(wordTy); @@ -69,11 +70,18 @@ void storeDistributedToShared(Value src, Value llSrc, Value staIdx0 = i32_val(0); Value dynIdx1 = multiDimIdx[outOrd[1]]; Value staIdx1 = i32_val(0); - // if(auto addOp = dyn_cast(dynIdx0.getDefiningOp())) - // if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { - // dynIdx0 = addOp.getLhs(); - // staIdx0 = mul(udiv(cstRhs, minVecVal), minVecVal); - // } + Value stride0 = dstStrides[outOrd[0]]; + Value stride1 = dstStrides[outOrd[1]]; + if(auto addOp = dyn_cast(dynIdx0.getDefiningOp())) + if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { + unsigned rhsVal = cstRhs.getValue().cast().getValue().getSExtValue(); + unsigned key = (rhsVal/outVec) % maxPhase; + llvm::outs() << srcDistributedLayout.dyn_cast() << " " << rhsVal << " " << key << "\n"; + if(cache.find(key) == cache.end()) + cache[key] = dynIdx0; + dynIdx0 = cache[key]; + staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase)); + } if(auto addOp = dyn_cast(dynIdx1.getDefiningOp())) if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { dynIdx1 = addOp.getLhs(); @@ -81,8 +89,6 @@ void storeDistributedToShared(Value src, Value llSrc, } - Value stride0 = dstStrides[outOrd[0]]; - Value stride1 = dstStrides[outOrd[1]]; // offset along non-contiguous dimension Value off1 = mul(dynIdx1, stride1); @@ -94,6 +100,8 @@ void storeDistributedToShared(Value src, Value llSrc, remained = udiv(remained, minVecVal); off0 = add(off0, mul(remained, minVecVal)); Value offset = add(off1, mul(off0, stride0)); + + // add static offset offset = add(offset, mul(staIdx1, stride1)); offset = add(offset, mul(staIdx0, stride0)); diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 8099e010f..2da77fbc6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms -#bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file +bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file