[codegen][coalesce] more bugfixes

This commit is contained in:
Philippe Tillet
2019-09-13 14:17:21 -04:00
parent 3fa3b90f16
commit 579a662e60
10 changed files with 92 additions and 124 deletions

View File

@@ -298,7 +298,7 @@ void grids::run(ir::module &mod) {
unsigned current = num_threads;
std::string nts = "nts.d" + s_ld;
std::string mts = "mts.d" + s_ld;
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 8));
params_.at(i).at(nts)->set_value(clamp(size / num_threads, 1, 4));
params_.at(i).at(mts)->set_value(clamp(current, 1, shapes[ld] / params_.at(i).at(nts)->get_value()));
current = current / params_.at(i).at(mts)->get_value();
for(size_t d = 1; d < shapes.size(); d++){

View File

@@ -34,6 +34,16 @@ void meminfo::replace(ir::value* before, ir::value *after) {
}
}
void meminfo::copy(ir::value* y, ir::value *x) {
if(shared_.find(x) != shared_.end())
shared_.insert(y);
if(refs_.find(x) != refs_.end())
refs_[y] = refs_[x];
if(double_.find(x) != double_.end())
double_.insert(y);
}
inline bool get_is_shared(ir::value* v) {
if(dynamic_cast<ir::atomic_cas_inst*>(v))
return true;

View File

@@ -556,15 +556,15 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<unsigned>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
static const size_t warp_size = 32;
size_t nthreads = 1, nwarps = 1;
nw.resize(bs.size());
ws.resize(bs.size());
for(size_t i = 0; i < bs.size(); ++i){
nthreads *= bs[i];
nw[i] = ceil(nthreads, nwarps*warp_size);
nwarps *= nw[i];
nw[order[i]] = ceil(nthreads, nwarps*warp_size);
nwarps *= nw[order[i]];
}
for(size_t i = 0; i < bs.size(); ++i){
ws[i] = bs[i] / nw[i];
@@ -585,7 +585,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
}
to_warps(block_size, n_warps, warp_size);
to_warps(block_size, order, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
// Create axes
@@ -711,52 +711,6 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
}
}
void selection::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
unsigned result = 0;
for(auto shape: v->get_type()->get_tile_shapes()) {
result += (shape > 1)? shape : 0;
}
return result;
};
// bind references
std::set<ir::value*> seen;
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
{
// skip
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
// recurse
if(auto *user = dynamic_cast<ir::user*>(v))
for(ir::value *op: user->ops())
bind_references(op);
// bind
const auto& shapes = v->get_type()->get_tile_shapes();
if(buffer_info_->is_shared(v))
return;
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] == 1)
continue;
unsigned x = params_->get_param_group(v, d);
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
}
};
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
bind_references(i);
// create grid
for(auto &ref: references)
if(std::find(grids.begin(), grids.end(), ref.second) == grids.end())
grids.push_back(ref.second);
}
bool static inline has_phi_user(ir::value *v) {
for(ir::user *usr: v->get_users()){
if(dynamic_cast<ir::phi_node*>(usr))

View File

@@ -28,16 +28,17 @@ void coalesce::run(ir::module &mod) {
std::function<void(ir::value*)> set_order = [&](ir::value *v) -> void {
if(order_.find(v) != order_.end())
return;
order_[v] = {};
ir::type *tile_ty = v->get_type();
if(auto *x = dynamic_cast<ir::store_inst*>(v))
tile_ty = x->get_operand(0)->get_type();
if(!tile_ty->is_tile_ty())
return;
std::vector<unsigned> order(tile_ty->get_tile_shapes().size());
std::iota(order.begin(), order.end(), 0);
order_[v] = order;
if(ir::user* u = dynamic_cast<ir::user*>(v))
for(ir::value* op: u->ops())
set_order(op);
ir::type* ty = v->get_type();
if(!ty->is_tile_ty())
return;
std::vector<unsigned> order(ty->get_tile_shapes().size());
std::iota(order.begin(), order.end(), 0);
order_[v] = order;
};
// initialize work-list
@@ -52,56 +53,58 @@ void coalesce::run(ir::module &mod) {
set_order(i);
}
// ir::builder &builder = mod.get_builder();
// std::set<ir::value*> seen;
// for(ir::io_inst *i: io) {
// ir::value *ptr = i->get_pointer_operand();
// auto max_contiguous = align_->get_max_contiguous_vec(ptr);
// std::vector<unsigned> order(max_contiguous.size());
// std::iota(order.begin(), order.end(), 0);
// std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
// std::list<ir::instruction*> work_list;
// if(order != order_[i])
// work_list.push_back(i);
// // rematerialize recursively
// while(!work_list.empty()) {
// ir::instruction* current = work_list.back();
// order_[current] = order;
// work_list.pop_back();
// for(ir::value *op: current->ops()) {
// ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
// if(!seen.insert(op).second)
// continue;
// if(!i_op)
// continue;
// ir::type *ty = i_op->get_type();
// if(!ty->is_tile_ty())
// continue;
// auto& inst_list = i_op->get_parent()->get_inst_list();
// auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
// it++;
// builder.set_insert_point(it);
// // found a load; write to shared memory and stop recursion
// ir::instruction *n_op = nullptr;
// if(mem_->is_shared(i_op)){
// continue;
// }
// if(auto* ld = dynamic_cast<ir::load_inst*>(i_op)) {
// n_op = ir::copy_to_shared_inst::create(ld);
// }
// // not a load; rematerialize and recurse
// else {
// n_op = i_op->clone();
// work_list.push_back(n_op);
// }
// n_op = builder.insert(n_op);
// order_[n_op] = order;
// align_->copy(n_op, i_op);
// current->replace_uses_of_with(i_op, n_op);
// }
// }
ir::builder &builder = mod.get_builder();
std::map<ir::value*, ir::value*> replaced;
for(ir::io_inst *i: io) {
ir::value *ptr = i->get_pointer_operand();
auto max_contiguous = align_->get_max_contiguous_vec(ptr);
std::vector<unsigned> order(max_contiguous.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; } );
std::list<ir::instruction*> work_list;
if(order != order_[i])
work_list.push_back(i);
// rematerialize recursively
while(!work_list.empty()) {
ir::instruction* current = work_list.back();
order_[current] = order;
work_list.pop_back();
for(ir::value *op: current->ops()) {
ir::instruction* i_op = dynamic_cast<ir::instruction*>(op);
if(replaced.find(i_op) != replaced.end()){
current->replace_uses_of_with(i_op, replaced.at(i_op));
continue;
}
if(!i_op)
continue;
ir::type *ty = i_op->get_type();
if(!ty->is_tile_ty())
continue;
auto& inst_list = i_op->get_parent()->get_inst_list();
auto it = std::find(inst_list.begin(), inst_list.end(), i_op);
it++;
builder.set_insert_point(it);
// found a load; write to shared memory and stop recursion
ir::instruction *n_op = nullptr;
if(mem_->is_shared(i_op))
continue;
if(auto* ld = dynamic_cast<ir::load_inst*>(i_op))
n_op = ir::copy_to_shared_inst::create(ld);
// not a load; rematerialize and recurse
else {
n_op = i_op->clone();
work_list.push_back(n_op);
}
n_op = builder.insert(n_op);
replaced.insert({i_op, n_op});
order_[n_op] = order;
align_->copy(n_op, i_op);
// mem_->copy(n_op, i_op);
current->replace_uses_of_with(i_op, n_op);
}
}
// }
}
}

View File

@@ -3,6 +3,7 @@
#include "triton/ir/module.h"
#include "triton/ir/cfg.h"
#include "triton/codegen/transform/dce.h"
#include <iostream>
namespace triton {
namespace codegen{
@@ -35,9 +36,10 @@ void dce::run(ir::module &mod) {
work_list.pop_back();
// mark instruction operands
for(ir::value* op: current->ops()) {
if(auto *i = dynamic_cast<ir::instruction*>(op))
if(auto *i = dynamic_cast<ir::instruction*>(op)){
if(marked.insert(i).second)
work_list.push_back(i);
}
}
// TODO: mark last intstruction of current's reverse-dominance frontier
}