more reassociation
This commit is contained in:
@@ -58,6 +58,7 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
Value outVecVal = i32_val(outVec);
|
||||
Value minVecVal = i32_val(minVec);
|
||||
Value word;
|
||||
std::map<unsigned, Value> cache;
|
||||
for (unsigned i = 0; i < numElems; ++i) {
|
||||
if (i % minVec == 0)
|
||||
word = undef(wordTy);
|
||||
@@ -69,11 +70,18 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
Value staIdx0 = i32_val(0);
|
||||
Value dynIdx1 = multiDimIdx[outOrd[1]];
|
||||
Value staIdx1 = i32_val(0);
|
||||
// if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||
// if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
// dynIdx0 = addOp.getLhs();
|
||||
// staIdx0 = mul(udiv(cstRhs, minVecVal), minVecVal);
|
||||
// }
|
||||
Value stride0 = dstStrides[outOrd[0]];
|
||||
Value stride1 = dstStrides[outOrd[1]];
|
||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
unsigned rhsVal = cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = (rhsVal/outVec) % maxPhase;
|
||||
llvm::outs() << srcDistributedLayout.dyn_cast<MmaEncodingAttr>() << " " << rhsVal << " " << key << "\n";
|
||||
if(cache.find(key) == cache.end())
|
||||
cache[key] = dynIdx0;
|
||||
dynIdx0 = cache[key];
|
||||
staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase));
|
||||
}
|
||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
dynIdx1 = addOp.getLhs();
|
||||
@@ -81,8 +89,6 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
}
|
||||
|
||||
|
||||
Value stride0 = dstStrides[outOrd[0]];
|
||||
Value stride1 = dstStrides[outOrd[1]];
|
||||
|
||||
// offset along non-contiguous dimension
|
||||
Value off1 = mul(dynIdx1, stride1);
|
||||
@@ -94,6 +100,8 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
remained = udiv(remained, minVecVal);
|
||||
off0 = add(off0, mul(remained, minVecVal));
|
||||
Value offset = add(off1, mul(off0, stride0));
|
||||
|
||||
// add static offset
|
||||
offset = add(offset, mul(staIdx1, stride1));
|
||||
offset = add(offset, mul(staIdx0, stride0));
|
||||
|
||||
|
@@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
#bench_flash_attention.run(save_path='.', print_data=True)
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
Reference in New Issue
Block a user