From 0cbbcce5c0e8cd356704da71cc7c876e15813e36 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 8 Sep 2019 21:38:08 -0400 Subject: [PATCH] added missing file --- include/triton/codegen/transform/reorder.h | 39 +++++++++ lib/codegen/transform/reorder.cc | 96 ++++++++++++++++++++++ 2 files changed, 135 insertions(+) create mode 100644 include/triton/codegen/transform/reorder.h create mode 100644 lib/codegen/transform/reorder.cc diff --git a/include/triton/codegen/transform/reorder.h b/include/triton/codegen/transform/reorder.h new file mode 100644 index 000000000..19bffab03 --- /dev/null +++ b/include/triton/codegen/transform/reorder.h @@ -0,0 +1,39 @@ +#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H +#define TDL_INCLUDE_CODEGEN_OPTIMIZE_REORDER_H + +#include +#include + +namespace triton { + +namespace ir { + class module; + class value; +} + +namespace codegen{ + +namespace analysis{ + class align; + class meminfo; +} + +namespace transform{ + +class reorder { +public: + reorder(analysis::align* algin, analysis::meminfo* mem); + std::vector get_order(ir::value* v); + void run(ir::module &mod); + +private: + analysis::align* align_; + analysis::meminfo* mem_; + std::map> order_; +}; + +} +} +} + +#endif diff --git a/lib/codegen/transform/reorder.cc b/lib/codegen/transform/reorder.cc new file mode 100644 index 000000000..c5bc31d59 --- /dev/null +++ b/lib/codegen/transform/reorder.cc @@ -0,0 +1,96 @@ +#include +#include +#include +#include "triton/ir/function.h" +#include "triton/ir/cfg.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/ir/module.h" +#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/transform/reorder.h" + +namespace triton { +namespace codegen{ +namespace transform{ + +reorder::reorder(analysis::align* align, analysis::meminfo *mem) + : align_(align), mem_(mem) { } + +std::vector reorder::get_order(ir::value* v) { + std::cout << v->get_name() << std::endl; + return order_.at(v); +} + +void reorder::run(ir::module &mod) { + + std::set io; + + // initialize work-list + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: ir::cfg::reverse_post_order(fn)) + for(ir::instruction *i: block->get_inst_list()){ + if(auto *x = dynamic_cast(i)) { + ir::type* ptr_ty = x->get_pointer_operand()->get_type(); + if(ptr_ty->is_tile_ty()) + io.insert(x); + std::vector order(ptr_ty->get_tile_shapes().size()); + std::iota(order.begin(), order.end(), 0); + order_[i] = order; + } + } + + ir::builder &builder = mod.get_builder(); + for(ir::io_inst *i: io) { + ir::value *ptr = i->get_pointer_operand(); + auto max_contiguous = align_->get_max_contiguous_vec(ptr); + std::vector order(max_contiguous.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); + std::list work_list; + if(order != order_[i]){ + work_list.push_back(i); + } + // rematerialize recursively + while(!work_list.empty()) { + ir::instruction* current = work_list.back(); + order_[current] = order; + work_list.pop_back(); + for(ir::value *op: current->ops()) { + ir::instruction* i_op = dynamic_cast(op); + if(!i_op) + continue; + ir::type *ty = i_op->get_type(); + if(!ty->is_tile_ty()) + continue; + auto& inst_list = i_op->get_parent()->get_inst_list(); + auto it = std::find(inst_list.begin(), inst_list.end(), i_op); + it++; + builder.set_insert_point(it); + // found a load; write to shared memory and stop recursion + ir::instruction *n_op = nullptr; + if(mem_->is_shared(i_op)){ + continue; + } + if(auto* ld = dynamic_cast(i_op)) { + n_op = ir::copy_to_shared_inst::create(ld); + } + // not a load; rematerialize and recurse + else { + n_op = i_op->clone(); + work_list.push_back(n_op); + } + n_op = builder.insert(n_op); + order_[n_op] = order; + align_->copy(n_op, i_op); + current->replace_uses_of_with(i_op, n_op); + } + } + + } +} + + +} +} +}