From 274d61348837fec1e12bdea2ff59bfd5f21d0583 Mon Sep 17 00:00:00 2001 From: daadaada Date: Thu, 2 Sep 2021 00:55:12 +0800 Subject: [PATCH] [IR] Better printer (#256) --- include/triton/ir/basic_block.h | 3 + include/triton/ir/function.h | 8 +- include/triton/ir/instructions.h | 2 + include/triton/ir/module.h | 2 + include/triton/ir/value.h | 2 + lib/codegen/pass.cc | 1 + lib/ir/print.cc | 320 +++++++++++++++++++++++++++++++ 7 files changed, 337 insertions(+), 1 deletion(-) diff --git a/include/triton/ir/basic_block.h b/include/triton/ir/basic_block.h index 3d274815a..840145246 100644 --- a/include/triton/ir/basic_block.h +++ b/include/triton/ir/basic_block.h @@ -39,6 +39,7 @@ public: // get instruction list inst_list_t &get_inst_list() { return inst_list_; } + const inst_list_t &get_inst_list() const { return inst_list_; } void erase(instruction *i) { inst_list_.remove(i); } // instruction iterator functions @@ -67,6 +68,8 @@ public: // factory functions static basic_block* create(context &ctx, const std::string &name, function *parent); + void print(std::ostream &os); + // visitor void accept(visitor *v) { v->visit_basic_block(this); } diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 2a944fbb5..9e1bc981a 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -104,19 +104,25 @@ public: // accessors const args_t &args() const { return args_; } function_type* get_fn_type() { return fn_ty_; } + const function_type* get_fn_type() const { return fn_ty_; } + module *get_parent() { return parent_; } + const module *get_parent() const { return parent_; } // factory methods static function *create(function_type *ty, linkage_types_t linkage, const std::string &name, module *mod); // blocks const blocks_t &blocks() { return blocks_; } + const blocks_t &blocks() const { return blocks_; } void insert_block(basic_block* block, basic_block *next = nullptr); // attributes void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); } const attr_map_t &attrs() { return attrs_; } bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); } - std::set get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; } + std::set get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; } + + void print(std::ostream &os); // visitor void accept(visitor *v) { v->visit_function(this); } diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index c9db25477..3a5011276 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -73,6 +73,8 @@ public: // instruction id value_id_t get_id() const { return id_; } + void print(std::ostream &os); + private: basic_block *parent_; std::map metadatas_; diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index fb4e6455b..b350e3cc9 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -87,6 +87,8 @@ public: // Metadata void add_metadata(const std::string &name, md_pair_t x) { metadatas_[name] = x; } + void print(std::ostream &os); + private: std::string name_; builder& builder_; diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h index e1599d6bc..7a132d5e2 100644 --- a/include/triton/ir/value.h +++ b/include/triton/ir/value.h @@ -35,6 +35,7 @@ public: // name void set_name(const std::string &name); const std::string &get_name() const { return name_; } + bool has_name() const { return !name_.empty(); } type* get_type() const { return ty_; } // visitor virtual void accept(visitor *v) = 0; @@ -70,6 +71,7 @@ public: // Operands const ops_t& ops() { return ops_; } + const ops_t& ops() const { return ops_; } op_iterator op_begin() { return ops_.begin(); } op_iterator op_end() { return ops_.end(); } void set_operand(unsigned i, value *x); diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 82fe61257..9b96d77e1 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -93,6 +93,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // ir.print(std::cout); isel.visit(ir, *llvm); mod = driver::module::create(dev, std::move(llvm)); ker = driver::kernel::create(&*mod, name.c_str()); diff --git a/lib/ir/print.cc b/lib/ir/print.cc index 1552193fa..f63c81587 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -2,14 +2,334 @@ #include "triton/ir/basic_block.h" #include "triton/ir/module.h" #include "triton/ir/type.h" +#include "triton/ir/value.h" #include "triton/ir/constant.h" #include "triton/ir/function.h" #include "triton/ir/instructions.h" #include "triton/ir/print.h" +#include +#include + namespace triton{ namespace ir{ +namespace { +class SlotTracker { + // A mapping of values to slot numbers. + using value_map = std::map; + + // The module for which we are holding slot numbers. + const module *mod_; + bool module_processed = false; + + // The function for which we are holding slot numbers. + const function *func_ = nullptr; + bool function_processed = false; + + // m_map - The slot map for the module level data. + value_map m_map; + unsigned m_next = 0; + + // f_map - The slot map for the function level data. + value_map f_map; + unsigned f_next = 0; + +public: + // Construct from a module + explicit SlotTracker(const module *mod) : mod_(mod) {} + + // Construct from a function + explicit SlotTracker(const function *f) + : mod_(f? f->get_parent() : nullptr), func_(f) {} + + // Return the slot number of the specified value. If something is not in + // the SlotTracker, return -1 + int get_local_slot(const value *v); + + void initialize_if_needed(); + + // If you'd like to deal with a function instead of just a module, use + // this method to get its data into the SlotTracker + void incorporate_function(const function *f) { + func_ = f; + function_processed = false; + } + +private: + // Add all of the module level global variables (and their initializers) + // and function declarations, but not contents of those functions. + void process_module(); + + // Add all of the functions arguments, basic blocks, and instructions. + void process_function(); + + // Insert specified value* into the slot table + void create_function_slot(const value *v); +}; + +class AssemblyWriter { + std::ostream &os; + SlotTracker &slot_tracker; + +public: + AssemblyWriter(std::ostream &os, SlotTracker &slot_tracker) + : os(os), slot_tracker(slot_tracker) {} + + void print_module(const module *mod); + void print_function(const function *f); + void print_argument(const argument *arg); + void print_basic_block(const basic_block *bb); + void print_instruction(const instruction *instr); + void print_value(const value *v); + + void write_operand(const value *op, bool print_type = false); +}; +} // anonymous namespace + +//------------------------- +// SlotTracker +//------------------------- +void SlotTracker::process_module() { + // Nothing to do at the moment. + // Create slots for global variable & unamed functions & ... + module_processed = true; +} + +void SlotTracker::process_function() { + f_next = 0; + + // Add all the function arguments with no names. + for (const argument *arg : func_->args()) + if (!arg->has_name()) + create_function_slot(arg); + + // Add all of the basic blocks and instructions with no names. + for (const basic_block *bb : func_->blocks()) { + if (!bb->has_name()) + create_function_slot(bb); + + for (const instruction *instr : bb->get_inst_list()) { + if (!instr->get_type()->is_void_ty() && !instr->has_name()) + create_function_slot(instr); + } + } + + function_processed = true; +} + +void SlotTracker::create_function_slot(const value *v) { + assert(!v->get_type()->is_void_ty() && !v->has_name() && "Doesn't need a slot"); + + unsigned dst_slot = f_next++; + f_map[v] = dst_slot; +} + +int SlotTracker::get_local_slot(const value *v) { + assert(dynamic_cast(v) == nullptr && "Can't get a constant slot"); + + // Check for uninitialized state and do lazy initialization. + initialize_if_needed(); + + value_map::iterator f_iter = f_map.find(v); + return f_iter == f_map.end() ? -1 : (int)f_iter->second; +} + +void SlotTracker::initialize_if_needed() { + if (mod_ && !module_processed) + process_module(); + + if (func_ && !function_processed) + process_function(); +} + + +//------------------------------- +// AssemblyWriter +//------------------------------- +void AssemblyWriter::write_operand(const value *operand, bool print_type) { + if (!operand) { + os << ""; + return; + } + + if (auto *c = dynamic_cast(operand)) { + os << c->repr(); + return; + } + + if (operand->has_name()) { + os << operand->get_name(); + return; + } + + // Print the normal way + int slot_num = slot_tracker.get_local_slot(operand); + + if (slot_num != -1) + os << "%" << slot_num; + else + os << ""; +} + +void AssemblyWriter::print_module(const module *mod) { + slot_tracker.initialize_if_needed(); + // ;ModuleID = ... + // source_filename = ... + + // Print all of the functions. + for (function *f : mod->get_function_list()) { + os << "\n"; + print_function(f); + } +} + +void AssemblyWriter::print_function(const function *f) { + // Annotation & Attributes + + slot_tracker.incorporate_function(f); + + os << "def "; + ir::type *rt_type = f->get_fn_type()->get_return_ty(); + // Functions must have names. + os << rt_type->repr() << " " << f->get_name() << "("; + // Print arguments + for (ir::argument *arg : f->args()) { + if (arg->get_arg_no() > 0) + os << ", "; + print_argument(arg); + } + os << ")"; + + // Print function body + os << "{"; + for (const basic_block *bb : f->blocks()) + print_basic_block(bb); + os << "}\n"; +} + +void AssemblyWriter::print_argument(const argument *arg) { + // Print type + os << arg->get_type()->repr(); + + // Print name, if available. + if (arg->has_name()) + os << " " << arg->get_name(); + else { + int slot_num = slot_tracker.get_local_slot(arg); + assert(slot_num != -1 && "expect argument in function here"); + os << " %" << slot_num; + } + + // Print attributes + std::set attrs = arg->get_parent()->get_attributes(arg); + for (attribute attr : attrs) + os << " " << attr.repr(); +} + +void AssemblyWriter::print_basic_block(const basic_block *bb) { + // bb label + if (bb->has_name()) { + os << "\n"; + os << bb->get_name() << ":"; + } else { + os << "\n"; + int slot_num = slot_tracker.get_local_slot(bb); + if (slot_num != -1) + os << slot_num << ":"; + else + os << ":"; + } + + // Print predecessors for the block + auto const &predecessors = bb->get_predecessors(); + if (!predecessors.empty()) { + os << std::setw(50) << std::setfill(' ') + << "; preds = "; + for (size_t i=0; iget_inst_list()) + print_instruction(instr); +} + +void AssemblyWriter::print_instruction(const instruction *instr) { + // Print out indentation for an instruction. + os << " "; + + ir::type *type = instr->get_type(); + if (instr->has_name()) { + os << instr->get_name(); + os << " = "; + } else if (!type->is_void_ty()) { + // Print out the def slot taken. + int slot_num = slot_tracker.get_local_slot(instr); + if (slot_num == -1) + os << " = "; + else + os << "%" << slot_num << " = "; + } + + // Print out opcode + os << instr->repr() << " " << type->repr(); + + size_t num_ops = instr->get_num_operands(); + if (num_ops > 0) + os << " "; + ir::instruction::ops_t ops = instr->ops(); + for (unsigned i = 0; i < num_ops; ++i) { + if (i) + os << ", "; + write_operand(ops[i]); + } + + os << ";\n"; +} + +void AssemblyWriter::print_value(const value *v) { + // Not implemented +} + + +//------------------------------- +// External interface +//------------------------------- +void module::print(std::ostream &os) { + SlotTracker slot_tracker(this); + AssemblyWriter writer(os, slot_tracker); + writer.print_module(this); +} + +void function::print(std::ostream &os) { + SlotTracker slot_tracker(this); + AssemblyWriter writer(os, slot_tracker); + writer.print_function(this); +} + +void basic_block::print(std::ostream &os) { + SlotTracker slot_tracker(this->get_parent()); + AssemblyWriter writer(os, slot_tracker); + writer.print_basic_block(this); +} + +void instruction::print(std::ostream &os) { + SlotTracker slot_tracker(this->get_parent()->get_parent()); + AssemblyWriter writer(os, slot_tracker); + writer.print_instruction(this); +} + +//------------------------------- +// legacy print interface +//------------------------------- std::string get_name(ir::value *v, unsigned i) { if(v->get_name().empty()){ std::string name = "%" + std::to_string(i);