dq now mma

This commit is contained in:
Philippe Tillet
2023-01-05 20:46:15 -08:00
parent 520b69fe70
commit 6f997f4ecb
2 changed files with 39 additions and 41 deletions

View File

@@ -72,23 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc,
Value staIdx1 = i32_val(0);
Value stride0 = dstStrides[outOrd[0]];
Value stride1 = dstStrides[outOrd[1]];
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
unsigned rhsVal = cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = (rhsVal/outVec) % maxPhase;
llvm::outs() << srcDistributedLayout.dyn_cast<MmaEncodingAttr>() << " " << rhsVal << " " << key << "\n";
if(cache.find(key) == cache.end())
cache[key] = dynIdx0;
dynIdx0 = cache[key];
staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase));
}
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
dynIdx1 = addOp.getLhs();
staIdx1 = addOp.getRhs();
}
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
if (auto cstRhs =
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
unsigned rhsVal =
cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
unsigned key = (rhsVal / outVec) % maxPhase;
if (cache.find(key) == cache.end())
cache[key] = dynIdx0;
dynIdx0 = cache[key];
staIdx0 =
i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
}
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
if (auto cstRhs =
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
dynIdx1 = addOp.getLhs();
staIdx1 = addOp.getRhs();
}
// offset along non-contiguous dimension
Value off1 = mul(dynIdx1, stride1);
@@ -100,10 +101,9 @@ void storeDistributedToShared(Value src, Value llSrc,
remained = udiv(remained, minVecVal);
off0 = add(off0, mul(remained, minVecVal));
Value offset = add(off1, mul(off0, stride0));
Value staOffset = add(mul(staIdx1, stride1), mul(staIdx0, stride0));
// add static offset
offset = add(offset, mul(staIdx1, stride1));
offset = add(offset, mul(staIdx0, stride0));
offset = add(offset, staOffset);
// step 3: store
Value smemAddr = gep(elemPtrTy, smemBase, offset);