[Analysis] Added Axis Info Analysis (#8)

This commit is contained in:
Philippe Tillet
2022-07-19 13:38:48 -07:00
committed by GitHub
parent df940aaab0
commit a633d2b403
20 changed files with 582 additions and 13 deletions

View File

@@ -0,0 +1,145 @@
#ifndef TRITON_ANALYSIS_AXISINFO_H
#define TRITON_ANALYSIS_AXISINFO_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
/// This lattice value represents known information on the axes of a lattice.
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef std::vector<int> ContiguityT;
typedef std::vector<int> DivisibilityT;
typedef std::vector<int> ConstancyT;
public:
// Default constructor
AxisInfo(): AxisInfo({}, {}, {}) { }
// Construct contiguity info with known contiguity
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
ConstancyT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank);
assert(knownConstancy.size() == rank);
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const ContiguityT& getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DivisibilityT& getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const ConstancyT& getConstancy() const { return constancy; }
int getRank() const { return rank; }
// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&
(divisibility == other.divisibility) &&
(constancy == other.constancy);
}
/// The pessimistic value state of the contiguity is unknown.
static AxisInfo getPessimisticValueState(MLIRContext *context)
{ return AxisInfo(); }
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs,
const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of contiguous integers along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
/// Would have contiguity [1, 4].
/// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
/// [18, 22, 26, 30]
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
ContiguityT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
// would have divisibility [1, 2]
// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
DivisibilityT divisibility;
/// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of constant integer along it. This is
/// particularly useful to infer the contiguity
/// of operations (e.g., add) involving a constant
/// For example
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
ConstancyT constancy;
// number of dimensions of the lattice
int rank;
};
class AxisInfoAnalysis
: public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n){
if(n==0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
ChangeResult visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
};
}
#endif

View File

@@ -33,6 +33,8 @@ def Triton_Dialect : Dialect {
let extraClassDeclaration = [{
void registerTypes();
}];
let hasConstantMaterializer = 1;
}
#endif // TRITON_DIALECT

View File

@@ -143,8 +143,20 @@ def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementT
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "broadcast";
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);

View File

@@ -59,6 +59,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
}
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
// This is needed because Arith's Cmp ops don't
// handle encodings
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi"> {
let summary = "integer comparison operation";