[Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924)
This commit is contained in:
@@ -58,11 +58,6 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
smemShapes[1].push_back(numWarps * 32);
|
||||
|
||||
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
|
||||
/// always smaller than smemShapes[0] shared memory block2
|
||||
smemShapes[2] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[2].erase(smemShapes[2].begin() + axis);
|
||||
|
||||
return smemShapes;
|
||||
}
|
||||
|
||||
|
@@ -43,6 +43,8 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::LLVM::shflSync;
|
||||
using ::mlir::LLVM::storeShared;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
@@ -307,20 +309,6 @@ T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
||||
reorder(shape, order));
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
auto &st = builder.create<>("st")->shared().b(bits);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
@@ -1342,9 +1330,6 @@ private:
|
||||
RedOp redOp, Value &acc, Value &accIndex, Value cur,
|
||||
Value curIndex, bool isFirst) const;
|
||||
|
||||
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
||||
int i) const;
|
||||
|
||||
// Use shared memory for reduction within warps and across warps
|
||||
LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
@@ -1472,34 +1457,6 @@ void ReduceOpConversion::accumulateWithIndex(
|
||||
}
|
||||
}
|
||||
|
||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value val, int i) const {
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(val, vecTy);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = shflSync(rewriter, loc, val0, i);
|
||||
val1 = shflSync(rewriter, loc, val1, i);
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(vec, val.getType());
|
||||
}
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||
auto *dOpr = builder.newOperand("=r");
|
||||
auto *aOpr = builder.newOperand(val, "r");
|
||||
auto *bOpr = builder.newConstantOperand(i);
|
||||
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
}
|
||||
|
||||
LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||
triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@@ -1665,7 +1622,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
auto smemShapes = helper.getScratchConfigsFast();
|
||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||
maxElems = std::max(maxElems, product<unsigned>(smemShapes[2]));
|
||||
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
|
||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||
|
||||
@@ -1725,11 +1681,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
|
||||
// reduce within warps
|
||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
Value shfl = shflSync(loc, rewriter, acc, N);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
} else {
|
||||
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
||||
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||
shflIndex, false);
|
||||
}
|
||||
@@ -1750,8 +1706,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
barrier();
|
||||
|
||||
// the second round of shuffle reduction
|
||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn =>
|
||||
// 1, s1, s2, .. , sn
|
||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn
|
||||
// where sizeInterWarps is 2^m
|
||||
//
|
||||
// each thread needs to process:
|
||||
@@ -1762,6 +1717,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
Value readOffset = threadId;
|
||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
// FIXME(Qingyi): need predicate icmp_slt(threadId, i32_val(sizeInerWarps))
|
||||
Value acc = load(readPtr);
|
||||
Value accIndex;
|
||||
if (withIndex) {
|
||||
@@ -1770,17 +1726,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
}
|
||||
|
||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
||||
Value shfl = shflSync(loc, rewriter, acc, N);
|
||||
if (!withIndex) {
|
||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||
} else {
|
||||
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
||||
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
|
||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||
shflIndex, false);
|
||||
}
|
||||
}
|
||||
|
||||
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
||||
// only the first thread in each sizeInterWarps is writing
|
||||
Value writeOffset = readOffset;
|
||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||
@@ -1807,22 +1764,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
// nd-tensor where n >= 1
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
SmallVector<unsigned> resultOrd;
|
||||
for (auto ord : order) {
|
||||
if (ord != 0)
|
||||
resultOrd.push_back(ord - 1);
|
||||
}
|
||||
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices =
|
||||
emitIndices(loc, rewriter, resultLayout, resultTy.getShape());
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
for (size_t i = 0; i < resultElems; ++i) {
|
||||
SmallVector<Value> readIdx = resultIndices[i];
|
||||
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||
Value readOffset =
|
||||
linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd);
|
||||
linearize(rewriter, loc, readIdx, smemShapes[0], order);
|
||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
|
@@ -269,6 +269,48 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||
}
|
||||
|
||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value val, Value pred) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||
auto *valOpr = builder.newOperand(val, c);
|
||||
auto &st = builder.create<>("st")->shared().b(bits);
|
||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||
}
|
||||
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i) {
|
||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||
|
||||
if (bits == 64) {
|
||||
Type vecTy = vec_ty(f32_ty, 2);
|
||||
Value vec = bitcast(val, vecTy);
|
||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||
val0 = shflSync(loc, rewriter, val0, i);
|
||||
val1 = shflSync(loc, rewriter, val1, i);
|
||||
vec = undef(vecTy);
|
||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||
return bitcast(vec, val.getType());
|
||||
}
|
||||
|
||||
PTXBuilder builder;
|
||||
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||
auto *dOpr = builder.newOperand("=r");
|
||||
auto *aOpr = builder.newOperand(val, "r");
|
||||
auto *bOpr = builder.newConstantOperand(i);
|
||||
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
|
Reference in New Issue
Block a user