Init
This commit is contained in:
@@ -3974,10 +3974,6 @@ struct InsertSliceAsyncOpConversion
|
|||||||
// %other
|
// %other
|
||||||
SmallVector<Value> otherElems;
|
SmallVector<Value> otherElems;
|
||||||
if (llOther) {
|
if (llOther) {
|
||||||
// FIXME(Keren): always assume other is 0 for now
|
|
||||||
// It's not necessary for now because the pipeline pass will skip
|
|
||||||
// generating insert_slice_async if the load op has any "other" tensor.
|
|
||||||
// assert(false && "insert_slice_async: Other value not supported yet");
|
|
||||||
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||||
assert(srcElems.size() == otherElems.size());
|
assert(srcElems.size() == otherElems.size());
|
||||||
}
|
}
|
||||||
@@ -4094,10 +4090,26 @@ struct InsertSliceAsyncOpConversion
|
|||||||
// We don't use predicate in this case, setting src-size to 0
|
// We don't use predicate in this case, setting src-size to 0
|
||||||
// if there's any mask. cp.async will automatically fill the
|
// if there's any mask. cp.async will automatically fill the
|
||||||
// remaining slots with 0 if cp-size > src-size.
|
// remaining slots with 0 if cp-size > src-size.
|
||||||
// XXX(Keren): Always assume other = 0 for now.
|
auto pred = maskElems[elemIdx + wordElemIdx];
|
||||||
auto selectOp = select(maskElems[elemIdx + wordElemIdx],
|
auto selectOp = select(pred, i32_val(byteWidth), i32_val(0));
|
||||||
i32_val(byteWidth), i32_val(0));
|
|
||||||
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
||||||
|
if (llOther) {
|
||||||
|
// Construct a new vecTy
|
||||||
|
auto vecTy = LLVM::getFixedVectorType(resElemTy, numWordElems);
|
||||||
|
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||||
|
for (size_t i = 0; i < numWordElems; ++i) {
|
||||||
|
Value falseVal = otherElems[elemIdx + wordElemIdx + i];
|
||||||
|
Value indexVal = createIndexAttrConstant(
|
||||||
|
rewriter, loc, this->getTypeConverter()->getIndexType(), i);
|
||||||
|
v = insert_element(vecTy, v, falseVal, indexVal);
|
||||||
|
}
|
||||||
|
v = bitcast(v, IntegerType::get(getContext(), byteWidth));
|
||||||
|
// Write shared memory if predicate is true
|
||||||
|
auto *valOperand = ptxBuilder.newOperand(v, "r");
|
||||||
|
auto &st = *ptxBuilder.create<PTXInstr>("st");
|
||||||
|
st.shared().o("b" + std::to_string(byteWidth));
|
||||||
|
st(dstOperand, valOperand).predicate(pred);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
||||||
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||||
|
Reference in New Issue
Block a user