This commit is contained in:
Jokeren
2022-12-06 17:09:09 -08:00
parent e817fdf1b9
commit 43408fef5a
2 changed files with 33 additions and 18 deletions

View File

@@ -4082,23 +4082,38 @@ struct InsertSliceAsyncOpConversion
auto selectOp = select(pred, i32_val(byteWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r");
if (llOther) {
// Construct a new vecTy
auto vecTy = LLVM::getFixedVectorType(resElemTy, numWordElems);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t i = 0; i < numWordElems; ++i) {
Value falseVal = otherElems[elemIdx + wordElemIdx + i];
Value indexVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), i);
v = insert_element(vecTy, v, falseVal, indexVal);
auto storeVecSize = 4;
auto remStoreElems = numWordElems % storeVecSize;
auto constraint =
resElemTy.getIntOrFloatBitWidth() <= 32 ? "r" : "l";
for (auto i = 0; i < numWordElems - remStoreElems;
i += storeVecSize) {
PTXBuilder ptxStoreBuilder;
auto *valOperands = ptxStoreBuilder.newListOperand();
for (auto s = 0; s < storeVecSize; ++s) {
auto value = otherElems[elemIdx + wordElemIdx + i + s];
auto *opr = ptxStoreBuilder.newOperand(value, constraint);
valOperands->listAppend(opr);
}
auto *storeDstOperand = ptxStoreBuilder.newAddrOperand(
basePtr, "r", (wordElemIdx + i) * resByteWidth);
auto &st = ptxStoreBuilder.create<PTXInstr>("st")->shared();
st.v(storeVecSize).b(resElemTy.getIntOrFloatBitWidth());
st(storeDstOperand, valOperands).predicate(pred);
ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext()));
}
for (auto i = numWordElems - remStoreElems; i < numWordElems; ++i) {
PTXBuilder ptxStoreBuilder;
auto value = otherElems[elemIdx + wordElemIdx + i];
auto *storeValOperand =
ptxStoreBuilder.newOperand(value, constraint);
auto *storeDstOperand = ptxStoreBuilder.newAddrOperand(
basePtr, "r", (wordElemIdx + i) * resByteWidth);
auto &st = ptxStoreBuilder.create<PTXInstr>("st")->shared();
st.b(resElemTy.getIntOrFloatBitWidth());
st(storeDstOperand, storeValOperand).predicate(pred);
ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext()));
}
v = bitcast(v, IntegerType::get(getContext(), bitWidth));
// Write shared memory if predicate is true
PTXBuilder ptxStoreBuilder;
auto *valOperand = ptxStoreBuilder.newOperand(v, "r");
auto &st = *ptxStoreBuilder.create<PTXInstr>("st");
st.shared().o("b" + std::to_string(bitWidth));
st(dstOperand, valOperand).predicate(pred);
ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext()));
}
}
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);