[Triton-MLIR][BACKEND] Minor fixes of shared memory in ReduceOpConversion (#924)
This commit is contained in:
@@ -58,11 +58,6 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
|||||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
smemShapes[1].push_back(numWarps * 32);
|
smemShapes[1].push_back(numWarps * 32);
|
||||||
|
|
||||||
/// FIXME(Qingyi): This requirement is actually not necessary, because it is
|
|
||||||
/// always smaller than smemShapes[0] shared memory block2
|
|
||||||
smemShapes[2] = convertType<unsigned>(getSrcShape());
|
|
||||||
smemShapes[2].erase(smemShapes[2].begin() + axis);
|
|
||||||
|
|
||||||
return smemShapes;
|
return smemShapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -43,6 +43,8 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
|||||||
using ::mlir::LLVM::getStructFromElements;
|
using ::mlir::LLVM::getStructFromElements;
|
||||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||||
using ::mlir::LLVM::SharedMemoryObject;
|
using ::mlir::LLVM::SharedMemoryObject;
|
||||||
|
using ::mlir::LLVM::shflSync;
|
||||||
|
using ::mlir::LLVM::storeShared;
|
||||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||||
using ::mlir::triton::gpu::getElemsPerThread;
|
using ::mlir::triton::gpu::getElemsPerThread;
|
||||||
@@ -307,20 +309,6 @@ T getLinearIndex(ArrayRef<T> multiDimIndex, ArrayRef<T> shape,
|
|||||||
reorder(shape, order));
|
reorder(shape, order));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
|
||||||
Value val, Value pred) {
|
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
|
||||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
|
||||||
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
|
||||||
|
|
||||||
PTXBuilder builder;
|
|
||||||
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
|
||||||
auto *valOpr = builder.newOperand(val, c);
|
|
||||||
auto &st = builder.create<>("st")->shared().b(bits);
|
|
||||||
st(ptrOpr, valOpr).predicate(pred, "b");
|
|
||||||
return builder.launch(rewriter, loc, void_ty(ctx));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ConvertTritonGPUOpToLLVMPatternBase {
|
struct ConvertTritonGPUOpToLLVMPatternBase {
|
||||||
static Value
|
static Value
|
||||||
getStructFromSharedMemoryObject(Location loc,
|
getStructFromSharedMemoryObject(Location loc,
|
||||||
@@ -1342,9 +1330,6 @@ private:
|
|||||||
RedOp redOp, Value &acc, Value &accIndex, Value cur,
|
RedOp redOp, Value &acc, Value &accIndex, Value cur,
|
||||||
Value curIndex, bool isFirst) const;
|
Value curIndex, bool isFirst) const;
|
||||||
|
|
||||||
Value shflSync(ConversionPatternRewriter &rewriter, Location loc, Value val,
|
|
||||||
int i) const;
|
|
||||||
|
|
||||||
// Use shared memory for reduction within warps and across warps
|
// Use shared memory for reduction within warps and across warps
|
||||||
LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewriteBasic(triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const;
|
ConversionPatternRewriter &rewriter) const;
|
||||||
@@ -1472,34 +1457,6 @@ void ReduceOpConversion::accumulateWithIndex(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value ReduceOpConversion::shflSync(ConversionPatternRewriter &rewriter,
|
|
||||||
Location loc, Value val, int i) const {
|
|
||||||
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
|
||||||
|
|
||||||
if (bits == 64) {
|
|
||||||
Type vecTy = vec_ty(f32_ty, 2);
|
|
||||||
Value vec = bitcast(val, vecTy);
|
|
||||||
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
|
||||||
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
|
||||||
val0 = shflSync(rewriter, loc, val0, i);
|
|
||||||
val1 = shflSync(rewriter, loc, val1, i);
|
|
||||||
vec = undef(vecTy);
|
|
||||||
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
|
||||||
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
|
||||||
return bitcast(vec, val.getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
PTXBuilder builder;
|
|
||||||
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
|
||||||
auto *dOpr = builder.newOperand("=r");
|
|
||||||
auto *aOpr = builder.newOperand(val, "r");
|
|
||||||
auto *bOpr = builder.newConstantOperand(i);
|
|
||||||
auto *cOpr = builder.newConstantOperand("0x1f");
|
|
||||||
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
|
||||||
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
|
||||||
return builder.launch(rewriter, loc, val.getType(), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
||||||
triton::ReduceOp op, OpAdaptor adaptor,
|
triton::ReduceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
@@ -1665,7 +1622,6 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
auto smemShapes = helper.getScratchConfigsFast();
|
auto smemShapes = helper.getScratchConfigsFast();
|
||||||
unsigned elems = product<unsigned>(smemShapes[0]);
|
unsigned elems = product<unsigned>(smemShapes[0]);
|
||||||
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));
|
||||||
maxElems = std::max(maxElems, product<unsigned>(smemShapes[2]));
|
|
||||||
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
|
Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems));
|
||||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||||
|
|
||||||
@@ -1725,11 +1681,11 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
|
|
||||||
// reduce within warps
|
// reduce within warps
|
||||||
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) {
|
||||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
Value shfl = shflSync(loc, rewriter, acc, N);
|
||||||
if (!withIndex) {
|
if (!withIndex) {
|
||||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||||
} else {
|
} else {
|
||||||
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
|
||||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||||
shflIndex, false);
|
shflIndex, false);
|
||||||
}
|
}
|
||||||
@@ -1750,8 +1706,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
// the second round of shuffle reduction
|
// the second round of shuffle reduction
|
||||||
// now the problem size: sizeInterWarps, s1, s2, .. , sn =>
|
// now the problem size: sizeInterWarps, s1, s2, .. , sn
|
||||||
// 1, s1, s2, .. , sn
|
|
||||||
// where sizeInterWarps is 2^m
|
// where sizeInterWarps is 2^m
|
||||||
//
|
//
|
||||||
// each thread needs to process:
|
// each thread needs to process:
|
||||||
@@ -1762,6 +1717,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
Value readOffset = threadId;
|
Value readOffset = threadId;
|
||||||
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
for (unsigned round = 0; round < elemsPerThread; ++round) {
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
|
// FIXME(Qingyi): need predicate icmp_slt(threadId, i32_val(sizeInerWarps))
|
||||||
Value acc = load(readPtr);
|
Value acc = load(readPtr);
|
||||||
Value accIndex;
|
Value accIndex;
|
||||||
if (withIndex) {
|
if (withIndex) {
|
||||||
@@ -1770,17 +1726,18 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) {
|
||||||
Value shfl = shflSync(rewriter, loc, acc, N);
|
Value shfl = shflSync(loc, rewriter, acc, N);
|
||||||
if (!withIndex) {
|
if (!withIndex) {
|
||||||
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
accumulate(rewriter, loc, op.redOp(), acc, shfl, false);
|
||||||
} else {
|
} else {
|
||||||
Value shflIndex = shflSync(rewriter, loc, accIndex, N);
|
Value shflIndex = shflSync(loc, rewriter, accIndex, N);
|
||||||
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
accumulateWithIndex(rewriter, loc, op.redOp(), acc, accIndex, shfl,
|
||||||
shflIndex, false);
|
shflIndex, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value writeOffset = udiv(readOffset, i32_val(sizeInterWarps));
|
// only the first thread in each sizeInterWarps is writing
|
||||||
|
Value writeOffset = readOffset;
|
||||||
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
Value writePtr = gep(elemPtrTy, smemBase, writeOffset);
|
||||||
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));
|
||||||
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps));
|
||||||
@@ -1807,22 +1764,17 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
if (auto resultTy = op.getType().dyn_cast<RankedTensorType>()) {
|
||||||
// nd-tensor where n >= 1
|
// nd-tensor where n >= 1
|
||||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||||
SmallVector<unsigned> resultOrd;
|
auto resultShape = resultTy.getShape();
|
||||||
for (auto ord : order) {
|
|
||||||
if (ord != 0)
|
|
||||||
resultOrd.push_back(ord - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned resultElems = getElemsPerThread(resultTy);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
auto resultIndices =
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||||
emitIndices(loc, rewriter, resultLayout, resultTy.getShape());
|
|
||||||
assert(resultIndices.size() == resultElems);
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
for (size_t i = 0; i < resultElems; ++i) {
|
for (size_t i = 0; i < resultElems; ++i) {
|
||||||
SmallVector<Value> readIdx = resultIndices[i];
|
SmallVector<Value> readIdx = resultIndices[i];
|
||||||
|
readIdx.insert(readIdx.begin() + axis, i32_val(0));
|
||||||
Value readOffset =
|
Value readOffset =
|
||||||
linearize(rewriter, loc, readIdx, smemShapes[2], resultOrd);
|
linearize(rewriter, loc, readIdx, smemShapes[0], order);
|
||||||
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
Value readPtr = gep(elemPtrTy, smemBase, readOffset);
|
||||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||||
|
@@ -269,6 +269,48 @@ getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
|||||||
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||||
|
Value val, Value pred) {
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
||||||
|
auto *valOpr = builder.newOperand(val, c);
|
||||||
|
auto &st = builder.create<>("st")->shared().b(bits);
|
||||||
|
st(ptrOpr, valOpr).predicate(pred, "b");
|
||||||
|
return builder.launch(rewriter, loc, void_ty(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||||
|
int i) {
|
||||||
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
if (bits == 64) {
|
||||||
|
Type vecTy = vec_ty(f32_ty, 2);
|
||||||
|
Value vec = bitcast(val, vecTy);
|
||||||
|
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
||||||
|
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
||||||
|
val0 = shflSync(loc, rewriter, val0, i);
|
||||||
|
val1 = shflSync(loc, rewriter, val1, i);
|
||||||
|
vec = undef(vecTy);
|
||||||
|
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
||||||
|
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
||||||
|
return bitcast(vec, val.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
||||||
|
auto *dOpr = builder.newOperand("=r");
|
||||||
|
auto *aOpr = builder.newOperand(val, "r");
|
||||||
|
auto *bOpr = builder.newConstantOperand(i);
|
||||||
|
auto *cOpr = builder.newConstantOperand("0x1f");
|
||||||
|
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
||||||
|
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
||||||
|
return builder.launch(rewriter, loc, val.getType(), false);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user