dq now mma
This commit is contained in:
@@ -72,23 +72,24 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
Value staIdx1 = i32_val(0);
|
||||
Value stride0 = dstStrides[outOrd[0]];
|
||||
Value stride1 = dstStrides[outOrd[1]];
|
||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
unsigned rhsVal = cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = (rhsVal/outVec) % maxPhase;
|
||||
llvm::outs() << srcDistributedLayout.dyn_cast<MmaEncodingAttr>() << " " << rhsVal << " " << key << "\n";
|
||||
if(cache.find(key) == cache.end())
|
||||
cache[key] = dynIdx0;
|
||||
dynIdx0 = cache[key];
|
||||
staIdx0 = i32_val((rhsVal)/(outVec*maxPhase)*(outVec*maxPhase));
|
||||
}
|
||||
if(auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||
if(auto cstRhs = dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
dynIdx1 = addOp.getLhs();
|
||||
staIdx1 = addOp.getRhs();
|
||||
}
|
||||
|
||||
|
||||
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx0.getDefiningOp()))
|
||||
if (auto cstRhs =
|
||||
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
unsigned rhsVal =
|
||||
cstRhs.getValue().cast<IntegerAttr>().getValue().getSExtValue();
|
||||
unsigned key = (rhsVal / outVec) % maxPhase;
|
||||
if (cache.find(key) == cache.end())
|
||||
cache[key] = dynIdx0;
|
||||
dynIdx0 = cache[key];
|
||||
staIdx0 =
|
||||
i32_val((rhsVal) / (outVec * maxPhase) * (outVec * maxPhase));
|
||||
}
|
||||
if (auto addOp = dyn_cast<LLVM::AddOp>(dynIdx1.getDefiningOp()))
|
||||
if (auto cstRhs =
|
||||
dyn_cast<LLVM::ConstantOp>(addOp.getRhs().getDefiningOp())) {
|
||||
dynIdx1 = addOp.getLhs();
|
||||
staIdx1 = addOp.getRhs();
|
||||
}
|
||||
|
||||
// offset along non-contiguous dimension
|
||||
Value off1 = mul(dynIdx1, stride1);
|
||||
@@ -100,10 +101,9 @@ void storeDistributedToShared(Value src, Value llSrc,
|
||||
remained = udiv(remained, minVecVal);
|
||||
off0 = add(off0, mul(remained, minVecVal));
|
||||
Value offset = add(off1, mul(off0, stride0));
|
||||
|
||||
Value staOffset = add(mul(staIdx1, stride1), mul(staIdx0, stride0));
|
||||
// add static offset
|
||||
offset = add(offset, mul(staIdx1, stride1));
|
||||
offset = add(offset, mul(staIdx0, stride0));
|
||||
offset = add(offset, staOffset);
|
||||
|
||||
// step 3: store
|
||||
Value smemAddr = gep(elemPtrTy, smemBase, offset);
|
||||
|
Reference in New Issue
Block a user