Fix
This commit is contained in:
@@ -4082,24 +4082,39 @@ struct InsertSliceAsyncOpConversion
|
|||||||
auto selectOp = select(pred, i32_val(byteWidth), i32_val(0));
|
auto selectOp = select(pred, i32_val(byteWidth), i32_val(0));
|
||||||
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
srcSize = ptxBuilder.newOperand(selectOp, "r");
|
||||||
if (llOther) {
|
if (llOther) {
|
||||||
// Construct a new vecTy
|
auto storeVecSize = 4;
|
||||||
auto vecTy = LLVM::getFixedVectorType(resElemTy, numWordElems);
|
auto remStoreElems = numWordElems % storeVecSize;
|
||||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
auto constraint =
|
||||||
for (size_t i = 0; i < numWordElems; ++i) {
|
resElemTy.getIntOrFloatBitWidth() <= 32 ? "r" : "l";
|
||||||
Value falseVal = otherElems[elemIdx + wordElemIdx + i];
|
for (auto i = 0; i < numWordElems - remStoreElems;
|
||||||
Value indexVal = createIndexAttrConstant(
|
i += storeVecSize) {
|
||||||
rewriter, loc, this->getTypeConverter()->getIndexType(), i);
|
|
||||||
v = insert_element(vecTy, v, falseVal, indexVal);
|
|
||||||
}
|
|
||||||
v = bitcast(v, IntegerType::get(getContext(), bitWidth));
|
|
||||||
// Write shared memory if predicate is true
|
|
||||||
PTXBuilder ptxStoreBuilder;
|
PTXBuilder ptxStoreBuilder;
|
||||||
auto *valOperand = ptxStoreBuilder.newOperand(v, "r");
|
auto *valOperands = ptxStoreBuilder.newListOperand();
|
||||||
auto &st = *ptxStoreBuilder.create<PTXInstr>("st");
|
for (auto s = 0; s < storeVecSize; ++s) {
|
||||||
st.shared().o("b" + std::to_string(bitWidth));
|
auto value = otherElems[elemIdx + wordElemIdx + i + s];
|
||||||
st(dstOperand, valOperand).predicate(pred);
|
auto *opr = ptxStoreBuilder.newOperand(value, constraint);
|
||||||
|
valOperands->listAppend(opr);
|
||||||
|
}
|
||||||
|
auto *storeDstOperand = ptxStoreBuilder.newAddrOperand(
|
||||||
|
basePtr, "r", (wordElemIdx + i) * resByteWidth);
|
||||||
|
auto &st = ptxStoreBuilder.create<PTXInstr>("st")->shared();
|
||||||
|
st.v(storeVecSize).b(resElemTy.getIntOrFloatBitWidth());
|
||||||
|
st(storeDstOperand, valOperands).predicate(pred);
|
||||||
ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext()));
|
ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||||
}
|
}
|
||||||
|
for (auto i = numWordElems - remStoreElems; i < numWordElems; ++i) {
|
||||||
|
PTXBuilder ptxStoreBuilder;
|
||||||
|
auto value = otherElems[elemIdx + wordElemIdx + i];
|
||||||
|
auto *storeValOperand =
|
||||||
|
ptxStoreBuilder.newOperand(value, constraint);
|
||||||
|
auto *storeDstOperand = ptxStoreBuilder.newAddrOperand(
|
||||||
|
basePtr, "r", (wordElemIdx + i) * resByteWidth);
|
||||||
|
auto &st = ptxStoreBuilder.create<PTXInstr>("st")->shared();
|
||||||
|
st.b(resElemTy.getIntOrFloatBitWidth());
|
||||||
|
st(storeDstOperand, storeValOperand).predicate(pred);
|
||||||
|
ptxStoreBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
|
||||||
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
|
||||||
|
@@ -545,11 +545,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
|||||||
%other_tensor = tt.splat %other : (f32) -> tensor<16x64xf32, #AL>
|
%other_tensor = tt.splat %other : (f32) -> tensor<16x64xf32, #AL>
|
||||||
|
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: @${{.*}} st.shared.b128 [ ${{.*}} + 0 ], ${{.*}}
|
// CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 0 ]
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: @${{.*}} st.shared.b128 [ ${{.*}} + 16 ], ${{.*}}
|
// CHECK-SAME: @${{.*}} st.shared.v4.b32 [ ${{.*}} + 16 ]
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
// CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, ${{.*}}
|
||||||
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
|
||||||
|
Reference in New Issue
Block a user