[Analysis] Added Axis Info Analysis (#8)
This commit is contained in:
52
test/Analysis/test-alignment.mlir
Normal file
52
test/Analysis/test-alignment.mlir
Normal file
@@ -0,0 +1,52 @@
|
||||
// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
|
||||
|
||||
func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%cst = arith.constant dense<true> : tensor<128x128xi1>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1]
|
||||
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1]
|
||||
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1]
|
||||
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128]
|
||||
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst, : tensor<128x128xf32>
|
||||
return
|
||||
}
|
@@ -1,3 +1,5 @@
|
||||
add_subdirectory(lib)
|
||||
|
||||
llvm_canonicalize_cmake_booleans(
|
||||
MLIR_ENABLE_BINDINGS_PYTHON
|
||||
)
|
||||
|
6
test/lib/Analysis/CMakeLists.txt
Normal file
6
test/lib/Analysis/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
add_mlir_library(TritonTestAnalysis
|
||||
TestAxisInfo.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonAnalysis
|
||||
)
|
67
test/lib/Analysis/TestAxisInfo.cpp
Normal file
67
test/lib/Analysis/TestAxisInfo.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace{
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>>{
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||
|
||||
void print(const std::string& name, raw_ostream& os, ArrayRef<int> vals){
|
||||
os << name << ": [";
|
||||
for(size_t d = 0; d < vals.size(); d++){
|
||||
if(d != 0) os << ", ";
|
||||
os << vals[d];
|
||||
}
|
||||
os << "]";
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-print-alignment"; }
|
||||
StringRef getDescription() const final
|
||||
{ return "print the result of the alignment analysis pass"; }
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation* operation = getOperation();
|
||||
auto& os = llvm::errs();
|
||||
os << "Testing: " << operation->getName() << "\n";
|
||||
AxisInfoAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
operation->walk([&](Operation* op){
|
||||
if(op->getNumResults() < 1)
|
||||
return;
|
||||
for(Value result: op->getResults()){
|
||||
// std::ostringstream oss;
|
||||
// result.print(oss);
|
||||
// os << " => ";
|
||||
LatticeElement<AxisInfo> *latticeElement = analysis.lookupLatticeElement(result);
|
||||
if(!latticeElement){
|
||||
os << "None\n";
|
||||
return;
|
||||
}
|
||||
AxisInfo& info = latticeElement->getValue();
|
||||
print("Contiguity", os, info.getContiguity());
|
||||
os << " ; ";
|
||||
print("Divisibility", os, info.getDivisibility());
|
||||
os << " ; ";
|
||||
print("Constancy", os, info.getConstancy());
|
||||
os << " ( ";
|
||||
result.print(os);
|
||||
os << " ) ";
|
||||
os << "\n";
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace mlir{
|
||||
namespace test{
|
||||
void registerTestAlignmentPass() { PassRegistration<TestAxisInfoPass>(); }
|
||||
}
|
||||
}
|
||||
|
1
test/lib/CMakeLists.txt
Normal file
1
test/lib/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(Analysis)
|
Reference in New Issue
Block a user