[Analysis] Added Axis Info Analysis (#8)
This commit is contained in:
@@ -183,6 +183,7 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
|||||||
target_link_libraries(triton
|
target_link_libraries(triton
|
||||||
${PYTHON_LIBRARIES}
|
${PYTHON_LIBRARIES}
|
||||||
TritonIR
|
TritonIR
|
||||||
|
TritonAnalysis
|
||||||
TritonTransforms
|
TritonTransforms
|
||||||
TritonDriver
|
TritonDriver
|
||||||
TritonToTritonGPU
|
TritonToTritonGPU
|
||||||
|
@@ -10,10 +10,14 @@ add_llvm_executable(triton-opt triton-opt.cpp)
|
|||||||
# TODO: what's this?
|
# TODO: what's this?
|
||||||
# llvm_update_compile_flags(triton-opt)
|
# llvm_update_compile_flags(triton-opt)
|
||||||
target_link_libraries(triton-opt PRIVATE
|
target_link_libraries(triton-opt PRIVATE
|
||||||
|
TritonAnalysis
|
||||||
TritonTransforms
|
TritonTransforms
|
||||||
TritonGPUTransforms
|
TritonGPUTransforms
|
||||||
${dialect_libs}
|
${dialect_libs}
|
||||||
${conversion_libs}
|
${conversion_libs}
|
||||||
|
# tests
|
||||||
|
TritonTestAnalysis
|
||||||
|
# MLIR core
|
||||||
MLIROptLib
|
MLIROptLib
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MLIRTransforms
|
MLIRTransforms
|
||||||
|
@@ -10,11 +10,17 @@
|
|||||||
#include "mlir/InitAllPasses.h"
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir/Support/MlirOptMain.h"
|
#include "mlir/Support/MlirOptMain.h"
|
||||||
|
|
||||||
|
namespace mlir{
|
||||||
|
namespace test{
|
||||||
|
void registerTestAlignmentPass();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
mlir::registerTritonPasses();
|
mlir::registerTritonPasses();
|
||||||
mlir::registerTritonGPUPasses();
|
mlir::registerTritonGPUPasses();
|
||||||
|
mlir::test::registerTestAlignmentPass();
|
||||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||||
|
|
||||||
// TODO: register Triton & TritonGPU passes
|
// TODO: register Triton & TritonGPU passes
|
||||||
|
145
include/triton/Analysis/AxisInfo.h
Normal file
145
include/triton/Analysis/AxisInfo.h
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
||||||
|
#define TRITON_ANALYSIS_AXISINFO_H
|
||||||
|
|
||||||
|
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AxisInfo
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// This lattice value represents known information on the axes of a lattice.
|
||||||
|
/// Axis information is represented by a std::map<int, int>
|
||||||
|
class AxisInfo {
|
||||||
|
public:
|
||||||
|
typedef std::vector<int> ContiguityT;
|
||||||
|
typedef std::vector<int> DivisibilityT;
|
||||||
|
typedef std::vector<int> ConstancyT;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Default constructor
|
||||||
|
AxisInfo(): AxisInfo({}, {}, {}) { }
|
||||||
|
// Construct contiguity info with known contiguity
|
||||||
|
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
|
||||||
|
ConstancyT knownConstancy)
|
||||||
|
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||||
|
constancy(knownConstancy), rank(contiguity.size()) {
|
||||||
|
assert(knownDivisibility.size() == rank);
|
||||||
|
assert(knownConstancy.size() == rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Accessors
|
||||||
|
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||||
|
const ContiguityT& getContiguity() const { return contiguity; }
|
||||||
|
|
||||||
|
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||||
|
const DivisibilityT& getDivisibility() const { return divisibility; }
|
||||||
|
|
||||||
|
int getConstancy(size_t d) const { return constancy[d]; }
|
||||||
|
const ConstancyT& getConstancy() const { return constancy; }
|
||||||
|
|
||||||
|
int getRank() const { return rank; }
|
||||||
|
|
||||||
|
// Comparison
|
||||||
|
bool operator==(const AxisInfo &other) const {
|
||||||
|
return (contiguity == other.contiguity) &&
|
||||||
|
(divisibility == other.divisibility) &&
|
||||||
|
(constancy == other.constancy);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The pessimistic value state of the contiguity is unknown.
|
||||||
|
static AxisInfo getPessimisticValueState(MLIRContext *context)
|
||||||
|
{ return AxisInfo(); }
|
||||||
|
static AxisInfo getPessimisticValueState(Value value);
|
||||||
|
|
||||||
|
// The gcd of both arguments for each dimension
|
||||||
|
static AxisInfo join(const AxisInfo &lhs,
|
||||||
|
const AxisInfo &rhs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// The _contiguity_ information maps the `d`-th
|
||||||
|
/// dimension to the length of the shortest
|
||||||
|
/// sequence of contiguous integers along it
|
||||||
|
/// For example:
|
||||||
|
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||||
|
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||||
|
/// Would have contiguity [1, 4].
|
||||||
|
/// and
|
||||||
|
/// [12, 16, 20, 24]
|
||||||
|
/// [13, 17, 21, 25]
|
||||||
|
/// [14, 18, 22, 26]
|
||||||
|
/// [15, 19, 23, 27]
|
||||||
|
/// [18, 22, 26, 30]
|
||||||
|
/// [19, 23, 27, 31]
|
||||||
|
/// Would have contiguity [2, 1].
|
||||||
|
ContiguityT contiguity;
|
||||||
|
|
||||||
|
/// The _divisibility_ information maps the `d`-th
|
||||||
|
/// dimension to the largest power-of-two that
|
||||||
|
/// divides the first element of all the values along it
|
||||||
|
/// For example:
|
||||||
|
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||||
|
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||||
|
// would have divisibility [1, 2]
|
||||||
|
// and
|
||||||
|
/// [12, 16, 20, 24]
|
||||||
|
/// [13, 17, 21, 25]
|
||||||
|
/// [14, 18, 22, 26]
|
||||||
|
/// [15, 19, 23, 27]
|
||||||
|
// would have divisibility [4, 1]
|
||||||
|
DivisibilityT divisibility;
|
||||||
|
|
||||||
|
/// The _constancy_ information maps the `d`-th
|
||||||
|
/// dimension to the length of the shortest
|
||||||
|
/// sequence of constant integer along it. This is
|
||||||
|
/// particularly useful to infer the contiguity
|
||||||
|
/// of operations (e.g., add) involving a constant
|
||||||
|
/// For example
|
||||||
|
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||||
|
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||||
|
/// would have constancy [1, 4]
|
||||||
|
ConstancyT constancy;
|
||||||
|
|
||||||
|
// number of dimensions of the lattice
|
||||||
|
int rank;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class AxisInfoAnalysis
|
||||||
|
: public ForwardDataFlowAnalysis<AxisInfo> {
|
||||||
|
|
||||||
|
private:
|
||||||
|
static const int maxPow2Divisor = 65536;
|
||||||
|
|
||||||
|
int highestPowOf2Divisor(int n){
|
||||||
|
if(n==0)
|
||||||
|
return maxPow2Divisor;
|
||||||
|
return (n & (~(n - 1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||||
|
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||||
|
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||||
|
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
|
||||||
|
|
||||||
|
public:
|
||||||
|
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||||
|
|
||||||
|
ChangeResult visitOperation(Operation *op,
|
||||||
|
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
@@ -33,6 +33,8 @@ def Triton_Dialect : Dialect {
|
|||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
void registerTypes();
|
void registerTypes();
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasConstantMaterializer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TRITON_DIALECT
|
#endif // TRITON_DIALECT
|
||||||
|
@@ -143,8 +143,20 @@ def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementT
|
|||||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "splat";
|
||||||
|
|
||||||
|
let arguments = (ins TT_Type:$src);
|
||||||
|
|
||||||
|
let results = (outs TT_Tensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
let summary = "broadcast";
|
let summary = "broadcast. No left-padding as of now.";
|
||||||
|
|
||||||
let arguments = (ins TT_Type:$src);
|
let arguments = (ins TT_Type:$src);
|
||||||
|
|
||||||
|
@@ -59,6 +59,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||||
|
// This is needed because Arith's Cmp ops don't
|
||||||
|
// handle encodings
|
||||||
|
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||||
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
||||||
let summary = "integer comparison operation";
|
let summary = "integer comparison operation";
|
||||||
|
|
||||||
|
239
lib/Analysis/AxisInfo.cpp
Normal file
239
lib/Analysis/AxisInfo.cpp
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "triton/Analysis/AxisInfo.h"
|
||||||
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AxisInfo
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Function for extended Euclidean Algorithm
|
||||||
|
static int gcd_impl(int a, int b, int *x, int *y){
|
||||||
|
// Base Case
|
||||||
|
if (a == 0) {
|
||||||
|
*x = 0;
|
||||||
|
*y = 1;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
int x1, y1; // To store results of recursive call
|
||||||
|
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||||
|
// Update x and y using results of
|
||||||
|
// recursive call
|
||||||
|
*x = y1 - (b/a) * x1;
|
||||||
|
*y = x1;
|
||||||
|
return gcd;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int gcd(int a, int b) {
|
||||||
|
int x, y;
|
||||||
|
return gcd_impl(a, b, &x, &y);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||||
|
size_t rank = 1;
|
||||||
|
if(TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||||
|
rank = ty.getRank();
|
||||||
|
int divHint = 1;
|
||||||
|
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
|
||||||
|
Operation* op = blockArg.getOwner()->getParentOp();
|
||||||
|
if(FuncOp fun = dyn_cast<FuncOp>(op)){
|
||||||
|
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||||
|
if(attr)
|
||||||
|
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ContiguityT contiguity(rank, 1);
|
||||||
|
DivisibilityT divisibility(rank, divHint);
|
||||||
|
ConstancyT constancy(rank, 1);
|
||||||
|
return AxisInfo(contiguity, divisibility, constancy);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// The gcd of both arguments for each dimension
|
||||||
|
AxisInfo AxisInfo::join(const AxisInfo &lhs,
|
||||||
|
const AxisInfo &rhs) {
|
||||||
|
ContiguityT retContiguity;
|
||||||
|
DivisibilityT retDivisibility;
|
||||||
|
ConstancyT retConstancy;
|
||||||
|
for(size_t d = 0; d < lhs.getRank(); d++){
|
||||||
|
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||||
|
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||||
|
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||||
|
}
|
||||||
|
return AxisInfo(retContiguity, retDivisibility, retConstancy);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AxisInfoAnalysis
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||||
|
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||||
|
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||||
|
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
|
||||||
|
int rank = lhsInfo.getRank();
|
||||||
|
AxisInfo::ContiguityT newContiguity;
|
||||||
|
AxisInfo::DivisibilityT newDivisibility;
|
||||||
|
AxisInfo::ConstancyT newConstancy;
|
||||||
|
for(size_t d = 0; d < rank; d++){
|
||||||
|
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||||
|
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||||
|
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||||
|
}
|
||||||
|
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||||
|
}
|
||||||
|
|
||||||
|
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||||
|
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||||
|
AxisInfo curr;
|
||||||
|
// This preserves the input axes (e.g., cast):
|
||||||
|
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||||
|
triton::PtrToIntOp, triton::IntToPtrOp>(op))
|
||||||
|
curr = operands[0]->getValue();
|
||||||
|
// Constant ranges
|
||||||
|
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
|
||||||
|
int start = make_range.start();
|
||||||
|
int end = make_range.end();
|
||||||
|
AxisInfo::ContiguityT contiguity = {end - start};
|
||||||
|
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)};
|
||||||
|
AxisInfo::ConstancyT constancy = {1};
|
||||||
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
|
}
|
||||||
|
// Constant
|
||||||
|
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
|
||||||
|
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
|
||||||
|
if(intAttr){
|
||||||
|
size_t val = intAttr.getValue().getZExtValue();
|
||||||
|
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
|
||||||
|
}
|
||||||
|
// TODO: generalize to dense attr
|
||||||
|
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
|
||||||
|
if(splatAttr && splatAttr.getElementType().isInteger(32)){
|
||||||
|
auto value = splatAttr.getSplatValue<int>();
|
||||||
|
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||||
|
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||||
|
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||||
|
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Addition
|
||||||
|
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
|
||||||
|
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||||
|
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||||
|
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||||
|
};
|
||||||
|
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||||
|
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||||
|
};
|
||||||
|
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||||
|
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||||
|
};
|
||||||
|
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||||
|
newContiguity, newDivisibility, newConstancy);
|
||||||
|
}
|
||||||
|
// Multiplication
|
||||||
|
if (llvm::isa<arith::MulIOp>(op)){
|
||||||
|
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||||
|
return 1;
|
||||||
|
};
|
||||||
|
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||||
|
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||||
|
};
|
||||||
|
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||||
|
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
|
||||||
|
};
|
||||||
|
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||||
|
newContiguity, newDivisibility, newConstancy);
|
||||||
|
}
|
||||||
|
// Splat
|
||||||
|
if (llvm::isa<triton::SplatOp>(op)){
|
||||||
|
Type _retTy = *op->result_type_begin();
|
||||||
|
TensorType retTy = _retTy.cast<TensorType>();
|
||||||
|
AxisInfo opInfo = operands[0]->getValue();
|
||||||
|
AxisInfo::ContiguityT contiguity;
|
||||||
|
AxisInfo::DivisibilityT divisibility;
|
||||||
|
AxisInfo::ConstancyT constancy;
|
||||||
|
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||||
|
contiguity.push_back(1);
|
||||||
|
divisibility.push_back(opInfo.getDivisibility(0));
|
||||||
|
constancy.push_back(retTy.getShape()[d]);
|
||||||
|
}
|
||||||
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
|
}
|
||||||
|
// Reshape
|
||||||
|
// TODO: Replace by `unsqueeze`
|
||||||
|
if (llvm::isa<triton::ReshapeOp>(op)){
|
||||||
|
Type _retTy = *op->result_type_begin();
|
||||||
|
Type _opTy = *op->operand_type_begin();
|
||||||
|
TensorType retTy = _retTy.cast<TensorType>();
|
||||||
|
TensorType opTy = _opTy.cast<TensorType>();
|
||||||
|
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||||
|
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||||
|
AxisInfo opInfo = operands[0]->getValue();
|
||||||
|
AxisInfo::ContiguityT contiguity;
|
||||||
|
AxisInfo::DivisibilityT divisibility;
|
||||||
|
AxisInfo::ConstancyT constancy;
|
||||||
|
bool is_skewed = false;
|
||||||
|
size_t current = 0;
|
||||||
|
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||||
|
if(retShape[d] == 1){
|
||||||
|
contiguity.push_back(1);
|
||||||
|
divisibility.push_back(1);
|
||||||
|
constancy.push_back(1);
|
||||||
|
}
|
||||||
|
else if(!is_skewed
|
||||||
|
&& retShape[d] == opShape[current]){
|
||||||
|
contiguity.push_back(opInfo.getContiguity()[current]);
|
||||||
|
divisibility.push_back(opInfo.getDivisibility()[current]);
|
||||||
|
constancy.push_back(opInfo.getConstancy()[current]);
|
||||||
|
current++;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
is_skewed = true;
|
||||||
|
contiguity.push_back(1);
|
||||||
|
divisibility.push_back(1);
|
||||||
|
constancy.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
|
}
|
||||||
|
// Broadcast
|
||||||
|
if (llvm::isa<triton::BroadcastOp>(op)){
|
||||||
|
Type _retTy = *op->result_type_begin();
|
||||||
|
Type _opTy = *op->operand_type_begin();
|
||||||
|
TensorType retTy = _retTy.cast<TensorType>();
|
||||||
|
TensorType opTy = _opTy.cast<TensorType>();
|
||||||
|
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||||
|
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||||
|
AxisInfo opInfo = operands[0]->getValue();
|
||||||
|
AxisInfo::ContiguityT contiguity;
|
||||||
|
AxisInfo::DivisibilityT divisibility;
|
||||||
|
AxisInfo::ConstancyT constancy;
|
||||||
|
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||||
|
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||||
|
divisibility.push_back(opInfo.getDivisibility(d));
|
||||||
|
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||||
|
}
|
||||||
|
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||||
|
}
|
||||||
|
if(curr.getRank() == 0){
|
||||||
|
return markAllPessimisticFixpoint(op->getResults());
|
||||||
|
}
|
||||||
|
// join all latice elements
|
||||||
|
ChangeResult result = ChangeResult::NoChange;
|
||||||
|
for (Value value : op->getResults()) {
|
||||||
|
result |= getLatticeElement(value).join(curr);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
3
lib/Analysis/CMakeLists.txt
Normal file
3
lib/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
add_mlir_library(TritonAnalysis
|
||||||
|
AxisInfo.cpp
|
||||||
|
)
|
@@ -1,4 +1,5 @@
|
|||||||
# add_subdirectory(codegen)
|
# add_subdirectory(codegen)
|
||||||
add_subdirectory(driver)
|
add_subdirectory(driver)
|
||||||
|
add_subdirectory(Analysis)
|
||||||
add_subdirectory(Conversion)
|
add_subdirectory(Conversion)
|
||||||
add_subdirectory(Dialect)
|
add_subdirectory(Dialect)
|
||||||
|
@@ -240,6 +240,7 @@ void populateTritonPatterns(
|
|||||||
) {
|
) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||||
|
TritonGenericPattern<triton::SplatOp>,
|
||||||
TritonGenericPattern<triton::BroadcastOp>,
|
TritonGenericPattern<triton::BroadcastOp>,
|
||||||
TritonGenericPattern<triton::GEPOp>,
|
TritonGenericPattern<triton::GEPOp>,
|
||||||
TritonReducePattern,
|
TritonReducePattern,
|
||||||
|
@@ -23,3 +23,8 @@ void TritonDialect::initialize() {
|
|||||||
|
|
||||||
// We can also add interface here.
|
// We can also add interface here.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||||||
|
Type type, Location loc) {
|
||||||
|
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||||
|
}
|
@@ -97,14 +97,22 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
|
|||||||
|
|
||||||
//-- DotOp --
|
//-- DotOp --
|
||||||
|
|
||||||
|
//-- SplatOp --
|
||||||
|
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
|
if (!constOperand)
|
||||||
|
return {};
|
||||||
|
auto shapedType = getType().cast<ShapedType>();
|
||||||
|
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
//-- BroadcastOp --
|
//-- BroadcastOp --
|
||||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
if (!constOperand)
|
if (!constOperand)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto shapedType = getType().cast<ShapedType>();
|
auto shapedType = getType().cast<ShapedType>();
|
||||||
|
|
||||||
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
||||||
return self.addEntryBlock();
|
return self.addEntryBlock();
|
||||||
}, ret::reference)
|
}, ret::reference)
|
||||||
|
.def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){
|
||||||
|
// set arg attributes "name" to value "val"
|
||||||
|
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
||||||
|
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||||
|
}, ret::reference)
|
||||||
.def("reset_type", &mlir::FuncOp::setType)
|
.def("reset_type", &mlir::FuncOp::setType)
|
||||||
;
|
;
|
||||||
|
|
||||||
@@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
auto argType = arg.getType();
|
auto argType = arg.getType();
|
||||||
return self.create<mlir::triton::BroadcastOp>(
|
auto ret = self.createOrFold<mlir::triton::SplatOp>(
|
||||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||||
);
|
);
|
||||||
|
return ret;
|
||||||
})
|
})
|
||||||
// // atomic
|
// // atomic
|
||||||
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
|
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||||
@@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||||
return mlir::succeeded(self.run(mod.getOperation()));
|
return mlir::succeeded(self.run(mod.getOperation()));
|
||||||
})
|
})
|
||||||
|
.def("add_sccp_pass", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createSCCPPass());
|
||||||
|
})
|
||||||
|
.def("add_symbol_dce_pass", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createSymbolDCEPass());
|
||||||
|
})
|
||||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createInlinerPass());
|
self.addPass(mlir::createInlinerPass());
|
||||||
})
|
})
|
||||||
|
@@ -199,14 +199,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
arg_values.append(cst)
|
arg_values.append(cst)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
# TODO: ...
|
if i in self.attributes:
|
||||||
# if i in self.attributes:
|
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
||||||
# is_ptr = fn.args[idx].type.is_ptr()
|
|
||||||
# attr = 'aligned' if is_ptr else 'multiple_of'
|
|
||||||
# attr = getattr(_triton.ir.attribute_kind, attr)
|
|
||||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
|
||||||
# fn.add_attr(idx + 1, attr)
|
|
||||||
# fn.args[idx].name = arg_name
|
|
||||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
@@ -1307,8 +1301,13 @@ class JITFunction:
|
|||||||
raise CompilationError(self.src, node) from e
|
raise CompilationError(self.src, node) from e
|
||||||
# cache num_warps & num_stages
|
# cache num_warps & num_stages
|
||||||
self.num_warps, self.num_stages = num_warps, num_stages
|
self.num_warps, self.num_stages = num_warps, num_stages
|
||||||
|
# run simple SCCP and DCE here to clean-up the generated IR
|
||||||
|
mod = generator.module
|
||||||
|
pm = _triton.ir.pass_manager(context)
|
||||||
|
pm.add_canonicalizer_pass()
|
||||||
|
pm.run(mod)
|
||||||
# FIXME: now we need to return context, otherwise it will be deleted
|
# FIXME: now we need to return context, otherwise it will be deleted
|
||||||
return generator.module, context
|
return mod, context
|
||||||
|
|
||||||
def compile_ttir_to_llir(self, mod, ctx):
|
def compile_ttir_to_llir(self, mod, ctx):
|
||||||
num_warps, num_stages = self.num_warps, self.num_stages
|
num_warps, num_stages = self.num_warps, self.num_stages
|
||||||
|
52
test/Analysis/test-alignment.mlir
Normal file
52
test/Analysis/test-alignment.mlir
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
|
||||||
|
|
||||||
|
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||||
|
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||||
|
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||||
|
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
||||||
|
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||||
|
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||||
|
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||||
|
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||||
|
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
|
||||||
|
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||||
|
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||||
|
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||||
|
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||||
|
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||||
|
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||||
|
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||||
|
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||||
|
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||||
|
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||||
|
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||||
|
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
|
||||||
|
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128]
|
||||||
|
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
|
||||||
|
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||||
|
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||||
|
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||||
|
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||||
|
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||||
|
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
||||||
|
return
|
||||||
|
}
|
@@ -1,3 +1,5 @@
|
|||||||
|
add_subdirectory(lib)
|
||||||
|
|
||||||
llvm_canonicalize_cmake_booleans(
|
llvm_canonicalize_cmake_booleans(
|
||||||
MLIR_ENABLE_BINDINGS_PYTHON
|
MLIR_ENABLE_BINDINGS_PYTHON
|
||||||
)
|
)
|
||||||
|
6
test/lib/Analysis/CMakeLists.txt
Normal file
6
test/lib/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
add_mlir_library(TritonTestAnalysis
|
||||||
|
TestAxisInfo.cpp
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
TritonAnalysis
|
||||||
|
)
|
67
test/lib/Analysis/TestAxisInfo.cpp
Normal file
67
test/lib/Analysis/TestAxisInfo.cpp
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#include "triton/Analysis/AxisInfo.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace{
|
||||||
|
|
||||||
|
struct TestAxisInfoPass
|
||||||
|
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>>{
|
||||||
|
|
||||||
|
// LLVM15+
|
||||||
|
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||||
|
|
||||||
|
void print(const std::string& name, raw_ostream& os, ArrayRef<int> vals){
|
||||||
|
os << name << ": [";
|
||||||
|
for(size_t d = 0; d < vals.size(); d++){
|
||||||
|
if(d != 0) os << ", ";
|
||||||
|
os << vals[d];
|
||||||
|
}
|
||||||
|
os << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef getArgument() const final { return "test-print-alignment"; }
|
||||||
|
StringRef getDescription() const final
|
||||||
|
{ return "print the result of the alignment analysis pass"; }
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
Operation* operation = getOperation();
|
||||||
|
auto& os = llvm::errs();
|
||||||
|
os << "Testing: " << operation->getName() << "\n";
|
||||||
|
AxisInfoAnalysis analysis(&getContext());
|
||||||
|
analysis.run(operation);
|
||||||
|
operation->walk([&](Operation* op){
|
||||||
|
if(op->getNumResults() < 1)
|
||||||
|
return;
|
||||||
|
for(Value result: op->getResults()){
|
||||||
|
// std::ostringstream oss;
|
||||||
|
// result.print(oss);
|
||||||
|
// os << " => ";
|
||||||
|
LatticeElement<AxisInfo> *latticeElement = analysis.lookupLatticeElement(result);
|
||||||
|
if(!latticeElement){
|
||||||
|
os << "None\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
AxisInfo& info = latticeElement->getValue();
|
||||||
|
print("Contiguity", os, info.getContiguity());
|
||||||
|
os << " ; ";
|
||||||
|
print("Divisibility", os, info.getDivisibility());
|
||||||
|
os << " ; ";
|
||||||
|
print("Constancy", os, info.getConstancy());
|
||||||
|
os << " ( ";
|
||||||
|
result.print(os);
|
||||||
|
os << " ) ";
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlir{
|
||||||
|
namespace test{
|
||||||
|
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
1
test/lib/CMakeLists.txt
Normal file
1
test/lib/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
add_subdirectory(Analysis)
|
Reference in New Issue
Block a user