[BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69)
* [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering * Clean code Co-authored-by: goostavz <gzhu@nvidia.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
@@ -309,10 +309,10 @@ public:
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
SmallVector<Value, 4>
|
||||
SmallVector<Value>
|
||||
getElementsFromStruct(Location loc, Value llvmStruct, unsigned elems,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
SmallVector<Value, 4> results(elems);
|
||||
SmallVector<Value> results(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
Type type =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody()[i];
|
||||
@@ -710,7 +710,7 @@ struct StoreOpConversion
|
||||
|
||||
PtxIOInstr asmStoreInstr("st");
|
||||
asmStoreInstr.predicate(maskElems[vecIdx], "b");
|
||||
asmStoreInstr.global().v(width).b(nWords);
|
||||
asmStoreInstr.global().b(width).v(nWords);
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
@@ -970,7 +970,10 @@ struct LoadOpConversion
|
||||
unsigned numElems = getElemsPerThread(blockedLayout, shape);
|
||||
auto ptrVals = getElementsFromStruct(loc, ptr, numElems, rewriter);
|
||||
auto maskVals = getElementsFromStruct(loc, mask, numElems, rewriter);
|
||||
auto otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
||||
SmallVector<Value> otherVals;
|
||||
if (other != nullptr) {
|
||||
otherVals = getElementsFromStruct(loc, other, numElems, rewriter);
|
||||
}
|
||||
unsigned nbits = elemTy.isa<FloatType>()
|
||||
? elemTy.cast<FloatType>().getWidth()
|
||||
: elemTy.cast<IntegerType>().getWidth();
|
||||
@@ -1039,31 +1042,33 @@ struct LoadOpConversion
|
||||
asmOss << ", $" << n_words + 2;
|
||||
asmOss << ";";
|
||||
SmallVector<Value> others;
|
||||
for (size_t ii = 0; ii < n_words; ii++) {
|
||||
size_t size = width / nbits;
|
||||
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (size_t s = 0; s < size; s++) {
|
||||
Value falseVal = otherVals[i + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
|
||||
sVal);
|
||||
if (other != nullptr) {
|
||||
for (size_t ii = 0; ii < n_words; ii++) {
|
||||
size_t size = width / nbits;
|
||||
auto vecTy = LLVM::getFixedVectorType(elemTy, size);
|
||||
Value v = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (size_t s = 0; s < size; s++) {
|
||||
Value falseVal = otherVals[i + ii * size + s];
|
||||
Value sVal = createIndexAttrConstant(
|
||||
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
|
||||
v = rewriter.create<LLVM::InsertElementOp>(loc, vecTy, v, falseVal,
|
||||
sVal);
|
||||
}
|
||||
v = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(getContext(), width), v);
|
||||
asmOss << "\n ";
|
||||
asmOss << "@!$" << n_words << " mov.u" << width;
|
||||
asmOss << " $" << ii << ", ";
|
||||
std::ios_base::fmtflags flags(asmOss.flags());
|
||||
if (otherIsSplatConstInt)
|
||||
asmOss << "0x" << std::hex << splatVal;
|
||||
else {
|
||||
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
||||
others.push_back(v);
|
||||
}
|
||||
asmOss.flags(flags);
|
||||
asmOss << ";";
|
||||
}
|
||||
v = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(getContext(), width), v);
|
||||
asmOss << "\n ";
|
||||
asmOss << "@!$" << n_words << " mov.u" << width;
|
||||
asmOss << " $" << ii << ", ";
|
||||
std::ios_base::fmtflags flags(asmOss.flags());
|
||||
if (otherIsSplatConstInt)
|
||||
asmOss << "0x" << std::hex << splatVal;
|
||||
else {
|
||||
asmOss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
||||
others.push_back(v);
|
||||
}
|
||||
asmOss.flags(flags);
|
||||
asmOss << ";";
|
||||
}
|
||||
// ---
|
||||
// create inline ASM signature
|
||||
|
Reference in New Issue
Block a user