diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2bfbbb090..c5779c021 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3974,10 +3974,6 @@ struct InsertSliceAsyncOpConversion // %other SmallVector otherElems; if (llOther) { - // FIXME(Keren): always assume other is 0 for now - // It's not necessary for now because the pipeline pass will skip - // generating insert_slice_async if the load op has any "other" tensor. - // assert(false && "insert_slice_async: Other value not supported yet"); otherElems = getLLVMElems(other, llOther, rewriter, loc); assert(srcElems.size() == otherElems.size()); } @@ -4094,10 +4090,26 @@ struct InsertSliceAsyncOpConversion // We don't use predicate in this case, setting src-size to 0 // if there's any mask. cp.async will automatically fill the // remaining slots with 0 if cp-size > src-size. - // XXX(Keren): Always assume other = 0 for now. - auto selectOp = select(maskElems[elemIdx + wordElemIdx], - i32_val(byteWidth), i32_val(0)); + auto pred = maskElems[elemIdx + wordElemIdx]; + auto selectOp = select(pred, i32_val(byteWidth), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); + if (llOther) { + // Construct a new vecTy + auto vecTy = LLVM::getFixedVectorType(resElemTy, numWordElems); + Value v = rewriter.create(loc, vecTy); + for (size_t i = 0; i < numWordElems; ++i) { + Value falseVal = otherElems[elemIdx + wordElemIdx + i]; + Value indexVal = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), i); + v = insert_element(vecTy, v, falseVal, indexVal); + } + v = bitcast(v, IntegerType::get(getContext(), byteWidth)); + // Write shared memory if predicate is true + auto *valOperand = ptxBuilder.newOperand(v, "r"); + auto &st = *ptxBuilder.create("st"); + st.shared().o("b" + std::to_string(byteWidth)); + st(dstOperand, valOperand).predicate(pred); + } } copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); ptxBuilder.launch(rewriter, loc, void_ty(getContext()));