[Triton-MLIR][BACKEND] Add argmin / argmax implementation for ReduceOp (#918)
This commit is contained in:
@@ -88,25 +88,6 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
return paddedRepShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getScratchConfigForReduce(triton::ReduceOp op) {
|
||||
ReduceOpHelper helper(op);
|
||||
|
||||
SmallVector<unsigned> smemShape;
|
||||
auto srcShape = helper.getSrcShape();
|
||||
for (auto d : srcShape)
|
||||
smemShape.push_back(d);
|
||||
|
||||
auto axis = op.axis();
|
||||
if (helper.isFastReduction()) {
|
||||
smemShape[axis] = helper.getInterWarpSize();
|
||||
} else {
|
||||
smemShape[axis] =
|
||||
std::min(smemShape[axis], helper.getThreadsReductionAxis());
|
||||
}
|
||||
|
||||
return smemShape;
|
||||
}
|
||||
|
||||
// TODO: extend beyond scalars
|
||||
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
|
||||
SmallVector<unsigned> smemShape;
|
||||
@@ -173,21 +154,9 @@ private:
|
||||
/// Initializes temporary shared memory for a given operation.
|
||||
void getScratchValueSize(Operation *op) {
|
||||
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
bool fastReduce = ReduceOpHelper(reduceOp).isFastReduction();
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
if (fastReduce) {
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
elems = std::max<unsigned>(elems, numWarps * 32);
|
||||
}
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
ReduceOpHelper helper(reduceOp);
|
||||
unsigned bytes = helper.getScratchSizeInBytes();
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
|
||||
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
|
||||
|
Reference in New Issue
Block a user