diff --git a/include/triton/codegen/analysis/grid.h b/include/triton/codegen/analysis/grid.h index c361db260..1eb352b00 100644 --- a/include/triton/codegen/analysis/grid.h +++ b/include/triton/codegen/analysis/grid.h @@ -19,7 +19,7 @@ namespace ir{ namespace codegen{ namespace transform{ -class reorder; +class coalesce; } namespace analysis{ @@ -46,7 +46,7 @@ private: public: - grids(size_t num_warps, transform::reorder* reorder); + grids(size_t num_warps, transform::coalesce* reorder); ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; } unsigned get_param_group(ir::value *value, unsigned ax); fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); } @@ -66,7 +66,7 @@ private: std::vector grids_; std::map> groups_; size_t num_warps_; - transform::reorder* reorder_; + transform::coalesce* reorder_; }; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 842f544aa..3efe0a256 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -50,7 +50,7 @@ class meminfo; } namespace transform{ -class reorder; +class coalesce; } class target; @@ -195,7 +195,7 @@ private: public: - selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::reorder* reorder, target *tgt) + selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt) : alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt){ } void run(ir::module &src, Module &dst); @@ -207,7 +207,7 @@ private: analysis::grids *params_; analysis::meminfo *buffer_info_; analysis::align *alignment_; - transform::reorder *reorder_; + transform::coalesce *reorder_; target *tgt_; std::map axes_; Value *sh_mem_ptr_; diff --git a/include/triton/codegen/transform/reorder.h b/include/triton/codegen/transform/coalesce.h similarity index 87% rename from include/triton/codegen/transform/reorder.h rename to include/triton/codegen/transform/coalesce.h index 19bffab03..e78010703 100644 --- a/include/triton/codegen/transform/reorder.h +++ b/include/triton/codegen/transform/coalesce.h @@ -20,9 +20,9 @@ namespace analysis{ namespace transform{ -class reorder { +class coalesce { public: - reorder(analysis::align* algin, analysis::meminfo* mem); + coalesce(analysis::align* algin, analysis::meminfo* mem); std::vector get_order(ir::value* v); void run(ir::module &mod); diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc index a33c4c25d..43a3eb1d9 100644 --- a/lib/codegen/analysis/grid.cc +++ b/lib/codegen/analysis/grid.cc @@ -1,6 +1,6 @@ #include #include -#include "triton/codegen/transform/reorder.h" +#include "triton/codegen/transform/coalesce.h" #include "triton/codegen/analysis/grid.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" @@ -16,7 +16,7 @@ namespace triton{ namespace codegen{ namespace analysis{ -grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder) +grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), reorder_(reorder) { } bool is_hmma(ir::value *v){ @@ -298,7 +298,7 @@ void grids::run(ir::module &mod) { unsigned current = num_threads; std::string nts = "nts.d" + s_ld; std::string mts = "mts.d" + s_ld; - params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 1)); + params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 8)); params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value())); current = current / params_.at(i).at(mts)->get_value(); for(size_t d = 1; d < shapes.size(); d++){ diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 271de7640..f452cc384 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -3,7 +3,7 @@ #include "triton/codegen/analysis/grid.h" #include "triton/codegen/analysis/memalloc.h" #include "triton/codegen/analysis/align.h" -#include "triton/codegen/transform/reorder.h" +#include "triton/codegen/transform/coalesce.h" #include "triton/ir/context.h" #include "triton/ir/module.h" #include "triton/ir/function.h" diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc new file mode 100644 index 000000000..29a87129c --- /dev/null +++ b/lib/codegen/transform/coalesce.cc @@ -0,0 +1,110 @@ +#include +#include +#include +#include "triton/ir/function.h" +#include "triton/ir/cfg.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/ir/module.h" +#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/transform/coalesce.h" + +namespace triton { +namespace codegen{ +namespace transform{ + +coalesce::coalesce(analysis::align* align, analysis::meminfo *mem) + : align_(align), mem_(mem) { } + +std::vector coalesce::get_order(ir::value* v) { + return order_.at(v); +} + +void coalesce::run(ir::module &mod) { + + std::set io; + + std::function set_order = [&](ir::value *v) -> void { + if(order_.find(v) != order_.end()) + return; + order_[v] = {}; + if(ir::user* u = dynamic_cast(v)) + for(ir::value* op: u->ops()) + set_order(op); + ir::type* ty = v->get_type(); + if(!ty->is_tile_ty()) + return; + std::vector order(ty->get_tile_shapes().size()); + std::iota(order.begin(), order.end(), 0); + order_[v] = order; + }; + + // initialize work-list + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: ir::cfg::reverse_post_order(fn)) + for(ir::instruction *i: block->get_inst_list()){ + if(auto *x = dynamic_cast(i)) { + ir::type* ptr_ty = x->get_pointer_operand()->get_type(); + if(ptr_ty->is_tile_ty()) + io.insert(x); + } + set_order(i); + } + +// ir::builder &builder = mod.get_builder(); +// std::set seen; +// for(ir::io_inst *i: io) { +// ir::value *ptr = i->get_pointer_operand(); +// auto max_contiguous = align_->get_max_contiguous_vec(ptr); +// std::vector order(max_contiguous.size()); +// std::iota(order.begin(), order.end(), 0); +// std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); +// std::list work_list; +// if(order != order_[i]) +// work_list.push_back(i); +// // rematerialize recursively +// while(!work_list.empty()) { +// ir::instruction* current = work_list.back(); +// order_[current] = order; +// work_list.pop_back(); +// for(ir::value *op: current->ops()) { +// ir::instruction* i_op = dynamic_cast(op); +// if(!seen.insert(op).second) +// continue; +// if(!i_op) +// continue; +// ir::type *ty = i_op->get_type(); +// if(!ty->is_tile_ty()) +// continue; +// auto& inst_list = i_op->get_parent()->get_inst_list(); +// auto it = std::find(inst_list.begin(), inst_list.end(), i_op); +// it++; +// builder.set_insert_point(it); +// // found a load; write to shared memory and stop recursion +// ir::instruction *n_op = nullptr; +// if(mem_->is_shared(i_op)){ +// continue; +// } +// if(auto* ld = dynamic_cast(i_op)) { +// n_op = ir::copy_to_shared_inst::create(ld); +// } +// // not a load; rematerialize and recurse +// else { +// n_op = i_op->clone(); +// work_list.push_back(n_op); +// } +// n_op = builder.insert(n_op); +// order_[n_op] = order; +// align_->copy(n_op, i_op); +// current->replace_uses_of_with(i_op, n_op); +// } +// } + +// } +} + + +} +} +} diff --git a/lib/codegen/transform/reorder.cc b/lib/codegen/transform/reorder.cc deleted file mode 100644 index 875faaab1..000000000 --- a/lib/codegen/transform/reorder.cc +++ /dev/null @@ -1,106 +0,0 @@ -#include -#include -#include -#include "triton/ir/function.h" -#include "triton/ir/cfg.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/module.h" -#include "triton/codegen/analysis/meminfo.h" -#include "triton/codegen/analysis/align.h" -#include "triton/codegen/transform/reorder.h" - -namespace triton { -namespace codegen{ -namespace transform{ - -reorder::reorder(analysis::align* align, analysis::meminfo *mem) - : align_(align), mem_(mem) { } - -std::vector reorder::get_order(ir::value* v) { - return order_.at(v); -} - -void reorder::run(ir::module &mod) { - - std::set io; - - std::function set_order = [&](ir::value *v) -> void { - if(order_.find(v) != order_.end()) - return; - if(ir::user* u = dynamic_cast(v)) - for(ir::value* op: u->ops()) - set_order(op); - ir::type* ty = v->get_type(); - if(!ty->is_tile_ty()) - return; - std::vector order(ty->get_tile_shapes().size()); - std::iota(order.begin(), order.end(), 0); - order_[v] = order; - }; - - // initialize work-list - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: ir::cfg::reverse_post_order(fn)) - for(ir::instruction *i: block->get_inst_list()){ - if(auto *x = dynamic_cast(i)) { - ir::type* ptr_ty = x->get_pointer_operand()->get_type(); - if(ptr_ty->is_tile_ty()) - io.insert(x); - } - set_order(i); - } - - ir::builder &builder = mod.get_builder(); - for(ir::io_inst *i: io) { - ir::value *ptr = i->get_pointer_operand(); - auto max_contiguous = align_->get_max_contiguous_vec(ptr); - std::vector order(max_contiguous.size()); - std::iota(order.begin(), order.end(), 0); - std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } ); - std::list work_list; - if(order != order_[i]) - work_list.push_back(i); - // rematerialize recursively - while(!work_list.empty()) { - ir::instruction* current = work_list.back(); - order_[current] = order; - work_list.pop_back(); - for(ir::value *op: current->ops()) { - ir::instruction* i_op = dynamic_cast(op); - if(!i_op) - continue; - ir::type *ty = i_op->get_type(); - if(!ty->is_tile_ty()) - continue; - auto& inst_list = i_op->get_parent()->get_inst_list(); - auto it = std::find(inst_list.begin(), inst_list.end(), i_op); - it++; - builder.set_insert_point(it); - // found a load; write to shared memory and stop recursion - ir::instruction *n_op = nullptr; - if(mem_->is_shared(i_op)){ - continue; - } - if(auto* ld = dynamic_cast(i_op)) { - n_op = ir::copy_to_shared_inst::create(ld); - } - // not a load; rematerialize and recurse - else { - n_op = i_op->clone(); - work_list.push_back(n_op); - } - n_op = builder.insert(n_op); - order_[n_op] = order; - align_->copy(n_op, i_op); - current->replace_uses_of_with(i_op, n_op); - } - } - - } -} - - -} -} -} diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 85877f911..0bf85c84f 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ - std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 05c39a451..ead9a9ab4 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -5,7 +5,7 @@ #include #include "triton/codegen/selection.h" #include "triton/runtime/function.h" -#include "triton/codegen/transform/reorder.h" +#include "triton/codegen/transform/coalesce.h" #include "triton/lang/cpp.h" #include "triton/lang/parser.h" #include "triton/lang/code_gen.h" @@ -197,7 +197,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::analysis::meminfo shmem_info; codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::analysis::align alignment_info; - codegen::transform::reorder reorder(&alignment_info, &shmem_info); + codegen::transform::coalesce reorder(&alignment_info, &shmem_info); codegen::analysis::grids grids(opt.num_warps, &reorder); codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); @@ -215,7 +215,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c // ir::print(module, std::cout); reorder.run(module); dce.run(module); - ir::print(module, std::cout); +// ir::print(module, std::cout); grids.run(module); reassociate.run(module); dce.run(module); @@ -231,7 +231,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); vectorize.run(module); dce.run(module); - ir::print(module, std::cout); +// ir::print(module, std::cout); // generate llvm code llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx));