diff --git a/CMakeLists.txt b/CMakeLists.txt index c476f7eb8..6eac81fe1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,6 +183,7 @@ set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") target_link_libraries(triton ${PYTHON_LIBRARIES} TritonIR + TritonAnalysis TritonTransforms TritonDriver TritonToTritonGPU diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 0fb8f28fb..92635eca3 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -10,10 +10,14 @@ add_llvm_executable(triton-opt triton-opt.cpp) # TODO: what's this? # llvm_update_compile_flags(triton-opt) target_link_libraries(triton-opt PRIVATE + TritonAnalysis TritonTransforms TritonGPUTransforms ${dialect_libs} ${conversion_libs} + # tests + TritonTestAnalysis + # MLIR core MLIROptLib MLIRPass MLIRTransforms diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 9975ed63d..d5d73e5f6 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -10,11 +10,17 @@ #include "mlir/InitAllPasses.h" #include "mlir/Support/MlirOptMain.h" +namespace mlir{ +namespace test{ +void registerTestAlignmentPass(); +} +} int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::registerTritonPasses(); mlir::registerTritonGPUPasses(); + mlir::test::registerTestAlignmentPass(); mlir::triton::registerConvertTritonToTritonGPUPass(); // TODO: register Triton & TritonGPU passes diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h new file mode 100644 index 000000000..0910f341f --- /dev/null +++ b/include/triton/Analysis/AxisInfo.h @@ -0,0 +1,145 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "llvm/Support/raw_ostream.h" +#include + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +/// Axis information is represented by a std::map +class AxisInfo { +public: + typedef std::vector ContiguityT; + typedef std::vector DivisibilityT; + typedef std::vector ConstancyT; + +public: + // Default constructor + AxisInfo(): AxisInfo({}, {}, {}) { } + // Construct contiguity info with known contiguity + AxisInfo(ContiguityT knownContiguity, DivisibilityT knownDivisibility, + ConstancyT knownConstancy) + : contiguity(knownContiguity), divisibility(knownDivisibility), + constancy(knownConstancy), rank(contiguity.size()) { + assert(knownDivisibility.size() == rank); + assert(knownConstancy.size() == rank); + } + + + // Accessors + int getContiguity(size_t d) const { return contiguity[d]; } + const ContiguityT& getContiguity() const { return contiguity; } + + int getDivisibility(size_t d) const { return divisibility[d]; } + const DivisibilityT& getDivisibility() const { return divisibility; } + + int getConstancy(size_t d) const { return constancy[d]; } + const ConstancyT& getConstancy() const { return constancy; } + + int getRank() const { return rank; } + + // Comparison + bool operator==(const AxisInfo &other) const { + return (contiguity == other.contiguity) && + (divisibility == other.divisibility) && + (constancy == other.constancy); + } + + /// The pessimistic value state of the contiguity is unknown. + static AxisInfo getPessimisticValueState(MLIRContext *context) + { return AxisInfo(); } + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, + const AxisInfo &rhs); + +private: + /// The _contiguity_ information maps the `d`-th + /// dimension to the length of the shortest + /// sequence of contiguous integers along it + /// For example: + /// [10, 11, 12, 13, 18, 19, 20, 21] + /// [20, 21, 22, 23, 28, 29, 30, 31] + /// Would have contiguity [1, 4]. + /// and + /// [12, 16, 20, 24] + /// [13, 17, 21, 25] + /// [14, 18, 22, 26] + /// [15, 19, 23, 27] + /// [18, 22, 26, 30] + /// [19, 23, 27, 31] + /// Would have contiguity [2, 1]. + ContiguityT contiguity; + + /// The _divisibility_ information maps the `d`-th + /// dimension to the largest power-of-two that + /// divides the first element of all the values along it + /// For example: + /// [10, 11, 12, 13, 18, 19, 20, 21] + /// [20, 21, 22, 23, 28, 29, 30, 31] + // would have divisibility [1, 2] + // and + /// [12, 16, 20, 24] + /// [13, 17, 21, 25] + /// [14, 18, 22, 26] + /// [15, 19, 23, 27] + // would have divisibility [4, 1] + DivisibilityT divisibility; + + /// The _constancy_ information maps the `d`-th + /// dimension to the length of the shortest + /// sequence of constant integer along it. This is + /// particularly useful to infer the contiguity + /// of operations (e.g., add) involving a constant + /// For example + /// [8, 8, 8, 8, 12, 12, 12, 12] + /// [16, 16, 16, 16, 20, 20, 20, 20] + /// would have constancy [1, 4] + ConstancyT constancy; + + // number of dimensions of the lattice + int rank; +}; + + +class AxisInfoAnalysis + : public ForwardDataFlowAnalysis { + +private: + static const int maxPow2Divisor = 65536; + + int highestPowOf2Divisor(int n){ + if(n==0) + return maxPow2Divisor; + return (n & (~(n - 1))); + } + + AxisInfo visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo, + const std::function& getContiguity, + const std::function& getDivisibility, + const std::function& getConstancy); + +public: + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + + ChangeResult visitOperation(Operation *op, + ArrayRef *> operands) override; + +}; + + +} + + +#endif \ No newline at end of file diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index 83e207fb4..3c5611448 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -33,6 +33,8 @@ def Triton_Dialect : Dialect { let extraClassDeclaration = [{ void registerTypes(); }]; + + let hasConstantMaterializer = 1; } #endif // TRITON_DIALECT diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index efc90f36d..d8e51ed06 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -143,8 +143,20 @@ def TT_ReshapeOp : TT_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementT let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } +def TT_SplatOp : TT_Op<"splat", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; + + let hasFolder = 1; +} + def TT_BroadcastOp : TT_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "broadcast"; + let summary = "broadcast. No left-padding as of now."; let arguments = (ins TT_Type:$src); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index bc2ea5ed3..08873acd4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -59,6 +59,9 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", } // Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. +// This is needed because Arith's Cmp ops don't +// handle encodings +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111 def TTG_CmpIOp : TTG_Op<"cmpi"> { let summary = "integer comparison operation"; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp new file mode 100644 index 000000000..6222e5261 --- /dev/null +++ b/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,239 @@ +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "llvm/Support/raw_ostream.h" +#include + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + + // Function for extended Euclidean Algorithm +static int gcd_impl(int a, int b, int *x, int *y){ + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int x1, y1; // To store results of recursive call + int gcd = gcd_impl(b%a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b/a) * x1; + *y = x1; + return gcd; +} + +static int gcd(int a, int b) { + int x, y; + return gcd_impl(a, b, &x, &y); +} + + +AxisInfo AxisInfo::getPessimisticValueState(Value value) { + size_t rank = 1; + if(TensorType ty = value.getType().dyn_cast()) + rank = ty.getRank(); + int divHint = 1; + if(BlockArgument blockArg = value.dyn_cast()){ + Operation* op = blockArg.getOwner()->getParentOp(); + if(FuncOp fun = dyn_cast(op)){ + Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); + if(attr) + divHint = attr.cast().getValue().getZExtValue(); + } + } + ContiguityT contiguity(rank, 1); + DivisibilityT divisibility(rank, divHint); + ConstancyT constancy(rank, 1); + return AxisInfo(contiguity, divisibility, constancy); +} + + +// The gcd of both arguments for each dimension +AxisInfo AxisInfo::join(const AxisInfo &lhs, + const AxisInfo &rhs) { + ContiguityT retContiguity; + DivisibilityT retDivisibility; + ConstancyT retConstancy; + for(size_t d = 0; d < lhs.getRank(); d++){ + retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + retDivisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + return AxisInfo(retContiguity, retDivisibility, retConstancy); +} + + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfo AxisInfoAnalysis::visitBinaryOp(Operation* op, AxisInfo lhsInfo, AxisInfo rhsInfo, + const std::function& getContiguity, + const std::function& getDivisibility, + const std::function& getConstancy) { + int rank = lhsInfo.getRank(); + AxisInfo::ContiguityT newContiguity; + AxisInfo::DivisibilityT newDivisibility; + AxisInfo::ConstancyT newConstancy; + for(size_t d = 0; d < rank; d++){ + newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); + newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); + newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); + } + return AxisInfo(newContiguity, newDivisibility, newConstancy); +} + +ChangeResult AxisInfoAnalysis::visitOperation(Operation *op, + ArrayRef *> operands) { + AxisInfo curr; + // This preserves the input axes (e.g., cast): + if (llvm::isa(op)) + curr = operands[0]->getValue(); + // Constant ranges + if (triton::MakeRangeOp make_range = llvm::dyn_cast(op)){ + int start = make_range.start(); + int end = make_range.end(); + AxisInfo::ContiguityT contiguity = {end - start}; + AxisInfo::DivisibilityT divisibility = {highestPowOf2Divisor(start)}; + AxisInfo::ConstancyT constancy = {1}; + curr = AxisInfo(contiguity, divisibility, constancy); + } + // Constant + if (arith::ConstantOp constant = llvm::dyn_cast(op)){ + auto intAttr = constant.getValue().dyn_cast(); + if(intAttr){ + size_t val = intAttr.getValue().getZExtValue(); + curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1}); + } + // TODO: generalize to dense attr + auto splatAttr = constant.getValue().dyn_cast(); + if(splatAttr && splatAttr.getElementType().isInteger(32)){ + auto value = splatAttr.getSplatValue(); + TensorType ty = splatAttr.getType().cast(); + curr = AxisInfo(AxisInfo::ContiguityT(ty.getRank(), 1), + AxisInfo::DivisibilityT(ty.getRank(), highestPowOf2Divisor(value)), + AxisInfo::ConstancyT(ty.getShape().begin(), ty.getShape().end())); + + } + } + // Addition + if (llvm::isa(op)){ + auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d){ + return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)), + gcd(lhs.getConstancy(d), rhs.getContiguity(d))); + }; + auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d){ + return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); + }; + auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d){ + return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); + }; + curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), + newContiguity, newDivisibility, newConstancy); + } + // Multiplication + if (llvm::isa(op)){ + auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d){ + return 1; + }; + auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d){ + return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); + }; + auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d){ + return lhs.getDivisibility(d)*rhs.getDivisibility(d); + }; + curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), + newContiguity, newDivisibility, newConstancy); + } + // Splat + if (llvm::isa(op)){ + Type _retTy = *op->result_type_begin(); + TensorType retTy = _retTy.cast(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::ContiguityT contiguity; + AxisInfo::DivisibilityT divisibility; + AxisInfo::ConstancyT constancy; + for(size_t d = 0; d < retTy.getRank(); d++){ + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + curr = AxisInfo(contiguity, divisibility, constancy); + } + // Reshape + // TODO: Replace by `unsqueeze` + if (llvm::isa(op)){ + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = _retTy.cast(); + TensorType opTy = _opTy.cast(); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::ContiguityT contiguity; + AxisInfo::DivisibilityT divisibility; + AxisInfo::ConstancyT constancy; + bool is_skewed = false; + size_t current = 0; + for(size_t d = 0; d < retTy.getRank(); d++){ + if(retShape[d] == 1){ + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back(1); + } + else if(!is_skewed + && retShape[d] == opShape[current]){ + contiguity.push_back(opInfo.getContiguity()[current]); + divisibility.push_back(opInfo.getDivisibility()[current]); + constancy.push_back(opInfo.getConstancy()[current]); + current++; + } + else { + is_skewed = true; + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back(1); + } + } + curr = AxisInfo(contiguity, divisibility, constancy); + } + // Broadcast + if (llvm::isa(op)){ + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = _retTy.cast(); + TensorType opTy = _opTy.cast(); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::ContiguityT contiguity; + AxisInfo::DivisibilityT divisibility; + AxisInfo::ConstancyT constancy; + for(size_t d = 0; d < retTy.getRank(); d++){ + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] : 1); + } + curr = AxisInfo(contiguity, divisibility, constancy); + } + if(curr.getRank() == 0){ + return markAllPessimisticFixpoint(op->getResults()); + } + // join all latice elements + ChangeResult result = ChangeResult::NoChange; + for (Value value : op->getResults()) { + result |= getLatticeElement(value).join(curr); + } + return result; +} + +} \ No newline at end of file diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..896939e53 --- /dev/null +++ b/lib/Analysis/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_library(TritonAnalysis + AxisInfo.cpp +) \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 480882592..d19e99c0b 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,4 +1,5 @@ # add_subdirectory(codegen) add_subdirectory(driver) +add_subdirectory(Analysis) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 94bc4c696..22663b504 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -240,6 +240,7 @@ void populateTritonPatterns( ) { MLIRContext *context = patterns.getContext(); patterns.add, + TritonGenericPattern, TritonGenericPattern, TritonGenericPattern, TritonReducePattern, diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 0b0adef6d..ff7ce0436 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -23,3 +23,8 @@ void TritonDialect::initialize() { // We can also add interface here. } + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return builder.create(loc, type, value); +} \ No newline at end of file diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index e5416bf19..fd911b7a3 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -97,14 +97,22 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, :: //-- DotOp -- +//-- SplatOp -- +OpFoldResult SplatOp::fold(ArrayRef operands) { + auto constOperand = src().getDefiningOp(); + if (!constOperand) + return {}; + auto shapedType = getType().cast(); + auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()}); + return ret; +} + //-- BroadcastOp -- OpFoldResult BroadcastOp::fold(ArrayRef operands) { auto constOperand = src().getDefiningOp(); if (!constOperand) return {}; - auto shapedType = getType().cast(); - return SplatElementsAttr::get(shapedType, {constOperand.getValue()}); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 53bf96d06..ba008cac4 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -788,6 +788,11 @@ void init_triton_ir(py::module &&m) { .def("add_entry_block", [](mlir::FuncOp &self) -> mlir::Block* { return self.addEntryBlock(); }, ret::reference) + .def("set_arg_attr", [](mlir::FuncOp &self, int arg_no, const std::string& name, int val){ + // set arg attributes "name" to value "val" + auto attrTy = mlir::IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); + }, ret::reference) .def("reset_type", &mlir::FuncOp::setType) ; @@ -1265,9 +1270,10 @@ void init_triton_ir(py::module &&m) { .def("create_splat", [](mlir::OpBuilder &self, mlir::Value &arg, std::vector &shape) -> mlir::Value { auto loc = self.getUnknownLoc(); auto argType = arg.getType(); - return self.create( + auto ret = self.createOrFold( loc, mlir::RankedTensorType::get(shape, argType), arg ); + return ret; }) // // atomic .def("create_atomic_cas", [](mlir::OpBuilder &self, mlir::Value &ptr, @@ -1337,6 +1343,12 @@ void init_triton_ir(py::module &&m) { .def("run", [](mlir::PassManager &self, mlir::ModuleOp &mod) -> bool { return mlir::succeeded(self.run(mod.getOperation())); }) + .def("add_sccp_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createSCCPPass()); + }) + .def("add_symbol_dce_pass", [](mlir::PassManager &self) { + self.addPass(mlir::createSymbolDCEPass()); + }) .def("add_inliner_pass", [](mlir::PassManager &self) { self.addPass(mlir::createInlinerPass()); }) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 7cac37cee..11ed31381 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -199,14 +199,8 @@ class CodeGenerator(ast.NodeVisitor): arg_values.append(cst) else: pass - # TODO: ... - # if i in self.attributes: - # is_ptr = fn.args[idx].type.is_ptr() - # attr = 'aligned' if is_ptr else 'multiple_of' - # attr = getattr(_triton.ir.attribute_kind, attr) - # attr = _triton.ir.attribute(attr, self.attributes[i]) - # fn.add_attr(idx + 1, attr) - # fn.args[idx].name = arg_name + if i in self.attributes: + fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i]) arg_values.append(triton.language.tensor(fn.args(idx), self.prototype.param_types[idx])) idx += 1 @@ -1307,8 +1301,13 @@ class JITFunction: raise CompilationError(self.src, node) from e # cache num_warps & num_stages self.num_warps, self.num_stages = num_warps, num_stages + # run simple SCCP and DCE here to clean-up the generated IR + mod = generator.module + pm = _triton.ir.pass_manager(context) + pm.add_canonicalizer_pass() + pm.run(mod) # FIXME: now we need to return context, otherwise it will be deleted - return generator.module, context + return mod, context def compile_ttir_to_llir(self, mod, ctx): num_warps, num_stages = self.num_warps, self.num_stages diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir new file mode 100644 index 000000000..3d2b31d58 --- /dev/null +++ b/test/Analysis/test-alignment.mlir @@ -0,0 +1,52 @@ +// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s + +func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] + %cst = arith.constant dense : tensor<128x128xi1> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + // CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1] + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: Contiguity: [128] ; Divisibility: [65536] ; Constancy: [1] + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] + %2 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] + %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1048576, 16] ; Constancy: [1, 1] + %4 = arith.muli %2, %3 : tensor<128x1xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] + %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] + %6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr> + // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] + %7 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] + %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1] + %9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32> + // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] + %10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr> + // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1] + %11 = tt.reshape %0 : (tensor<128xi32>) -> tensor<128x1xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> + // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] + %13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr> + // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1] + %14 = tt.reshape %1 : (tensor<128xi32>) -> tensor<1x128xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] + %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [1, 1] + %16 = arith.muli %14, %15 : tensor<1x128xi32> + // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] + %17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1] + %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> + // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] + %19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr> + // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] + %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> + tt.store %19, %20, %cst, : tensor<128x128xf32> + return +} \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b1e0ec7df..a37b680a5 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(lib) + llvm_canonicalize_cmake_booleans( MLIR_ENABLE_BINDINGS_PYTHON ) diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..2f45f2e53 --- /dev/null +++ b/test/lib/Analysis/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_library(TritonTestAnalysis + TestAxisInfo.cpp + + LINK_LIBS PUBLIC + TritonAnalysis +) \ No newline at end of file diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp new file mode 100644 index 000000000..fd6493cbf --- /dev/null +++ b/test/lib/Analysis/TestAxisInfo.cpp @@ -0,0 +1,67 @@ +#include "triton/Analysis/AxisInfo.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace{ + +struct TestAxisInfoPass + : public PassWrapper>{ + + // LLVM15+ + // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass); + + void print(const std::string& name, raw_ostream& os, ArrayRef vals){ + os << name << ": ["; + for(size_t d = 0; d < vals.size(); d++){ + if(d != 0) os << ", "; + os << vals[d]; + } + os << "]"; + } + + StringRef getArgument() const final { return "test-print-alignment"; } + StringRef getDescription() const final + { return "print the result of the alignment analysis pass"; } + + void runOnOperation() override { + Operation* operation = getOperation(); + auto& os = llvm::errs(); + os << "Testing: " << operation->getName() << "\n"; + AxisInfoAnalysis analysis(&getContext()); + analysis.run(operation); + operation->walk([&](Operation* op){ + if(op->getNumResults() < 1) + return; + for(Value result: op->getResults()){ + // std::ostringstream oss; + // result.print(oss); + // os << " => "; + LatticeElement *latticeElement = analysis.lookupLatticeElement(result); + if(!latticeElement){ + os << "None\n"; + return; + } + AxisInfo& info = latticeElement->getValue(); + print("Contiguity", os, info.getContiguity()); + os << " ; "; + print("Divisibility", os, info.getDivisibility()); + os << " ; "; + print("Constancy", os, info.getConstancy()); + os << " ( "; + result.print(os); + os << " ) "; + os << "\n"; + } + }); + } +}; + +} + +namespace mlir{ +namespace test{ +void registerTestAlignmentPass() { PassRegistration(); } +} +} + diff --git a/test/lib/CMakeLists.txt b/test/lib/CMakeLists.txt new file mode 100644 index 000000000..5c6d3ffe1 --- /dev/null +++ b/test/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Analysis) \ No newline at end of file