[Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924)

This commit is contained in:
Qingyi Liu
2022-11-29 11:50:31 +08:00
committed by GitHub
parent c87fbf886e
commit 661be523c0
3 changed files with 56 additions and 67 deletions

View File

@@ -58,11 +58,6 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32); smemShapes[1].push_back(numWarps * 32);
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
/// always smaller than smemShapes[0] shared memory block2
smemShapes[2] = convertType<unsigned>(getSrcShape());
smemShapes[2].erase(smemShapes[2].begin() + axis);
return smemShapes; return smemShapes;
} }

View File

@@ -43,6 +43,8 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::getStructFromElements;
using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::LLVM::MMA16816ConversionHelper;
using ::mlir::LLVM::SharedMemoryObject; using ::mlir::LLVM::SharedMemoryObject;
using ::mlir::LLVM::shflSync;
using ::mlir::LLVM::storeShared;
using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getElemsPerThread;
@@ -307,20 +309,6 @@ T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
reorder(shape, order)); reorder(shape, order));
} }
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
PTXBuilder builder;
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
auto &st = builder.create<>("st")->shared().b(bits);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty(ctx));
}
struct ConvertTritonGPUOpToLLVMPatternBase { struct ConvertTritonGPUOpToLLVMPatternBase {
static Value static Value
getStructFromSharedMemoryObject(Location loc, getStructFromSharedMemoryObject(Location loc,
@@ -1342,9 +1330,6 @@ private:
RedOp redOp, Value &acc, Value &accIndex, Value cur, RedOp redOp, Value &acc, Value &accIndex, Value cur,
Value curIndex, bool isFirst) const; Value curIndex, bool isFirst) const;
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
int i) const;
// Use shared memory for reduction within warps and across warps // Use shared memory for reduction within warps and across warps
LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor, LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const; ConversionPatternRewriter &rewriter) const;
@@ -1472,34 +1457,6 @@ void ReduceOpConversion::accumulateWithIndex(
} }
} }
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
Location loc, Value val, int i) const {
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {
Type vecTy = vec_ty(f32_ty, 2);
Value vec = bitcast(val, vecTy);
Value val0 = extract_element(f32_ty, vec, i32_val(0));
Value val1 = extract_element(f32_ty, vec, i32_val(1));
val0 = shflSync(rewriter, loc, val0, i);
val1 = shflSync(rewriter, loc, val1, i);
vec = undef(vecTy);
vec = insert_element(vecTy, vec, val0, i32_val(0));
vec = insert_element(vecTy, vec, val1, i32_val(1));
return bitcast(vec, val.getType());
}
PTXBuilder builder;
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
auto *dOpr = builder.newOperand("=r");
auto *aOpr = builder.newOperand(val, "r");
auto *bOpr = builder.newConstantOperand(i);
auto *cOpr = builder.newConstantOperand("0x1f");
auto *maskOpr = builder.newConstantOperand("0xffffffff");
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
return builder.launch(rewriter, loc, val.getType(), false);
}
LogicalResult ReduceOpConversion::matchAndRewriteBasic( LogicalResult ReduceOpConversion::matchAndRewriteBasic(
triton::ReduceOp op, OpAdaptor adaptor, triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
@@ -1665,7 +1622,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
auto smemShapes = helper.getScratchConfigsFast(); auto smemShapes = helper.getScratchConfigsFast();
unsigned elems = product<unsigned>(smemShapes[0]); unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1])); unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
maxElems = std::max(maxElems, product<unsigned>(smemShapes[2]));
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
indexSmemBase = bitcast(indexSmemBase, indexPtrTy); indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
@@ -1725,11 +1681,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
// reduce within warps // reduce within warps
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(rewriter, loc, acc, N); Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) { if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false); accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else { } else {
Value shflIndex = shflSync(rewriter, loc, accIndex, N); Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false); shflIndex, false);
} }
@@ -1750,8 +1706,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
barrier(); barrier();
// the second round of shuffle reduction // the second round of shuffle reduction
// now the problem size: sizeInterWarps, s1, s2, .. , sn => // now the problem size: sizeInterWarps, s1, s2, .. , sn
// 1, s1, s2, .. , sn
// where sizeInterWarps is 2^m // where sizeInterWarps is 2^m
// //
// each thread needs to process: // each thread needs to process:
@@ -1762,6 +1717,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
Value readOffset = threadId; Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) { for (unsigned round = 0; round < elemsPerThread; ++round) {
Value readPtr = gep(elemPtrTy, smemBase, readOffset); Value readPtr = gep(elemPtrTy, smemBase, readOffset);
// FIXME(Qingyi): need predicate icmp_slt(threadId, i32_val(sizeInerWarps))
Value acc = load(readPtr); Value acc = load(readPtr);
Value accIndex; Value accIndex;
if (withIndex) { if (withIndex) {
@@ -1770,17 +1726,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
} }
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
Value shfl = shflSync(rewriter, loc, acc, N); Value shfl = shflSync(loc, rewriter, acc, N);
if (!withIndex) { if (!withIndex) {
accumulate(rewriter, loc, op.redOp(), acc, shfl, false); accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
} else { } else {
Value shflIndex = shflSync(rewriter, loc, accIndex, N); Value shflIndex = shflSync(loc, rewriter, accIndex, N);
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl, accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
shflIndex, false); shflIndex, false);
} }
} }
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps)); // only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
Value writePtr = gep(elemPtrTy, smemBase, writeOffset); Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
@@ -1807,22 +1764,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) { if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
// nd-tensor where n >= 1 // nd-tensor where n >= 1
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>(); auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
SmallVector<unsigned> resultOrd; auto resultShape = resultTy.getShape();
for (auto ord : order) {
if (ord != 0)
resultOrd.push_back(ord - 1);
}
unsigned resultElems = getElemsPerThread(resultTy); unsigned resultElems = getElemsPerThread(resultTy);
auto resultIndices = auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
emitIndices(loc, rewriter, resultLayout, resultTy.getShape());
assert(resultIndices.size() == resultElems); assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems); SmallVector<Value> resultVals(resultElems);
for (size_t i = 0; i < resultElems; ++i) { for (size_t i = 0; i < resultElems; ++i) {
SmallVector<Value> readIdx = resultIndices[i]; SmallVector<Value> readIdx = resultIndices[i];
readIdx.insert(readIdx.begin() + axis, i32_val(0));
Value readOffset = Value readOffset =
linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd); linearize(rewriter, loc, readIdx, smemShapes[0], order);
Value readPtr = gep(elemPtrTy, smemBase, readOffset); Value readPtr = gep(elemPtrTy, smemBase, readOffset);
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr); resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);

View File

@@ -269,6 +269,48 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; /*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
} }
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
PTXBuilder builder;
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
auto *valOpr = builder.newOperand(val, c);
auto &st = builder.create<>("st")->shared().b(bits);
st(ptrOpr, valOpr).predicate(pred, "b");
return builder.launch(rewriter, loc, void_ty(ctx));
}
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
int i) {
unsigned bits = val.getType().getIntOrFloatBitWidth();
if (bits == 64) {
Type vecTy = vec_ty(f32_ty, 2);
Value vec = bitcast(val, vecTy);
Value val0 = extract_element(f32_ty, vec, i32_val(0));
Value val1 = extract_element(f32_ty, vec, i32_val(1));
val0 = shflSync(loc, rewriter, val0, i);
val1 = shflSync(loc, rewriter, val1, i);
vec = undef(vecTy);
vec = insert_element(vecTy, vec, val0, i32_val(0));
vec = insert_element(vecTy, vec, val1, i32_val(1));
return bitcast(vec, val.getType());
}
PTXBuilder builder;
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
auto *dOpr = builder.newOperand("=r");
auto *aOpr = builder.newOperand(val, "r");
auto *bOpr = builder.newConstantOperand(i);
auto *cOpr = builder.newConstantOperand("0x1f");
auto *maskOpr = builder.newConstantOperand("0xffffffff");
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
return builder.launch(rewriter, loc, val.getType(), false);
}
} // namespace LLVM } // namespace LLVM
} // namespace mlir } // namespace mlir