trying to decrease register pressure

This commit is contained in:
Phil Tillet
2023-01-05 13:02:38 -08:00
parent 1bde80b1e8
commit 764134ee34
2 changed files with 30 additions and 12 deletions

View File

@@ -65,19 +65,37 @@ void storeDistributedToShared(Value src, Value llSrc,
if (i % minVec == minVec - 1) {
// step 1: recover the multidim_index from the index of
SmallVector<Value> multiDimIdx = srcIndices[i];
SmallVector<Value> dbgVal = srcIndices[i];
Value dynIdx0 = multiDimIdx[outOrd[0]];
Value staIdx0 = i32_val(0);
Value dynIdx1 = multiDimIdx[outOrd[1]];
Value staIdx1 = i32_val(0);
// if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
// if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
// dynIdx0 = addOp.getLhs();
// staIdx0 = mul(udiv(cstRhs, minVecVal), minVecVal);
// }
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
dynIdx1 = addOp.getLhs();
staIdx1 = addOp.getRhs();
}
// step 2: do swizzling
Value remained = urem(multiDimIdx[outOrd[0]], outVecVal);
multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal);
Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]);
Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase));
phaseId = urem(phaseId, i32_val(maxPhase));
Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId);
off_0 = mul(off_0, outVecVal);
Value stride0 = dstStrides[outOrd[0]];
Value stride1 = dstStrides[outOrd[1]];
// offset along non-contiguous dimension
Value off1 = mul(dynIdx1, stride1);
// swizzled offset along contiguous dimension
Value phaseId = urem(udiv(dynIdx1, i32_val(perPhase)), i32_val(maxPhase));
Value off0 = xor_(udiv(dynIdx0, outVecVal), phaseId);
off0 = mul(off0, outVecVal);
Value remained = urem(dynIdx0, outVecVal);
remained = udiv(remained, minVecVal);
off_0 = add(off_0, mul(remained, minVecVal));
Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]]));
off0 = add(off0, mul(remained, minVecVal));
Value offset = add(off1, mul(off0, stride0));
offset = add(offset, mul(staIdx1, stride1));
offset = add(offset, mul(staIdx0, stride0));
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);

View File

@@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
bench_flash_attention.run(save_path='.', print_data=True)
#bench_flash_attention.run(save_path='.', print_data=True)