more reassociation

This commit is contained in:
Phil Tillet
2023-01-05 16:05:11 -08:00
parent 764134ee34
commit 520b69fe70
2 changed files with 16 additions and 8 deletions

View File

@@ -58,6 +58,7 @@ void storeDistributedToShared(Value src, Value llSrc,
Value outVecVal = i32_val(outVec);
Value minVecVal = i32_val(minVec);
Value word;
std::map<unsigned, Value> cache;
for (unsigned i = 0; i < numElems; ++i) {
if (i % minVec == 0)
word = undef(wordTy);
@@ -69,11 +70,18 @@ void storeDistributedToShared(Value src, Value llSrc,
Value staIdx0 = i32_val(0);
Value dynIdx1 = multiDimIdx[outOrd[1]];
Value staIdx1 = i32_val(0);
// if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
// if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
// dynIdx0 = addOp.getLhs();
// staIdx0 = mul(udiv(cstRhs, minVecVal), minVecVal);
// }
Value stride0 = dstStrides[outOrd[0]];
Value stride1 = dstStrides[outOrd[1]];
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
unsigned rhsVal = cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = (rhsVal/outVec) % maxPhase;
llvm::outs() << srcDistributedLayout.dyn_cast<MmaEncodingAttr>() << " " << rhsVal << " " << key << "\n";
if(cache.find(key) == cache.end())
cache[key] = dynIdx0;
dynIdx0 = cache[key];
staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase));
}
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
dynIdx1 = addOp.getLhs();
@@ -81,8 +89,6 @@ void storeDistributedToShared(Value src, Value llSrc,
}
Value stride0 = dstStrides[outOrd[0]];
Value stride1 = dstStrides[outOrd[1]];
// offset along non-contiguous dimension
Value off1 = mul(dynIdx1, stride1);
@@ -94,6 +100,8 @@ void storeDistributedToShared(Value src, Value llSrc,
remained = udiv(remained, minVecVal);
off0 = add(off0, mul(remained, minVecVal));
Value offset = add(off1, mul(off0, stride0));
// add static offset
offset = add(offset, mul(staIdx1, stride1));
offset = add(offset, mul(staIdx0, stride0));

View File

@@ -376,4 +376,4 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
return ms
#bench_flash_attention.run(save_path='.', print_data=True)
bench_flash_attention.run(save_path='.', print_data=True)