This commit is contained in:
Jokeren
2022-12-05 23:18:13 -08:00
parent 9490252261
commit 46fa29496c

View File

@@ -3974,10 +3974,6 @@ struct InsertSliceAsyncOpConversion
// %other // %other
SmallVector<Value> otherElems; SmallVector<Value> otherElems;
if (llOther) { if (llOther) {
// FIXME(Keren): always assume other is 0 for now
// It's not necessary for now because the pipeline pass will skip
// generating insert_slice_async if the load op has any "other" tensor.
// assert(false && "insert_slice_async: Other value not supported yet");
otherElems = getLLVMElems(other, llOther, rewriter, loc); otherElems = getLLVMElems(other, llOther, rewriter, loc);
assert(srcElems.size() == otherElems.size()); assert(srcElems.size() == otherElems.size());
} }
@@ -4094,10 +4090,26 @@ struct InsertSliceAsyncOpConversion
// We don't use predicate in this case, setting src-size to 0 // We don't use predicate in this case, setting src-size to 0
// if there's any mask. cp.async will automatically fill the // if there's any mask. cp.async will automatically fill the
// remaining slots with 0 if cp-size > src-size. // remaining slots with 0 if cp-size > src-size.
// XXX(Keren): Always assume other = 0 for now. auto pred = maskElems[elemIdx + wordElemIdx];
auto selectOp = select(maskElems[elemIdx + wordElemIdx], auto selectOp = select(pred, i32_val(byteWidth), i32_val(0));
i32_val(byteWidth), i32_val(0));
srcSize = ptxBuilder.newOperand(selectOp, "r"); srcSize = ptxBuilder.newOperand(selectOp, "r");
if (llOther) {
// Construct a new vecTy
auto vecTy = LLVM::getFixedVectorType(resElemTy, numWordElems);
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
for (size_t i = 0; i < numWordElems; ++i) {
Value falseVal = otherElems[elemIdx + wordElemIdx + i];
Value indexVal = createIndexAttrConstant(
rewriter, loc, this->getTypeConverter()->getIndexType(), i);
v = insert_element(vecTy, v, falseVal, indexVal);
}
v = bitcast(v, IntegerType::get(getContext(), byteWidth));
// Write shared memory if predicate is true
auto *valOperand = ptxBuilder.newOperand(v, "r");
auto &st = *ptxBuilder.create<PTXInstr>("st");
st.shared().o("b" + std::to_string(byteWidth));
st(dstOperand, valOperand).predicate(pred);
}
} }
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
ptxBuilder.launch(rewriter, loc, void_ty(getContext())); ptxBuilder.launch(rewriter, loc, void_ty(getContext()));