[Analysis] Added Axis Info Analysis (#8)
This commit is contained in:
145
include/triton/Analysis/AxisInfo.h
Normal file
145
include/triton/Analysis/AxisInfo.h
Normal file
@@ -0,0 +1,145 @@
|
||||
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
||||
#define TRITON_ANALYSIS_AXISINFO_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
/// Axis information is represented by a std::map<int, int>
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef std::vector<int> ContiguityT;
|
||||
typedef std::vector<int> DivisibilityT;
|
||||
typedef std::vector<int> ConstancyT;
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
AxisInfo(): AxisInfo({}, {}, {}) { }
|
||||
// Construct contiguity info with known contiguity
|
||||
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
|
||||
ConstancyT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == rank);
|
||||
assert(knownConstancy.size() == rank);
|
||||
}
|
||||
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const ContiguityT& getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
const DivisibilityT& getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const ConstancyT& getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
// Comparison
|
||||
bool operator==(const AxisInfo &other) const {
|
||||
return (contiguity == other.contiguity) &&
|
||||
(divisibility == other.divisibility) &&
|
||||
(constancy == other.constancy);
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context)
|
||||
{ return AxisInfo(); }
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of contiguous integers along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
/// Would have contiguity [1, 4].
|
||||
/// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
/// [18, 22, 26, 30]
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
ContiguityT contiguity;
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
// would have divisibility [1, 2]
|
||||
// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
DivisibilityT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of constant integer along it. This is
|
||||
/// particularly useful to infer the contiguity
|
||||
/// of operations (e.g., add) involving a constant
|
||||
/// For example
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
ConstancyT constancy;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
};
|
||||
|
||||
|
||||
class AxisInfoAnalysis
|
||||
: public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
|
||||
private:
|
||||
static const int maxPow2Divisor = 65536;
|
||||
|
||||
int highestPowOf2Divisor(int n){
|
||||
if(n==0)
|
||||
return maxPow2Divisor;
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
|
||||
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
ChangeResult visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -33,6 +33,8 @@ def Triton_Dialect : Dialect {
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
#endif // TRITON_DIALECT
|
||||
|
@@ -143,8 +143,20 @@ def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementT
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast";
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
|
@@ -59,6 +59,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||
// This is needed because Arith's Cmp ops don't
|
||||
// handle encodings
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
|
Reference in New Issue
Block a user