[Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924)
This commit is contained in:
@@ -269,6 +269,48 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
auto &st = builder.create<>("st")->shared().b(bits);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i) {
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(val, vecTy);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = shflSync(loc, rewriter, val0, i);
|
||||
val1 = shflSync(loc, rewriter, val1, i);
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(vec, val.getType());
|
||||
}
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||
auto *dOpr = builder.newOperand("=r");
|
||||
auto *aOpr = builder.newOperand(val, "r");
|
||||
auto *bOpr = builder.newConstantOperand(i);
|
||||
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
|
Reference in New Issue
Block a user