#include "triton/Analysis/Utility.h" #include "mlir/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir { bool ReduceOpHelper::isFastReduction() { auto srcLayout = srcTy.getEncoding(); auto axis = op.axis(); return axis == triton::gpu::getOrder(srcLayout)[0]; } unsigned ReduceOpHelper::getInterWarpSize() { auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto axis = op.axis(); auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, triton::gpu::getWarpsPerCTA(srcLayout)[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); auto axis = op.axis(); auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, triton::gpu::getThreadsPerWarp(srcLayout)[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = srcTy.getEncoding(); auto axis = op.axis(); return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * triton::gpu::getWarpsPerCTA(srcLayout)[axis]; } bool isSharedEncoding(Value value) { auto type = value.getType(); if (auto tensorType = type.dyn_cast()) { auto encoding = tensorType.getEncoding(); return encoding && encoding.isa(); } return false; } bool maybeSharedAllocationOp(Operation *op) { // TODO(Keren): This function can be replaced by adding // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to // query the memory effects of the op. auto *dialect = op->getDialect(); return dialect && (dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get() || dialect->getTypeID() == mlir::TypeID::get()); } bool maybeAliasOp(Operation *op) { return isa(op) || isa(op) || isa(op); } std::string getValueOperandName(Value value, AsmState &state) { std::string opName; llvm::raw_string_ostream ss(opName); value.printAsOperand(ss, state); return opName; } } // namespace mlir