progress on re-association

This commit is contained in:
Philippe Tillet
2019-07-23 17:21:24 -07:00
parent 38b3771c26
commit 397d76156b
11 changed files with 167 additions and 26 deletions

View File

@@ -5,6 +5,7 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/cfg.h"
#include "triton/codegen/tune.h"
namespace triton {
namespace codegen{
@@ -68,25 +69,32 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
if(cst){
ir::value *old_lhs = bin_add->get_operand(0);
ir::value *old_rhs = bin_add->get_operand(1);
ir::value *new_lhs = nullptr;
ir::value *new_rhs = nullptr;
if(dynamic_cast<ir::reshape_inst*>(op)){
builder.set_insert_point(op);
ir::value *new_lhs = builder.create_reshape(old_lhs, shapes);
ir::value *new_rhs = builder.create_reshape(old_rhs, shapes);
new_lhs = builder.create_reshape(old_lhs, shapes);
new_rhs = builder.create_reshape(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(dynamic_cast<ir::broadcast_inst*>(op)){
builder.set_insert_point(op);
ir::value *new_lhs = builder.create_broadcast(old_lhs, shapes);
ir::value *new_rhs = builder.create_broadcast(old_rhs, shapes);
new_lhs = builder.create_broadcast(old_lhs, shapes);
new_rhs = builder.create_broadcast(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(dynamic_cast<ir::splat_inst*>(op)){
builder.set_insert_point(op);
ir::value *new_lhs = builder.create_splat(old_lhs, shapes);
ir::value *new_rhs = builder.create_splat(old_rhs, shapes);
new_lhs = builder.create_splat(old_lhs, shapes);
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
to_delete.push_back(op);
if(new_value != old_value){
params_->copy(new_value, old_value);
params_->copy(new_lhs, old_value);
params_->copy(new_rhs, old_value);
to_delete.push_back(op);
}
}
}
@@ -107,8 +115,9 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
// (x + cst) + y -> cst + (x + y)
if(is_cst(rlhs))
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
if(new_value != op)
if(new_value != old_value){
to_delete.push_back(bin_lhs);
}
}
// x + (y + z)
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
@@ -123,6 +132,11 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
if(new_value != op)
to_delete.push_back(bin_rhs);
}
if(new_value != old_value){
params_->copy(new_value, old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
}
}
// extract constant and non-constant
@@ -149,13 +163,39 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
return new_value;
}
reassociate::reassociate() {
}
reassociate::reassociate(tune* params)
: params_(params)
{ }
void reassociate::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
std::vector<ir::instruction*> to_delete;
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::constant_range*> ranges;
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list())
for(ir::value* op: i->ops())
if(auto *range = dynamic_cast<ir::constant_range*>(op))
ranges.push_back(range);
}
builder.set_insert_point(rpo.front()->get_first_non_phi());
for(ir::constant_range* old_range: ranges){
ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type()));
ir::value* static_range = ir::nv_static_range_idx::get(old_range);
ir::value* new_range = builder.create_add(dyn_range, static_range);
old_range->replace_all_uses_with(new_range);
params_->copy(dyn_range, old_range);
params_->copy(static_range, old_range);
params_->copy(new_range, old_range);
}
}
// reassociate
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
bool done = false;