diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 5f4c803e8..661b02afa 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -65,19 +65,37 @@ void storeDistributedToShared(Value src, Value llSrc, if (i % minVec == minVec - 1) { // step 1: recover the multidim_index from the index of SmallVector multiDimIdx = srcIndices[i]; - SmallVector dbgVal = srcIndices[i]; + Value dynIdx0 = multiDimIdx[outOrd[0]]; + Value staIdx0 = i32_val(0); + Value dynIdx1 = multiDimIdx[outOrd[1]]; + Value staIdx1 = i32_val(0); + // if(auto addOp = dyn_cast(dynIdx0.getDefiningOp())) + // if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { + // dynIdx0 = addOp.getLhs(); + // staIdx0 = mul(udiv(cstRhs, minVecVal), minVecVal); + // } + if(auto addOp = dyn_cast(dynIdx1.getDefiningOp())) + if(auto cstRhs = dyn_cast(addOp.getRhs().getDefiningOp())) { + dynIdx1 = addOp.getLhs(); + staIdx1 = addOp.getRhs(); + } + - // step 2: do swizzling - Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); - multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); - Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]); - Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); - phaseId = urem(phaseId, i32_val(maxPhase)); - Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); - off_0 = mul(off_0, outVecVal); + Value stride0 = dstStrides[outOrd[0]]; + Value stride1 = dstStrides[outOrd[1]]; + + // offset along non-contiguous dimension + Value off1 = mul(dynIdx1, stride1); + // swizzled offset along contiguous dimension + Value phaseId = urem(udiv(dynIdx1, i32_val(perPhase)), i32_val(maxPhase)); + Value off0 = xor_(udiv(dynIdx0, outVecVal), phaseId); + off0 = mul(off0, outVecVal); + Value remained = urem(dynIdx0, outVecVal); remained = udiv(remained, minVecVal); - off_0 = add(off_0, mul(remained, minVecVal)); - Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]])); + off0 = add(off0, mul(remained, minVecVal)); + Value offset = add(off1, mul(off0, stride0)); + offset = add(offset, mul(staIdx1, stride1)); + offset = add(offset, mul(staIdx0, stride0)); // step 3: store Value smemAddr = gep(elemPtrTy, smemBase, offset); diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 2da77fbc6..8099e010f 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) return ms -bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file +#bench_flash_attention.run(save_path='.', print_data=True) \ No newline at end of file