trying to decrease register pressure
This commit is contained in:
@@ -65,19 +65,37 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
if (i % minVec == minVec - 1) {
|
||||
// step 1: recover the multidim_index from the index of
|
||||
SmallVector<Value> multiDimIdx = srcIndices[i];
|
||||
SmallVector<Value> dbgVal = srcIndices[i];
|
||||
Value dynIdx0 = multiDimIdx[outOrd[0]];
|
||||
Value staIdx0 = i32_val(0);
|
||||
Value dynIdx1 = multiDimIdx[outOrd[1]];
|
||||
Value staIdx1 = i32_val(0);
|
||||
// if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||
// if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
// dynIdx0 = addOp.getLhs();
|
||||
// staIdx0 = mul(udiv(cstRhs, minVecVal), minVecVal);
|
||||
// }
|
||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
dynIdx1 = addOp.getLhs();
|
||||
staIdx1 = addOp.getRhs();
|
||||
}
|
||||
|
||||
|
||||
// step 2: do swizzling
|
||||
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
|
||||
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
|
||||
Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
|
||||
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
|
||||
phaseId = urem(phaseId, i32_val(maxPhase));
|
||||
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
|
||||
off_0 = mul(off_0, outVecVal);
|
||||
Value stride0 = dstStrides[outOrd[0]];
|
||||
Value stride1 = dstStrides[outOrd[1]];
|
||||
|
||||
// offset along non-contiguous dimension
|
||||
Value off1 = mul(dynIdx1, stride1);
|
||||
// swizzled offset along contiguous dimension
|
||||
Value phaseId = urem(udiv(dynIdx1, i32_val(perPhase)), i32_val(maxPhase));
|
||||
Value off0 = xor_(udiv(dynIdx0, outVecVal), phaseId);
|
||||
off0 = mul(off0, outVecVal);
|
||||
Value remained = urem(dynIdx0, outVecVal);
|
||||
remained = udiv(remained, minVecVal);
|
||||
off_0 = add(off_0, mul(remained, minVecVal));
|
||||
Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
|
||||
off0 = add(off0, mul(remained, minVecVal));
|
||||
Value offset = add(off1, mul(off0, stride0));
|
||||
offset = add(offset, mul(staIdx1, stride1));
|
||||
offset = add(offset, mul(staIdx0, stride0));
|
||||
|
||||
// step 3: store
|
||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||
|
@@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
||||
#bench_flash_attention.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user