[Analysis] Added Axis Info Analysis (#8)
This commit is contained in:
@@ -183,6 +183,7 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
target_link_libraries(triton
|
||||
${PYTHON_LIBRARIES}
|
||||
TritonIR
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonDriver
|
||||
TritonToTritonGPU
|
||||
|
@@ -10,10 +10,14 @@ add_llvm_executable(triton-opt triton-opt.cpp)
|
||||
# TODO: what's this?
|
||||
# llvm_update_compile_flags(triton-opt)
|
||||
target_link_libraries(triton-opt PRIVATE
|
||||
TritonAnalysis
|
||||
TritonTransforms
|
||||
TritonGPUTransforms
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
# tests
|
||||
TritonTestAnalysis
|
||||
# MLIR core
|
||||
MLIROptLib
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
|
@@ -10,11 +10,17 @@
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
|
||||
namespace mlir{
|
||||
namespace test{
|
||||
void registerTestAlignmentPass();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::registerTritonPasses();
|
||||
mlir::registerTritonGPUPasses();
|
||||
mlir::test::registerTestAlignmentPass();
|
||||
mlir::triton::registerConvertTritonToTritonGPUPass();
|
||||
|
||||
// TODO: register Triton & TritonGPU passes
|
||||
|
145
include/triton/Analysis/AxisInfo.h
Normal file
145
include/triton/Analysis/AxisInfo.h
Normal file
@@ -0,0 +1,145 @@
|
||||
#ifndef TRITON_ANALYSIS_AXISINFO_H
|
||||
#define TRITON_ANALYSIS_AXISINFO_H
|
||||
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This lattice value represents known information on the axes of a lattice.
|
||||
/// Axis information is represented by a std::map<int, int>
|
||||
class AxisInfo {
|
||||
public:
|
||||
typedef std::vector<int> ContiguityT;
|
||||
typedef std::vector<int> DivisibilityT;
|
||||
typedef std::vector<int> ConstancyT;
|
||||
|
||||
public:
|
||||
// Default constructor
|
||||
AxisInfo(): AxisInfo({}, {}, {}) { }
|
||||
// Construct contiguity info with known contiguity
|
||||
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
|
||||
ConstancyT knownConstancy)
|
||||
: contiguity(knownContiguity), divisibility(knownDivisibility),
|
||||
constancy(knownConstancy), rank(contiguity.size()) {
|
||||
assert(knownDivisibility.size() == rank);
|
||||
assert(knownConstancy.size() == rank);
|
||||
}
|
||||
|
||||
|
||||
// Accessors
|
||||
int getContiguity(size_t d) const { return contiguity[d]; }
|
||||
const ContiguityT& getContiguity() const { return contiguity; }
|
||||
|
||||
int getDivisibility(size_t d) const { return divisibility[d]; }
|
||||
const DivisibilityT& getDivisibility() const { return divisibility; }
|
||||
|
||||
int getConstancy(size_t d) const { return constancy[d]; }
|
||||
const ConstancyT& getConstancy() const { return constancy; }
|
||||
|
||||
int getRank() const { return rank; }
|
||||
|
||||
// Comparison
|
||||
bool operator==(const AxisInfo &other) const {
|
||||
return (contiguity == other.contiguity) &&
|
||||
(divisibility == other.divisibility) &&
|
||||
(constancy == other.constancy);
|
||||
}
|
||||
|
||||
/// The pessimistic value state of the contiguity is unknown.
|
||||
static AxisInfo getPessimisticValueState(MLIRContext *context)
|
||||
{ return AxisInfo(); }
|
||||
static AxisInfo getPessimisticValueState(Value value);
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
static AxisInfo join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs);
|
||||
|
||||
private:
|
||||
/// The _contiguity_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of contiguous integers along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
/// Would have contiguity [1, 4].
|
||||
/// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
/// [18, 22, 26, 30]
|
||||
/// [19, 23, 27, 31]
|
||||
/// Would have contiguity [2, 1].
|
||||
ContiguityT contiguity;
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
// would have divisibility [1, 2]
|
||||
// and
|
||||
/// [12, 16, 20, 24]
|
||||
/// [13, 17, 21, 25]
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
DivisibilityT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
/// dimension to the length of the shortest
|
||||
/// sequence of constant integer along it. This is
|
||||
/// particularly useful to infer the contiguity
|
||||
/// of operations (e.g., add) involving a constant
|
||||
/// For example
|
||||
/// [8, 8, 8, 8, 12, 12, 12, 12]
|
||||
/// [16, 16, 16, 16, 20, 20, 20, 20]
|
||||
/// would have constancy [1, 4]
|
||||
ConstancyT constancy;
|
||||
|
||||
// number of dimensions of the lattice
|
||||
int rank;
|
||||
};
|
||||
|
||||
|
||||
class AxisInfoAnalysis
|
||||
: public ForwardDataFlowAnalysis<AxisInfo> {
|
||||
|
||||
private:
|
||||
static const int maxPow2Divisor = 65536;
|
||||
|
||||
int highestPowOf2Divisor(int n){
|
||||
if(n==0)
|
||||
return maxPow2Divisor;
|
||||
return (n & (~(n - 1)));
|
||||
}
|
||||
|
||||
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
|
||||
|
||||
public:
|
||||
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
|
||||
|
||||
ChangeResult visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
#endif
|
@@ -33,6 +33,8 @@ def Triton_Dialect : Dialect {
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
#endif // TRITON_DIALECT
|
||||
|
@@ -143,8 +143,20 @@ def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementT
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "splat";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
let results = (outs TT_Tensor:$result);
|
||||
|
||||
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "broadcast";
|
||||
let summary = "broadcast. No left-padding as of now.";
|
||||
|
||||
let arguments = (ins TT_Type:$src);
|
||||
|
||||
|
@@ -59,6 +59,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
}
|
||||
|
||||
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
|
||||
// This is needed because Arith's Cmp ops don't
|
||||
// handle encodings
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
|
239
lib/Analysis/AxisInfo.cpp
Normal file
239
lib/Analysis/AxisInfo.cpp
Normal file
@@ -0,0 +1,239 @@
|
||||
#include "mlir/Analysis/DataFlowAnalysis.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <iostream>
|
||||
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Function for extended Euclidean Algorithm
|
||||
static int gcd_impl(int a, int b, int *x, int *y){
|
||||
// Base Case
|
||||
if (a == 0) {
|
||||
*x = 0;
|
||||
*y = 1;
|
||||
return b;
|
||||
}
|
||||
int x1, y1; // To store results of recursive call
|
||||
int gcd = gcd_impl(b%a, a, &x1, &y1);
|
||||
// Update x and y using results of
|
||||
// recursive call
|
||||
*x = y1 - (b/a) * x1;
|
||||
*y = x1;
|
||||
return gcd;
|
||||
}
|
||||
|
||||
static int gcd(int a, int b) {
|
||||
int x, y;
|
||||
return gcd_impl(a, b, &x, &y);
|
||||
}
|
||||
|
||||
|
||||
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
size_t rank = 1;
|
||||
if(TensorType ty = value.getType().dyn_cast<TensorType>())
|
||||
rank = ty.getRank();
|
||||
int divHint = 1;
|
||||
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
|
||||
Operation* op = blockArg.getOwner()->getParentOp();
|
||||
if(FuncOp fun = dyn_cast<FuncOp>(op)){
|
||||
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
|
||||
if(attr)
|
||||
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
|
||||
}
|
||||
}
|
||||
ContiguityT contiguity(rank, 1);
|
||||
DivisibilityT divisibility(rank, divHint);
|
||||
ConstancyT constancy(rank, 1);
|
||||
return AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs,
|
||||
const AxisInfo &rhs) {
|
||||
ContiguityT retContiguity;
|
||||
DivisibilityT retDivisibility;
|
||||
ConstancyT retConstancy;
|
||||
for(size_t d = 0; d < lhs.getRank(); d++){
|
||||
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
|
||||
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
|
||||
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
|
||||
}
|
||||
return AxisInfo(retContiguity, retDivisibility, retConstancy);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AxisInfoAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
|
||||
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
|
||||
int rank = lhsInfo.getRank();
|
||||
AxisInfo::ContiguityT newContiguity;
|
||||
AxisInfo::DivisibilityT newDivisibility;
|
||||
AxisInfo::ConstancyT newConstancy;
|
||||
for(size_t d = 0; d < rank; d++){
|
||||
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
|
||||
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
|
||||
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
|
||||
}
|
||||
return AxisInfo(newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
|
||||
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<AxisInfo> *> operands) {
|
||||
AxisInfo curr;
|
||||
// This preserves the input axes (e.g., cast):
|
||||
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
|
||||
triton::PtrToIntOp, triton::IntToPtrOp>(op))
|
||||
curr = operands[0]->getValue();
|
||||
// Constant ranges
|
||||
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
|
||||
int start = make_range.start();
|
||||
int end = make_range.end();
|
||||
AxisInfo::ContiguityT contiguity = {end - start};
|
||||
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)};
|
||||
AxisInfo::ConstancyT constancy = {1};
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Constant
|
||||
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
|
||||
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
|
||||
if(intAttr){
|
||||
size_t val = intAttr.getValue().getZExtValue();
|
||||
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
|
||||
}
|
||||
// TODO: generalize to dense attr
|
||||
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if(splatAttr && splatAttr.getElementType().isInteger(32)){
|
||||
auto value = splatAttr.getSplatValue<int>();
|
||||
TensorType ty = splatAttr.getType().cast<TensorType>();
|
||||
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
|
||||
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
|
||||
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
|
||||
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
};
|
||||
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Multiplication
|
||||
if (llvm::isa<arith::MulIOp>(op)){
|
||||
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return 1;
|
||||
};
|
||||
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
|
||||
};
|
||||
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
|
||||
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
|
||||
};
|
||||
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
|
||||
newContiguity, newDivisibility, newConstancy);
|
||||
}
|
||||
// Splat
|
||||
if (llvm::isa<triton::SplatOp>(op)){
|
||||
Type _retTy = *op->result_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(opInfo.getDivisibility(0));
|
||||
constancy.push_back(retTy.getShape()[d]);
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Reshape
|
||||
// TODO: Replace by `unsqueeze`
|
||||
if (llvm::isa<triton::ReshapeOp>(op)){
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
TensorType opTy = _opTy.cast<TensorType>();
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
bool is_skewed = false;
|
||||
size_t current = 0;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
if(retShape[d] == 1){
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
}
|
||||
else if(!is_skewed
|
||||
&& retShape[d] == opShape[current]){
|
||||
contiguity.push_back(opInfo.getContiguity()[current]);
|
||||
divisibility.push_back(opInfo.getDivisibility()[current]);
|
||||
constancy.push_back(opInfo.getConstancy()[current]);
|
||||
current++;
|
||||
}
|
||||
else {
|
||||
is_skewed = true;
|
||||
contiguity.push_back(1);
|
||||
divisibility.push_back(1);
|
||||
constancy.push_back(1);
|
||||
}
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
// Broadcast
|
||||
if (llvm::isa<triton::BroadcastOp>(op)){
|
||||
Type _retTy = *op->result_type_begin();
|
||||
Type _opTy = *op->operand_type_begin();
|
||||
TensorType retTy = _retTy.cast<TensorType>();
|
||||
TensorType opTy = _opTy.cast<TensorType>();
|
||||
ArrayRef<int64_t> retShape = retTy.getShape();
|
||||
ArrayRef<int64_t> opShape = opTy.getShape();
|
||||
AxisInfo opInfo = operands[0]->getValue();
|
||||
AxisInfo::ContiguityT contiguity;
|
||||
AxisInfo::DivisibilityT divisibility;
|
||||
AxisInfo::ConstancyT constancy;
|
||||
for(size_t d = 0; d < retTy.getRank(); d++){
|
||||
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
|
||||
divisibility.push_back(opInfo.getDivisibility(d));
|
||||
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
|
||||
}
|
||||
curr = AxisInfo(contiguity, divisibility, constancy);
|
||||
}
|
||||
if(curr.getRank() == 0){
|
||||
return markAllPessimisticFixpoint(op->getResults());
|
||||
}
|
||||
// join all latice elements
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Value value : op->getResults()) {
|
||||
result |= getLatticeElement(value).join(curr);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
3
lib/Analysis/CMakeLists.txt
Normal file
3
lib/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
add_mlir_library(TritonAnalysis
|
||||
AxisInfo.cpp
|
||||
)
|
@@ -1,4 +1,5 @@
|
||||
# add_subdirectory(codegen)
|
||||
add_subdirectory(driver)
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
|
@@ -240,6 +240,7 @@ void populateTritonPatterns(
|
||||
) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::SplatOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>,
|
||||
TritonReducePattern,
|
||||
|
@@ -23,3 +23,8 @@ void TritonDialect::initialize() {
|
||||
|
||||
// We can also add interface here.
|
||||
}
|
||||
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||||
Type type, Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
}
|
@@ -97,14 +97,22 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
|
||||
|
||||
//-- DotOp --
|
||||
|
||||
//-- SplatOp --
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
return ret;
|
||||
}
|
||||
|
||||
//-- BroadcastOp --
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
|
||||
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
}
|
||||
|
||||
|
@@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
|
||||
return self.addEntryBlock();
|
||||
}, ret::reference)
|
||||
.def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){
|
||||
// set arg attributes "name" to value "val"
|
||||
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
}, ret::reference)
|
||||
.def("reset_type", &mlir::FuncOp::setType)
|
||||
;
|
||||
|
||||
@@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto argType = arg.getType();
|
||||
return self.create<mlir::triton::BroadcastOp>(
|
||||
auto ret = self.createOrFold<mlir::triton::SplatOp>(
|
||||
loc, mlir::RankedTensorType::get(shape, argType), arg
|
||||
);
|
||||
return ret;
|
||||
})
|
||||
// // atomic
|
||||
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
@@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
|
||||
return mlir::succeeded(self.run(mod.getOperation()));
|
||||
})
|
||||
.def("add_sccp_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSCCPPass());
|
||||
})
|
||||
.def("add_symbol_dce_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createSymbolDCEPass());
|
||||
})
|
||||
.def("add_inliner_pass", [](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createInlinerPass());
|
||||
})
|
||||
|
@@ -199,14 +199,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
pass
|
||||
# TODO: ...
|
||||
# if i in self.attributes:
|
||||
# is_ptr = fn.args[idx].type.is_ptr()
|
||||
# attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
# attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
# attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
# fn.add_attr(idx + 1, attr)
|
||||
# fn.args[idx].name = arg_name
|
||||
if i in self.attributes:
|
||||
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
|
||||
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
|
||||
idx += 1
|
||||
|
||||
@@ -1307,8 +1301,13 @@ class JITFunction:
|
||||
raise CompilationError(self.src, node) from e
|
||||
# cache num_warps & num_stages
|
||||
self.num_warps, self.num_stages = num_warps, num_stages
|
||||
# run simple SCCP and DCE here to clean-up the generated IR
|
||||
mod = generator.module
|
||||
pm = _triton.ir.pass_manager(context)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
# FIXME: now we need to return context, otherwise it will be deleted
|
||||
return generator.module, context
|
||||
return mod, context
|
||||
|
||||
def compile_ttir_to_llir(self, mod, ctx):
|
||||
num_warps, num_stages = self.num_warps, self.num_stages
|
||||
|
52
test/Analysis/test-alignment.mlir
Normal file
52
test/Analysis/test-alignment.mlir
Normal file
@@ -0,0 +1,52 @@
|
||||
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
|
||||
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
|
||||
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128]
|
||||
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
||||
return
|
||||
}
|
@@ -1,3 +1,5 @@
|
||||
add_subdirectory(lib)
|
||||
|
||||
llvm_canonicalize_cmake_booleans(
|
||||
MLIR_ENABLE_BINDINGS_PYTHON
|
||||
)
|
||||
|
6
test/lib/Analysis/CMakeLists.txt
Normal file
6
test/lib/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
add_mlir_library(TritonTestAnalysis
|
||||
TestAxisInfo.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonAnalysis
|
||||
)
|
67
test/lib/Analysis/TestAxisInfo.cpp
Normal file
67
test/lib/Analysis/TestAxisInfo.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace{
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>>{
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||
|
||||
void print(const std::string& name, raw_ostream& os, ArrayRef<int> vals){
|
||||
os << name << ": [";
|
||||
for(size_t d = 0; d < vals.size(); d++){
|
||||
if(d != 0) os << ", ";
|
||||
os << vals[d];
|
||||
}
|
||||
os << "]";
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-print-alignment"; }
|
||||
StringRef getDescription() const final
|
||||
{ return "print the result of the alignment analysis pass"; }
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation* operation = getOperation();
|
||||
auto& os = llvm::errs();
|
||||
os << "Testing: " << operation->getName() << "\n";
|
||||
AxisInfoAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
operation->walk([&](Operation* op){
|
||||
if(op->getNumResults() < 1)
|
||||
return;
|
||||
for(Value result: op->getResults()){
|
||||
// std::ostringstream oss;
|
||||
// result.print(oss);
|
||||
// os << " => ";
|
||||
LatticeElement<AxisInfo> *latticeElement = analysis.lookupLatticeElement(result);
|
||||
if(!latticeElement){
|
||||
os << "None\n";
|
||||
return;
|
||||
}
|
||||
AxisInfo& info = latticeElement->getValue();
|
||||
print("Contiguity", os, info.getContiguity());
|
||||
os << " ; ";
|
||||
print("Divisibility", os, info.getDivisibility());
|
||||
os << " ; ";
|
||||
print("Constancy", os, info.getConstancy());
|
||||
os << " ( ";
|
||||
result.print(os);
|
||||
os << " ) ";
|
||||
os << "\n";
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace mlir{
|
||||
namespace test{
|
||||
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
|
||||
}
|
||||
}
|
||||
|
1
test/lib/CMakeLists.txt
Normal file
1
test/lib/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(Analysis)
|
Reference in New Issue
Block a user