[Analysis] Added Axis Info Analysis (#8)

This commit is contained in:
Philippe Tillet
2022-07-19 13:38:48 -07:00
committed by GitHub
parent df940aaab0
commit a633d2b403
20 changed files with 582 additions and 13 deletions

View File

@@ -183,6 +183,7 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
target_link_libraries(triton
${PYTHON_LIBRARIES}
TritonIR
TritonAnalysis
TritonTransforms
TritonDriver
TritonToTritonGPU

View File

@@ -10,10 +10,14 @@ add_llvm_executable(triton-opt triton-opt.cpp)
# TODO: what's this?
# llvm_update_compile_flags(triton-opt)
target_link_libraries(triton-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
${dialect_libs}
${conversion_libs}
# tests
TritonTestAnalysis
# MLIR core
MLIROptLib
MLIRPass
MLIRTransforms

View File

@@ -10,11 +10,17 @@
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
namespace mlir{
namespace test{
void registerTestAlignmentPass();
}
}
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::registerTritonPasses();
mlir::registerTritonGPUPasses();
mlir::test::registerTestAlignmentPass();
mlir::triton::registerConvertTritonToTritonGPUPass();
// TODO: register Triton & TritonGPU passes

View File

@@ -0,0 +1,145 @@
#ifndef TRITON_ANALYSIS_AXISINFO_H
#define TRITON_ANALYSIS_AXISINFO_H
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
/// This lattice value represents known information on the axes of a lattice.
/// Axis information is represented by a std::map<int, int>
class AxisInfo {
public:
typedef std::vector<int> ContiguityT;
typedef std::vector<int> DivisibilityT;
typedef std::vector<int> ConstancyT;
public:
// Default constructor
AxisInfo(): AxisInfo({}, {}, {}) { }
// Construct contiguity info with known contiguity
AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility,
ConstancyT knownConstancy)
: contiguity(knownContiguity), divisibility(knownDivisibility),
constancy(knownConstancy), rank(contiguity.size()) {
assert(knownDivisibility.size() == rank);
assert(knownConstancy.size() == rank);
}
// Accessors
int getContiguity(size_t d) const { return contiguity[d]; }
const ContiguityT& getContiguity() const { return contiguity; }
int getDivisibility(size_t d) const { return divisibility[d]; }
const DivisibilityT& getDivisibility() const { return divisibility; }
int getConstancy(size_t d) const { return constancy[d]; }
const ConstancyT& getConstancy() const { return constancy; }
int getRank() const { return rank; }
// Comparison
bool operator==(const AxisInfo &other) const {
return (contiguity == other.contiguity) &&
(divisibility == other.divisibility) &&
(constancy == other.constancy);
}
/// The pessimistic value state of the contiguity is unknown.
static AxisInfo getPessimisticValueState(MLIRContext *context)
{ return AxisInfo(); }
static AxisInfo getPessimisticValueState(Value value);
// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs,
const AxisInfo &rhs);
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of contiguous integers along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
/// Would have contiguity [1, 4].
/// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
/// [18, 22, 26, 30]
/// [19, 23, 27, 31]
/// Would have contiguity [2, 1].
ContiguityT contiguity;
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
// would have divisibility [1, 2]
// and
/// [12, 16, 20, 24]
/// [13, 17, 21, 25]
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
DivisibilityT divisibility;
/// The _constancy_ information maps the `d`-th
/// dimension to the length of the shortest
/// sequence of constant integer along it. This is
/// particularly useful to infer the contiguity
/// of operations (e.g., add) involving a constant
/// For example
/// [8, 8, 8, 8, 12, 12, 12, 12]
/// [16, 16, 16, 16, 20, 20, 20, 20]
/// would have constancy [1, 4]
ConstancyT constancy;
// number of dimensions of the lattice
int rank;
};
class AxisInfoAnalysis
: public ForwardDataFlowAnalysis<AxisInfo> {
private:
static const int maxPow2Divisor = 65536;
int highestPowOf2Divisor(int n){
if(n==0)
return maxPow2Divisor;
return (n & (~(n - 1)));
}
AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy);
public:
using ForwardDataFlowAnalysis<AxisInfo>::ForwardDataFlowAnalysis;
ChangeResult visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) override;
};
}
#endif

View File

@@ -33,6 +33,8 @@ def Triton_Dialect : Dialect {
let extraClassDeclaration = [{
void registerTypes();
}];
let hasConstantMaterializer = 1;
}
#endif // TRITON_DIALECT

View File

@@ -143,8 +143,20 @@ def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementT
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
}
def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "splat";
let arguments = (ins TT_Type:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
}
def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "broadcast";
let summary = "broadcast. No left-padding as of now.";
let arguments = (ins TT_Type:$src);

View File

@@ -59,6 +59,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
}
// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU.
// This is needed because Arith's Cmp ops don't
// handle encodings
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
def TTG_CmpIOp : TTG_Op<"cmpi"> {
let summary = "integer comparison operation";

239
lib/Analysis/AxisInfo.cpp Normal file
View File

@@ -0,0 +1,239 @@
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "llvm/Support/raw_ostream.h"
#include <iostream>
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// AxisInfo
//===----------------------------------------------------------------------===//
// Function for extended Euclidean Algorithm
static int gcd_impl(int a, int b, int *x, int *y){
// Base Case
if (a == 0) {
*x = 0;
*y = 1;
return b;
}
int x1, y1; // To store results of recursive call
int gcd = gcd_impl(b%a, a, &x1, &y1);
// Update x and y using results of
// recursive call
*x = y1 - (b/a) * x1;
*y = x1;
return gcd;
}
static int gcd(int a, int b) {
int x, y;
return gcd_impl(a, b, &x, &y);
}
AxisInfo AxisInfo::getPessimisticValueState(Value value) {
size_t rank = 1;
if(TensorType ty = value.getType().dyn_cast<TensorType>())
rank = ty.getRank();
int divHint = 1;
if(BlockArgument blockArg = value.dyn_cast<BlockArgument>()){
Operation* op = blockArg.getOwner()->getParentOp();
if(FuncOp fun = dyn_cast<FuncOp>(op)){
Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if(attr)
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
}
}
ContiguityT contiguity(rank, 1);
DivisibilityT divisibility(rank, divHint);
ConstancyT constancy(rank, 1);
return AxisInfo(contiguity, divisibility, constancy);
}
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs,
const AxisInfo &rhs) {
ContiguityT retContiguity;
DivisibilityT retDivisibility;
ConstancyT retConstancy;
for(size_t d = 0; d < lhs.getRank(); d++){
retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
}
return AxisInfo(retContiguity, retDivisibility, retConstancy);
}
//===----------------------------------------------------------------------===//
// AxisInfoAnalysis
//===----------------------------------------------------------------------===//
AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo,
const std::function<int(AxisInfo,AxisInfo,int)>& getContiguity,
const std::function<int(AxisInfo,AxisInfo,int)>& getDivisibility,
const std::function<int(AxisInfo,AxisInfo,int)>& getConstancy) {
int rank = lhsInfo.getRank();
AxisInfo::ContiguityT newContiguity;
AxisInfo::DivisibilityT newDivisibility;
AxisInfo::ConstancyT newConstancy;
for(size_t d = 0; d < rank; d++){
newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d));
newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d));
newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d));
}
return AxisInfo(newContiguity, newDivisibility, newConstancy);
}
ChangeResult AxisInfoAnalysis::visitOperation(Operation *op,
ArrayRef<LatticeElement<AxisInfo> *> operands) {
AxisInfo curr;
// This preserves the input axes (e.g., cast):
if (llvm::isa<arith::ExtSIOp, arith::ExtUIOp, arith::TruncIOp,
triton::PtrToIntOp, triton::IntToPtrOp>(op))
curr = operands[0]->getValue();
// Constant ranges
if (triton::MakeRangeOp make_range = llvm::dyn_cast<triton::MakeRangeOp>(op)){
int start = make_range.start();
int end = make_range.end();
AxisInfo::ContiguityT contiguity = {end - start};
AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)};
AxisInfo::ConstancyT constancy = {1};
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Constant
if (arith::ConstantOp constant = llvm::dyn_cast<arith::ConstantOp>(op)){
auto intAttr = constant.getValue().dyn_cast<IntegerAttr>();
if(intAttr){
size_t val = intAttr.getValue().getZExtValue();
curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1});
}
// TODO: generalize to dense attr
auto splatAttr = constant.getValue().dyn_cast<SplatElementsAttr>();
if(splatAttr && splatAttr.getElementType().isInteger(32)){
auto value = splatAttr.getSplatValue<int>();
TensorType ty = splatAttr.getType().cast<TensorType>();
curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1),
AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)),
AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end()));
}
}
// Addition
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)){
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
};
auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){
return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d));
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Multiplication
if (llvm::isa<arith::MulIOp>(op)){
auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){
return 1;
};
auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){
return gcd(lhs.getConstancy(d), rhs.getConstancy(d));
};
auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){
return lhs.getDivisibility(d)*rhs.getDivisibility(d);
};
curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(),
newContiguity, newDivisibility, newConstancy);
}
// Splat
if (llvm::isa<triton::SplatOp>(op)){
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
for(size_t d = 0; d < retTy.getRank(); d++){
contiguity.push_back(1);
divisibility.push_back(opInfo.getDivisibility(0));
constancy.push_back(retTy.getShape()[d]);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Reshape
// TODO: Replace by `unsqueeze`
if (llvm::isa<triton::ReshapeOp>(op)){
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
TensorType opTy = _opTy.cast<TensorType>();
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
bool is_skewed = false;
size_t current = 0;
for(size_t d = 0; d < retTy.getRank(); d++){
if(retShape[d] == 1){
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(1);
}
else if(!is_skewed
&& retShape[d] == opShape[current]){
contiguity.push_back(opInfo.getContiguity()[current]);
divisibility.push_back(opInfo.getDivisibility()[current]);
constancy.push_back(opInfo.getConstancy()[current]);
current++;
}
else {
is_skewed = true;
contiguity.push_back(1);
divisibility.push_back(1);
constancy.push_back(1);
}
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
// Broadcast
if (llvm::isa<triton::BroadcastOp>(op)){
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
TensorType opTy = _opTy.cast<TensorType>();
ArrayRef<int64_t> retShape = retTy.getShape();
ArrayRef<int64_t> opShape = opTy.getShape();
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::ContiguityT contiguity;
AxisInfo::DivisibilityT divisibility;
AxisInfo::ConstancyT constancy;
for(size_t d = 0; d < retTy.getRank(); d++){
contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d));
divisibility.push_back(opInfo.getDivisibility(d));
constancy.push_back(opShape[d] == 1 ? retShape[d] : 1);
}
curr = AxisInfo(contiguity, divisibility, constancy);
}
if(curr.getRank() == 0){
return markAllPessimisticFixpoint(op->getResults());
}
// join all latice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
result |= getLatticeElement(value).join(curr);
}
return result;
}
}

View File

@@ -0,0 +1,3 @@
add_mlir_library(TritonAnalysis
AxisInfo.cpp
)

View File

@@ -1,4 +1,5 @@
# add_subdirectory(codegen)
add_subdirectory(driver)
add_subdirectory(Analysis)
add_subdirectory(Conversion)
add_subdirectory(Dialect)

View File

@@ -240,6 +240,7 @@ void populateTritonPatterns(
) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
TritonGenericPattern<triton::SplatOp>,
TritonGenericPattern<triton::BroadcastOp>,
TritonGenericPattern<triton::GEPOp>,
TritonReducePattern,

View File

@@ -23,3 +23,8 @@ void TritonDialect::initialize() {
// We can also add interface here.
}
Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
}

View File

@@ -97,14 +97,22 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::
//-- DotOp --
//-- SplatOp --
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
return ret;
}
//-- BroadcastOp --
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
}

View File

@@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) {
.def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* {
return self.addEntryBlock();
}, ret::reference)
.def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){
// set arg attributes "name" to value "val"
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
}, ret::reference)
.def("reset_type", &mlir::FuncOp::setType)
;
@@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) {
.def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector<int64_t> &shape) -> mlir::Value {
auto loc = self.getUnknownLoc();
auto argType = arg.getType();
return self.create<mlir::triton::BroadcastOp>(
auto ret = self.createOrFold<mlir::triton::SplatOp>(
loc, mlir::RankedTensorType::get(shape, argType), arg
);
return ret;
})
// // atomic
.def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr,
@@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) {
.def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool {
return mlir::succeeded(self.run(mod.getOperation()));
})
.def("add_sccp_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createSCCPPass());
})
.def("add_symbol_dce_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createSymbolDCEPass());
})
.def("add_inliner_pass", [](mlir::PassManager &self) {
self.addPass(mlir::createInlinerPass());
})

View File

@@ -199,14 +199,8 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(cst)
else:
pass
# TODO: ...
# if i in self.attributes:
# is_ptr = fn.args[idx].type.is_ptr()
# attr = 'aligned' if is_ptr else 'multiple_of'
# attr = getattr(_triton.ir.attribute_kind, attr)
# attr = _triton.ir.attribute(attr, self.attributes[i])
# fn.add_attr(idx + 1, attr)
# fn.args[idx].name = arg_name
if i in self.attributes:
fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i])
arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx]))
idx += 1
@@ -1307,8 +1301,13 @@ class JITFunction:
raise CompilationError(self.src, node) from e
# cache num_warps & num_stages
self.num_warps, self.num_stages = num_warps, num_stages
# run simple SCCP and DCE here to clean-up the generated IR
mod = generator.module
pm = _triton.ir.pass_manager(context)
pm.add_canonicalizer_pass()
pm.run(mod)
# FIXME: now we need to return context, otherwise it will be deleted
return generator.module, context
return mod, context
def compile_ttir_to_llir(self, mod, ctx):
num_warps, num_stages = self.num_warps, self.num_stages

View File

@@ -0,0 +1,52 @@
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%cst = arith.constant dense<true> : tensor<128x128xi1>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
%4 = arith.muli %2, %3 : tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
%16 = arith.muli %14, %15 : tensor<1x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128]
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst, : tensor<128x128xf32>
return
}

View File

@@ -1,3 +1,5 @@
add_subdirectory(lib)
llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON
)

View File

@@ -0,0 +1,6 @@
add_mlir_library(TritonTestAnalysis
TestAxisInfo.cpp
LINK_LIBS PUBLIC
TritonAnalysis
)

View File

@@ -0,0 +1,67 @@
#include "triton/Analysis/AxisInfo.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace{
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>>{
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
void print(const std::string& name, raw_ostream& os, ArrayRef<int> vals){
os << name << ": [";
for(size_t d = 0; d < vals.size(); d++){
if(d != 0) os << ", ";
os << vals[d];
}
os << "]";
}
StringRef getArgument() const final { return "test-print-alignment"; }
StringRef getDescription() const final
{ return "print the result of the alignment analysis pass"; }
void runOnOperation() override {
Operation* operation = getOperation();
auto& os = llvm::errs();
os << "Testing: " << operation->getName() << "\n";
AxisInfoAnalysis analysis(&getContext());
analysis.run(operation);
operation->walk([&](Operation* op){
if(op->getNumResults() < 1)
return;
for(Value result: op->getResults()){
// std::ostringstream oss;
// result.print(oss);
// os << " => ";
LatticeElement<AxisInfo> *latticeElement = analysis.lookupLatticeElement(result);
if(!latticeElement){
os << "None\n";
return;
}
AxisInfo& info = latticeElement->getValue();
print("Contiguity", os, info.getContiguity());
os << " ; ";
print("Divisibility", os, info.getDivisibility());
os << " ; ";
print("Constancy", os, info.getConstancy());
os << " ( ";
result.print(os);
os << " ) ";
os << "\n";
}
});
}
};
}
namespace mlir{
namespace test{
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
}
}

1
test/lib/CMakeLists.txt Normal file
View File

@@ -0,0 +1 @@
add_subdirectory(Analysis)