[codegen][coalesce] some bugfix for phi-nodes

This commit is contained in:
Philippe Tillet
2019-09-12 22:44:07 -04:00
parent 0c41bade07
commit 11ff27d638
9 changed files with 126 additions and 123 deletions

View File

@@ -19,7 +19,7 @@ namespace ir{
namespace codegen{
namespace transform{
class reorder;
class coalesce;
}
namespace analysis{
@@ -46,7 +46,7 @@ private:
public:
grids(size_t num_warps, transform::reorder* reorder);
grids(size_t num_warps, transform::coalesce* reorder);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
unsigned get_param_group(ir::value *value, unsigned ax);
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
@@ -66,7 +66,7 @@ private:
std::vector<ir::value*> grids_;
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
size_t num_warps_;
transform::reorder* reorder_;
transform::coalesce* reorder_;
};

View File

@@ -50,7 +50,7 @@ class meminfo;
}
namespace transform{
class reorder;
class coalesce;
}
class target;
@@ -195,7 +195,7 @@ private:
public:
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::reorder* reorder, target *tgt)
selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, transform::coalesce* reorder, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), reorder_(reorder), tgt_(tgt){ }
void run(ir::module &src, Module &dst);
@@ -207,7 +207,7 @@ private:
analysis::grids *params_;
analysis::meminfo *buffer_info_;
analysis::align *alignment_;
transform::reorder *reorder_;
transform::coalesce *reorder_;
target *tgt_;
std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_;

View File

@@ -20,9 +20,9 @@ namespace analysis{
namespace transform{
class reorder {
class coalesce {
public:
reorder(analysis::align* algin, analysis::meminfo* mem);
coalesce(analysis::align* algin, analysis::meminfo* mem);
std::vector<unsigned> get_order(ir::value* v);
void run(ir::module &mod);

View File

@@ -1,6 +1,6 @@
#include <algorithm>
#include <cstdlib>
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/codegen/analysis/grid.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
@@ -16,7 +16,7 @@ namespace triton{
namespace codegen{
namespace analysis{
grids::grids(size_t num_warps, transform::reorder *reorder): num_warps_(num_warps), reorder_(reorder)
grids::grids(size_t num_warps, transform::coalesce *reorder): num_warps_(num_warps), reorder_(reorder)
{ }
bool is_hmma(ir::value *v){
@@ -298,7 +298,7 @@ void grids::run(ir::module &mod) {
unsigned current = num_threads;
std::string nts = "nts.d" + s_ld;
std::string mts = "mts.d" + s_ld;
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 1));
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 8));
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
current = current / params_.at(i).at(mts)->get_value();
for(size_t d = 1; d < shapes.size(); d++){

View File

@@ -3,7 +3,7 @@
#include "triton/codegen/analysis/grid.h"
#include "triton/codegen/analysis/memalloc.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"

View File

@@ -0,0 +1,110 @@
#include <iostream>
#include <algorithm>
#include <numeric>
#include "triton/ir/function.h"
#include "triton/ir/cfg.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/module.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/coalesce.h"
namespace triton {
namespace codegen{
namespace transform{
coalesce::coalesce(analysis::align* align, analysis::meminfo *mem)
: align_(align), mem_(mem) { }
std::vector<unsigned> coalesce::get_order(ir::value* v) {
return order_.at(v);
}
void coalesce::run(ir::module &mod) {
std::set<ir::io_inst*> io;
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
if(order_.find(v) != order_.end())
return;
order_[v] = {};
if(ir::user* u = dynamic_cast<ir::user*>(v))
for(ir::value* op: u->ops())
set_order(op);
ir::type* ty = v->get_type();
if(!ty->is_tile_ty())
return;
std::vector<unsigned> order(ty->get_tile_shapes().size());
std::iota(order.begin(), order.end(), 0);
order_[v] = order;
};
// initialize work-list
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
for(ir::instruction *i: block->get_inst_list()){
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
if(ptr_ty->is_tile_ty())
io.insert(x);
}
set_order(i);
}
// ir::builder &builder = mod.get_builder();
// std::set<ir::value*> seen;
// for(ir::io_inst *i: io) {
// ir::value *ptr = i->get_pointer_operand();
// auto max_contiguous = align_->get_max_contiguous_vec(ptr);
// std::vector<unsigned> order(max_contiguous.size());
// std::iota(order.begin(), order.end(), 0);
// std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
// std::list<ir::instruction*> work_list;
// if(order != order_[i])
// work_list.push_back(i);
// // rematerialize recursively
// while(!work_list.empty()) {
// ir::instruction* current = work_list.back();
// order_[current] = order;
// work_list.pop_back();
// for(ir::value *op: current->ops()) {
// ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
// if(!seen.insert(op).second)
// continue;
// if(!i_op)
// continue;
// ir::type *ty = i_op->get_type();
// if(!ty->is_tile_ty())
// continue;
// auto& inst_list = i_op->get_parent()->get_inst_list();
// auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
// it++;
// builder.set_insert_point(it);
// // found a load; write to shared memory and stop recursion
// ir::instruction *n_op = nullptr;
// if(mem_->is_shared(i_op)){
// continue;
// }
// if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
// n_op = ir::copy_to_shared_inst::create(ld);
// }
// // not a load; rematerialize and recurse
// else {
// n_op = i_op->clone();
// work_list.push_back(n_op);
// }
// n_op = builder.insert(n_op);
// order_[n_op] = order;
// align_->copy(n_op, i_op);
// current->replace_uses_of_with(i_op, n_op);
// }
// }
// }
}
}
}
}

View File

@@ -1,106 +0,0 @@
#include <iostream>
#include <algorithm>
#include <numeric>
#include "triton/ir/function.h"
#include "triton/ir/cfg.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/module.h"
#include "triton/codegen/analysis/meminfo.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/transform/reorder.h"
namespace triton {
namespace codegen{
namespace transform{
reorder::reorder(analysis::align* align, analysis::meminfo *mem)
: align_(align), mem_(mem) { }
std::vector<unsigned> reorder::get_order(ir::value* v) {
return order_.at(v);
}
void reorder::run(ir::module &mod) {
std::set<ir::io_inst*> io;
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
if(order_.find(v) != order_.end())
return;
if(ir::user* u = dynamic_cast<ir::user*>(v))
for(ir::value* op: u->ops())
set_order(op);
ir::type* ty = v->get_type();
if(!ty->is_tile_ty())
return;
std::vector<unsigned> order(ty->get_tile_shapes().size());
std::iota(order.begin(), order.end(), 0);
order_[v] = order;
};
// initialize work-list
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: ir::cfg::reverse_post_order(fn))
for(ir::instruction *i: block->get_inst_list()){
if(auto *x = dynamic_cast<ir::io_inst*>(i)) {
ir::type* ptr_ty = x->get_pointer_operand()->get_type();
if(ptr_ty->is_tile_ty())
io.insert(x);
}
set_order(i);
}
ir::builder &builder = mod.get_builder();
for(ir::io_inst *i: io) {
ir::value *ptr = i->get_pointer_operand();
auto max_contiguous = align_->get_max_contiguous_vec(ptr);
std::vector<unsigned> order(max_contiguous.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
std::list<ir::instruction*> work_list;
if(order != order_[i])
work_list.push_back(i);
// rematerialize recursively
while(!work_list.empty()) {
ir::instruction* current = work_list.back();
order_[current] = order;
work_list.pop_back();
for(ir::value *op: current->ops()) {
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
if(!i_op)
continue;
ir::type *ty = i_op->get_type();
if(!ty->is_tile_ty())
continue;
auto& inst_list = i_op->get_parent()->get_inst_list();
auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
it++;
builder.set_insert_point(it);
// found a load; write to shared memory and stop recursion
ir::instruction *n_op = nullptr;
if(mem_->is_shared(i_op)){
continue;
}
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
n_op = ir::copy_to_shared_inst::create(ld);
}
// not a load; rematerialize and recurse
else {
n_op = i_op->clone();
work_list.push_back(n_op);
}
n_op = builder.insert(n_op);
order_[n_op] = order;
align_->copy(n_op, i_op);
current->replace_uses_of_with(i_op, n_op);
}
}
}
}
}
}
}

View File

@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -5,7 +5,7 @@
#include <algorithm>
#include "triton/codegen/selection.h"
#include "triton/runtime/function.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
#include "triton/lang/code_gen.h"
@@ -197,7 +197,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::analysis::meminfo shmem_info;
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::align alignment_info;
codegen::transform::reorder reorder(&alignment_info, &shmem_info);
codegen::transform::coalesce reorder(&alignment_info, &shmem_info);
codegen::analysis::grids grids(opt.num_warps, &reorder);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
@@ -215,7 +215,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// ir::print(module, std::cout);
reorder.run(module);
dce.run(module);
ir::print(module, std::cout);
// ir::print(module, std::cout);
grids.run(module);
reassociate.run(module);
dce.run(module);
@@ -231,7 +231,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
dce.run(module);
vectorize.run(module);
dce.run(module);
ir::print(module, std::cout);
// ir::print(module, std::cout);
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));