FINALLY
This commit is contained in:
@@ -20,24 +20,26 @@ class getelementptr_inst;
|
||||
namespace codegen{
|
||||
|
||||
class tune;
|
||||
class alignment_info;
|
||||
|
||||
class reassociate {
|
||||
struct cst_info {
|
||||
ir::value* sta;
|
||||
ir::value* dyn;
|
||||
ir::getelementptr_inst* dyn_ptr;
|
||||
ir::getelementptr_inst* sta_ptr;
|
||||
};
|
||||
|
||||
private:
|
||||
ir::instruction* is_bin_add(ir::value *x);
|
||||
ir::value *reassociate_idx(ir::value *value, ir::builder &builder, std::vector<ir::instruction*>& to_delete, ir::value *&noncst, ir::value *&cst);
|
||||
ir::value *reassociate_idx(ir::value *value, ir::builder &builder, ir::value *&noncst, ir::value *&cst);
|
||||
ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map<ir::value*, cst_info> &offsets);
|
||||
|
||||
public:
|
||||
reassociate(tune *params);
|
||||
reassociate(tune *params, alignment_info *align);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
tune* params_;
|
||||
alignment_info* align_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -66,7 +66,7 @@ public:
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
alignment_info(),
|
||||
reassociate(&tune),
|
||||
reassociate(&tune, &alignment_info),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
@@ -79,7 +79,7 @@ public:
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
ir::print(module, std::cout);
|
||||
//exit(EXIT_FAILURE);
|
||||
// exit(EXIT_FAILURE);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/reassociate.h"
|
||||
#include "triton/codegen/alignment_info.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -48,33 +49,8 @@ inline bool is_cst(ir::value *x) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// reassociate pointer
|
||||
// pz = py + a = (px + (cst + b)) + a -> (px + b) + (cst + a)
|
||||
ir::value *reassociate::reassociate_ptr(ir::getelementptr_inst* pz,
|
||||
ir::builder &builder,
|
||||
std::map<ir::value*, cst_info> &info) {
|
||||
ir::value *a = *pz->idx_begin();
|
||||
ir::value *vpy = pz->get_pointer_operand();
|
||||
if(info.find(vpy) == info.end())
|
||||
return nullptr;
|
||||
ir::getelementptr_inst *py = (ir::getelementptr_inst*)vpy;
|
||||
ir::value *px = py->get_pointer_operand();
|
||||
ir::value *cst = info.at(py).sta;
|
||||
ir::value *b = info.at(py).dyn;
|
||||
ir::value *new_py = builder.create_gep(px, {b});
|
||||
ir::value *new_a = builder.create_add(cst, a);
|
||||
ir::value *new_pz = builder.create_gep(new_py, {new_a});
|
||||
params_->copy(new_pz, pz);
|
||||
params_->copy(new_py, vpy);
|
||||
params_->copy(new_a, a);
|
||||
pz->replace_all_uses_with(new_pz);
|
||||
return pz;
|
||||
}
|
||||
|
||||
ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::builder &builder,
|
||||
std::vector<ir::instruction*>& to_delete,
|
||||
ir::value *&noncst,
|
||||
ir::value *&cst){
|
||||
// value doesn't change by default
|
||||
@@ -86,7 +62,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, to_delete, noncst, cst);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
||||
if(cst){
|
||||
@@ -116,7 +92,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(new_lhs, old_value);
|
||||
params_->copy(new_rhs, old_value);
|
||||
to_delete.push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -125,8 +100,8 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
if(ir::instruction* op = is_bin_add(old_value)){
|
||||
builder.set_insert_point(op);
|
||||
std::string name = op->get_name();
|
||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, to_delete, noncst, cst);
|
||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, to_delete, noncst, cst);
|
||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
|
||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
|
||||
builder.set_insert_point(op);
|
||||
// (x + y) + z
|
||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
||||
@@ -138,9 +113,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
if(new_value != old_value){
|
||||
to_delete.push_back(bin_lhs);
|
||||
}
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
@@ -152,8 +124,6 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
if(new_value != op)
|
||||
to_delete.push_back(bin_rhs);
|
||||
}
|
||||
if(new_value != old_value){
|
||||
params_->copy(new_value, old_value);
|
||||
@@ -179,22 +149,19 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
// clean-up if some re-ordering happened
|
||||
if(old_value != new_value){
|
||||
old_value->replace_all_uses_with(new_value);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(old_value))
|
||||
to_delete.push_back(x);
|
||||
}
|
||||
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate(tune* params)
|
||||
: params_(params)
|
||||
reassociate::reassociate(tune* params, alignment_info* align)
|
||||
: params_(params), align_(align)
|
||||
{ }
|
||||
|
||||
|
||||
/* run */
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
|
||||
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
@@ -232,56 +199,75 @@ void reassociate::run(ir::module &mod) {
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
|
||||
// pz = py + offset
|
||||
// tries to achieve pz = py + (cst + a)
|
||||
// by modifying py and/or offset
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(pz, builder, to_delete, dyn, sta);
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
infos[pz] = {sta, dyn};
|
||||
re_ordered[block].insert(pz);
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
params_->copy(dyn_ptr, pz);
|
||||
params_->copy(sta_ptr, pz);
|
||||
align_->copy(sta_ptr, pz);
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it = infos.find(vpa);
|
||||
if(it == infos.end())
|
||||
continue;
|
||||
ir::getelementptr_inst *pa = (ir::getelementptr_inst*)vpa;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::getelementptr_inst *dyn_ptr = it->second.dyn_ptr;
|
||||
ir::getelementptr_inst *sta_ptr = it->second.sta_ptr;
|
||||
// we take static offset out of the phi function
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *new_phi = builder.create_phi(phi->get_type(), 2);
|
||||
// new pz for phi has the same offsets
|
||||
builder.set_insert_point(pz);
|
||||
std::vector<ir::value*> idxs(pz->idx_begin(), pz->idx_end());
|
||||
ir::value *new_phi_pz = builder.create_gep(new_phi, idxs);
|
||||
// fold the static offset into the new pz value
|
||||
ir::value *new_pz = builder.create_gep(new_phi_pz, {*sta_ptr->idx_begin()});
|
||||
// populate incoming values
|
||||
new_phi->add_incoming(dyn_ptr, phi->get_incoming_block(idx_a));
|
||||
new_phi->add_incoming(new_phi_pz, phi->get_incoming_block(idx_z));
|
||||
// replace phi uses
|
||||
phi->replace_all_uses_with(new_phi);
|
||||
// replace pz uses
|
||||
pz->replace_all_uses_with(new_pz);
|
||||
// copy params
|
||||
params_->copy(new_phi_pz, pz);
|
||||
params_->copy(new_phi, phi);
|
||||
params_->copy(new_pz, pz);
|
||||
align_->copy(new_pz, pz);
|
||||
}
|
||||
|
||||
// // reassociate pointer
|
||||
// reassociate_ptr(pz, builder, offsets);
|
||||
|
||||
// // reassociate phi-node
|
||||
// if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// // only optimize the case where py = phi pa, pz
|
||||
// std::vector<ir::value*> ops = phi->ops();
|
||||
// if(!(ops.size() == 2 && (ops[0] == pz || ops[1] == pz)))
|
||||
// continue;
|
||||
// size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
// size_t idx_a = (idx_z + 1) % 2;
|
||||
// ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
// ir::value *block_a = phi->get_incoming_block(idx_a);
|
||||
// ir::value *block_z = phi->get_incoming_value(idx_z);
|
||||
// auto it = infos.find(vpa);
|
||||
// if(it == infos.end())
|
||||
// continue;
|
||||
// ir::value *b = it->a;
|
||||
// // pa = px + (cst + b)
|
||||
// ir::getelementptr_inst *pa = (ir::getelementptr_inst*)vpa;
|
||||
// ir::getelementptr_inst *px = pa->get_pointer_operand();
|
||||
// // new_pa = px + b
|
||||
// ir::getelementptr_inst *new_pa = builder.create_gep(px, {b});
|
||||
// // new_pz = py + (offset + a)
|
||||
// ir::getelementptr_inst *new_offset = builder.create_add(it->cst, dyn);
|
||||
// ir::getelementptr_inst *new_pz = builder.create_gep(pz->get_pointer_operand(), {new_offset});
|
||||
// }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// erase dead code
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user