diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 6314e7a8e..3704fc089 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -4082,23 +4082,38 @@ struct InsertSliceAsyncOpConversion auto selectOp = select(pred, i32_val(byteWidth), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); if (llOther) { - // Construct a new vecTy - auto vecTy = LLVM::getFixedVectorType(resElemTy, numWordElems); - Value v = rewriter.create(loc, vecTy); - for (size_t i = 0; i < numWordElems; ++i) { - Value falseVal = otherElems[elemIdx + wordElemIdx + i]; - Value indexVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), i); - v = insert_element(vecTy, v, falseVal, indexVal); + auto storeVecSize = 4; + auto remStoreElems = numWordElems % storeVecSize; + auto constraint = + resElemTy.getIntOrFloatBitWidth() <= 32 ? "r" : "l"; + for (auto i = 0; i < numWordElems - remStoreElems; + i += storeVecSize) { + PTXBuilder ptxStoreBuilder; + auto *valOperands = ptxStoreBuilder.newListOperand(); + for (auto s = 0; s < storeVecSize; ++s) { + auto value = otherElems[elemIdx + wordElemIdx + i + s]; + auto *opr = ptxStoreBuilder.newOperand(value, constraint); + valOperands->listAppend(opr); + } + auto *storeDstOperand = ptxStoreBuilder.newAddrOperand( + basePtr, "r", (wordElemIdx + i) * resByteWidth); + auto &st = ptxStoreBuilder.create("st")->shared(); + st.v(storeVecSize).b(resElemTy.getIntOrFloatBitWidth()); + st(storeDstOperand, valOperands).predicate(pred); + ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext())); + } + for (auto i = numWordElems - remStoreElems; i < numWordElems; ++i) { + PTXBuilder ptxStoreBuilder; + auto value = otherElems[elemIdx + wordElemIdx + i]; + auto *storeValOperand = + ptxStoreBuilder.newOperand(value, constraint); + auto *storeDstOperand = ptxStoreBuilder.newAddrOperand( + basePtr, "r", (wordElemIdx + i) * resByteWidth); + auto &st = ptxStoreBuilder.create("st")->shared(); + st.b(resElemTy.getIntOrFloatBitWidth()); + st(storeDstOperand, storeValOperand).predicate(pred); + ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext())); } - v = bitcast(v, IntegerType::get(getContext(), bitWidth)); - // Write shared memory if predicate is true - PTXBuilder ptxStoreBuilder; - auto *valOperand = ptxStoreBuilder.newOperand(v, "r"); - auto &st = *ptxStoreBuilder.create("st"); - st.shared().o("b" + std::to_string(bitWidth)); - st(dstOperand, valOperand).predicate(pred); - ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext())); } } copyAsyncOp(dstOperand, srcOperand, copySize, srcSize); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 394ba1d0c..41debbf71 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -545,11 +545,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { %other_tensor = tt.splat %other : (f32) -> tensor<16x64xf32, #AL> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: @${{.*}} st.shared.b128 [ ${{.*}} + 0 ], ${{.*}} + // CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 0 ] // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}} // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: @${{.*}} st.shared.b128 [ ${{.*}} + 16 ], ${{.*}} + // CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 16 ] // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}} // CHECK: llvm.inline_asm has_side_effects asm_dialect = att