This commit is contained in:
Philippe Tillet
2023-01-09 22:11:00 -08:00
parent d88353a5a4
commit ff04a5e9b6
4 changed files with 88 additions and 35 deletions

View File

@@ -72,24 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc,
Value staIdx1 = i32_val(0);
Value stride0 = dstStrides[outOrd[0]];
Value stride1 = dstStrides[outOrd[1]];
// if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
// if (auto cstRhs =
// dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
// unsigned rhsVal =
// cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
// unsigned key = (rhsVal / outVec) % maxPhase;
// if (cache.find(key) == cache.end())
// cache[key] = dynIdx0;
// dynIdx0 = cache[key];
// staIdx0 =
// i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
// }
// if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
// if (auto cstRhs =
// dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
// dynIdx1 = addOp.getLhs();
// staIdx1 = addOp.getRhs();
// }
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
if (auto cstRhs =
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
unsigned rhsVal =
cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = (rhsVal / outVec) % maxPhase;
if (cache.find(key) == cache.end())
cache[key] = dynIdx0;
dynIdx0 = cache[key];
staIdx0 =
i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
}
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
if (auto cstRhs =
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
dynIdx1 = addOp.getLhs();
staIdx1 = addOp.getRhs();
}
// offset along non-contiguous dimension
Value off1 = mul(dynIdx1, stride1);