[Triton-MLIR] Add ex2.approx implementation for ExpOp and fix smem allocation for ReduceOpConversion (#875)
This commit is contained in:
@@ -161,9 +161,16 @@ private:
|
||||
// TODO(Keren): Reduce with index is not supported yet.
|
||||
auto value = op->getOperand(0);
|
||||
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto srcLayout = tensorType.getEncoding();
|
||||
bool fastReduce = reduceOp.axis() == getOrder(srcLayout)[0];
|
||||
auto smemShape = getScratchConfigForReduce(reduceOp);
|
||||
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
|
||||
std::multiplies{});
|
||||
if (fastReduce) {
|
||||
auto mod = op->getParentOfType<ModuleOp>();
|
||||
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
elems = std::max<unsigned>(elems, numWarps * 32);
|
||||
}
|
||||
auto bytes = elems * tensorType.getElementTypeBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
}
|
||||
|
Reference in New Issue
Block a user