History prior to this date belonged to the now deprecated ISAAC project, and was deleted to save space
This commit is contained in:
143
lib/codegen/transform/coalesce.cc
Normal file
143
lib/codegen/transform/coalesce.cc
Normal file
@@ -0,0 +1,143 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
||||
: align_(align), layout_(layouts) { }
|
||||
|
||||
// Find all values that are used as pointer operands in LD/ST
|
||||
void coalesce::extract_io_use(ir::value *v, std::set<ir::io_inst*>& result) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||
if(i && i->get_pointer_operand() == v)
|
||||
result.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
void coalesce::extract_ld(ir::io_inst* i, std::map<int, std::vector<ir::io_inst*>>& result) {
|
||||
ir::value *ptr = i->get_pointer_operand();
|
||||
auto contiguous = align_->contiguous(ptr);
|
||||
auto it = std::max_element(contiguous.begin(), contiguous.end());
|
||||
int axis = std::distance(contiguous.begin(), it);
|
||||
result[axis].push_back(i);
|
||||
}
|
||||
|
||||
ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder,
|
||||
std::map<ir::value*, ir::value*>& seen) {
|
||||
if(seen.find(x) != seen.end())
|
||||
return seen.at(x);
|
||||
auto i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction -- forward value
|
||||
if(!i)
|
||||
return x;
|
||||
// already in shared memory -- forward value
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(x)){
|
||||
return x;
|
||||
}
|
||||
// set insert point
|
||||
auto& inst_list = i->get_parent()->get_inst_list();
|
||||
auto pos = ++std::find(inst_list.begin(), inst_list.end(), i);
|
||||
builder.set_insert_point(pos);
|
||||
if(dynamic_cast<ir::load_inst*>(x)){
|
||||
ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x));
|
||||
return ret;
|
||||
}
|
||||
// default -- recursive clone
|
||||
ir::instruction *cloned = builder.insert(i->clone());
|
||||
seen[i] = cloned;
|
||||
// rematerialize operands
|
||||
for(ir::value *op: cloned->ops())
|
||||
cloned->replace_uses_of_with(op, rematerialize(op, builder, seen));
|
||||
return cloned;
|
||||
}
|
||||
|
||||
void coalesce::run(ir::module &mod) {
|
||||
size_t num_groups = layout_->num_layouts();
|
||||
|
||||
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
if(!layout_->get(id)->to_mma884())
|
||||
continue;
|
||||
// extract memory stores
|
||||
const auto& values = layout_->values_of(id);
|
||||
ir::value* dot = nullptr;
|
||||
for(ir::value *v: values)
|
||||
if(auto x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dot = x;
|
||||
|
||||
ir::builder& builder = mod.get_builder();
|
||||
std::vector<ir::value*> worklist = {dot};
|
||||
std::set<ir::value*> seen;
|
||||
while(!worklist.empty()) {
|
||||
ir::value *current = worklist.back();
|
||||
seen.insert(current);
|
||||
worklist.pop_back();
|
||||
// stop if trunc
|
||||
if(auto x = dynamic_cast<ir::fp_trunc_inst*>(current)){
|
||||
builder.set_insert_point_after(x);
|
||||
ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x);
|
||||
builder.insert(rc);
|
||||
x->replace_all_uses_with(rc);
|
||||
rc->replace_uses_of_with(rc, x);
|
||||
break;
|
||||
}
|
||||
// recurse
|
||||
for(ir::user *u: current->get_users())
|
||||
if(seen.find(u) == seen.end())
|
||||
worklist.push_back(u);
|
||||
}
|
||||
}
|
||||
|
||||
// find values to rematerialize
|
||||
std::vector<ir::io_inst*> remat;
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
const auto& values = layout_->values_of(id);
|
||||
// extract pointers used in ld/st operations
|
||||
std::set<ir::io_inst*> io;
|
||||
for(ir::value *v: values)
|
||||
extract_io_use(v, io);
|
||||
// extract leading axes
|
||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||
for(ir::io_inst *i: io){
|
||||
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
|
||||
extract_ld(i, axes);
|
||||
}
|
||||
// update list of values to rematerialize
|
||||
if(axes.empty())
|
||||
continue;
|
||||
for(auto it = ++axes.rbegin(); it != axes.rend(); it++)
|
||||
remat.insert(remat.begin(), it->second.begin(), it->second.end());
|
||||
}
|
||||
// rematerialize values
|
||||
for(ir::io_inst *r: remat) {
|
||||
ir::builder& builder = mod.get_builder();
|
||||
// rematerialize operands
|
||||
std::map<ir::value*, ir::value*> seen;
|
||||
for(ir::value *op: r->ops())
|
||||
r->replace_uses_of_with(op, rematerialize(op, mod.get_builder(), seen));
|
||||
// copy to shared if load
|
||||
auto& inst_list = r->get_parent()->get_inst_list();
|
||||
auto pos = ++std::find(inst_list.begin(), inst_list.end(), r);
|
||||
builder.set_insert_point(pos);
|
||||
if(dynamic_cast<ir::load_inst*>(r)){
|
||||
ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r));
|
||||
r->replace_all_uses_with(cts);
|
||||
cts->replace_uses_of_with(cts, r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
95
lib/codegen/transform/cts.cc
Normal file
95
lib/codegen/transform/cts.cc
Normal file
@@ -0,0 +1,95 @@
|
||||
#include "triton/codegen/transform/cts.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline bool is_shmem_op(ir::instruction* i, int op) {
|
||||
if(i->get_id() == ir::INST_DOT)
|
||||
return op==0 || op==1;
|
||||
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
||||
return op==0;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return op==0;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool is_shmem_res(ir::value* v){
|
||||
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return false;
|
||||
if(i->get_id() == ir::INST_TRANS)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_REDUCE)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
builder.set_insert_point(parent);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
return;
|
||||
}
|
||||
// phi node
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
||||
add_copy(phi, phi->get_incoming_value(i), builder, to_shared);
|
||||
return;
|
||||
}
|
||||
// already in shared memory
|
||||
if(to_shared && is_shmem_res(i))
|
||||
return;
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
}
|
||||
|
||||
void cts::run(ir::module &mod) {
|
||||
// Add shared copies
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function* fn: mod.get_function_list()){
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
size_t num_op = i->get_num_operands();
|
||||
// copy to shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(is_shmem_op(i, k))
|
||||
add_copy(i, i->get_operand(k), builder, true);
|
||||
// copy from shared operands
|
||||
for(size_t k = 0; k < num_op; k++)
|
||||
if(!dynamic_cast<ir::phi_node*>(i) &&
|
||||
!is_shmem_op(i,k) &&
|
||||
is_shmem_res(i->get_operand(k))){
|
||||
add_copy(i, i->get_operand(k), builder, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
75
lib/codegen/transform/dce.cc
Normal file
75
lib/codegen/transform/dce.cc
Normal file
@@ -0,0 +1,75 @@
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
void dce::run(ir::module &mod) {
|
||||
std::list<ir::instruction*> work_list;
|
||||
std::set<ir::instruction*> marked;
|
||||
|
||||
// initialize work-list
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
switch(i->get_id()){
|
||||
case ir::INST_RETURN:
|
||||
case ir::INST_UNCOND_BRANCH:
|
||||
case ir::INST_COND_BRANCH:
|
||||
case ir::INST_UNMASKED_STORE:
|
||||
case ir::INST_MASKED_STORE:
|
||||
case ir::INST_ATOMIC_ADD:
|
||||
case ir::INST_ATOMIC_CAS:
|
||||
case ir::INST_ATOMIC_EXCH:
|
||||
case ir::INST_BARRIER: {
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mark -- ignore branches
|
||||
while(!work_list.empty()){
|
||||
ir::instruction* current = work_list.back();
|
||||
work_list.pop_back();
|
||||
// mark instruction operands
|
||||
for(ir::value* op: current->ops()) {
|
||||
if(auto *i = dynamic_cast<ir::instruction*>(op)){
|
||||
if(marked.insert(i).second)
|
||||
work_list.push_back(i);
|
||||
}
|
||||
}
|
||||
// TODO: mark last intstruction of current's reverse-dominance frontier
|
||||
}
|
||||
|
||||
// sweep -- delete non-branch unmarked instructions
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(marked.find(i) == marked.end())
|
||||
to_delete.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// delete
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
76
lib/codegen/transform/disassociate.cc
Normal file
76
lib/codegen/transform/disassociate.cc
Normal file
@@ -0,0 +1,76 @@
|
||||
#include "triton/codegen/transform/disassociate.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/builder.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void extract_retile_chain(ir::user *root,
|
||||
std::map<int, std::set<ir::user*>>& result,
|
||||
int depth,
|
||||
std::set<ir::value*>& seen) {
|
||||
if(!seen.insert(root).second)
|
||||
return;
|
||||
result[depth].insert(root);
|
||||
if(dynamic_cast<ir::make_range*>(root) ||
|
||||
dynamic_cast<ir::splat_inst*>(root)){
|
||||
return;
|
||||
}
|
||||
for(ir::value *op: root->ops()){
|
||||
ir::user *u = dynamic_cast<ir::user*>(op);
|
||||
if(!u)
|
||||
continue;
|
||||
extract_retile_chain(u, result, depth + 1, seen);
|
||||
}
|
||||
}
|
||||
|
||||
void disassociate::run(ir::module &mod) {
|
||||
ir::builder &bld = mod.get_builder();
|
||||
|
||||
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
|
||||
ir::for_each_instruction(mod, [&](ir::instruction *i){
|
||||
if(dynamic_cast<ir::reshape_inst*>(i)){
|
||||
std::map<int, std::set<ir::user*>> chains;
|
||||
std::set<ir::value*> seen;
|
||||
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
|
||||
return;
|
||||
extract_retile_chain(i, chains, 0, seen);
|
||||
if(chains.size())
|
||||
clone_info[i] = chains;
|
||||
}
|
||||
});
|
||||
|
||||
for(const auto& x: clone_info){
|
||||
int depth = 1;
|
||||
std::map<ir::instruction*, ir::instruction*> clone_map;
|
||||
while(x.second.find(depth) != x.second.end()){
|
||||
// clone all users
|
||||
const auto& remat = x.second.at(depth);
|
||||
for(ir::user* u: remat){
|
||||
ir::instruction *y = (ir::instruction*)u;
|
||||
ir::instruction *cloned = y->clone();
|
||||
bld.set_insert_point(y);
|
||||
bld.insert(cloned);
|
||||
clone_map[y] = cloned;
|
||||
// replace operands of parents
|
||||
if(depth > 1)
|
||||
for(ir::user* ux: x.second.at(depth - 1))
|
||||
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
|
||||
else
|
||||
x.first->replace_uses_of_with(y, cloned);
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
168
lib/codegen/transform/membar.cc
Normal file
168
lib/codegen/transform/membar.cc
Normal file
@@ -0,0 +1,168 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, interval_t x) {
|
||||
return std::any_of(X.begin(), X.end(), [&](const interval_t &y){
|
||||
bool left_intersect = y.first <= x.first && x.first < y.second;
|
||||
bool right_intersect = y.first <= x.second && x.second < y.second;
|
||||
return left_intersect || right_intersect;
|
||||
});
|
||||
}
|
||||
|
||||
bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) {
|
||||
return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){
|
||||
return intersect(X, y);
|
||||
});
|
||||
}
|
||||
|
||||
void membar::add_reference(ir::value *v, interval_vec_t &res){
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
return;
|
||||
analysis::shared_layout* layout = layouts_->get(v)->to_shared();
|
||||
if(!layout)
|
||||
return;
|
||||
if(alloc_->has_offset(layout)){
|
||||
unsigned offset = alloc_->offset(layout);
|
||||
res.push_back(interval_t(offset, offset + layout->get_size()));
|
||||
}
|
||||
}
|
||||
|
||||
void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
for(ir::value *op: i->ops())
|
||||
add_reference(op, res);
|
||||
}
|
||||
|
||||
void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
if(!dynamic_cast<ir::phi_node*>(i) && !dynamic_cast<ir::trans_inst*>(i))
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
ir::instruction *inc_val = dynamic_cast<ir::instruction*>(phi->get_incoming_value(n));
|
||||
assert(inc_val);
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
builder.set_insert_point(instr);
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
|
||||
membar::interval_vec_t membar::join(const std::vector<interval_vec_t>& intervals) {
|
||||
membar::interval_vec_t result;
|
||||
for(auto x: intervals)
|
||||
for(interval_t i: x)
|
||||
result.push_back(i);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc,
|
||||
std::set<ir::value*>& safe_war) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
|
||||
for(ir::instruction *i: instructions){
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering
|
||||
if(safe_war.find(i) != safe_war.end()){
|
||||
write_after_read = false;
|
||||
read_after_write = false;
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
}
|
||||
return std::make_pair(new_written_to, new_read_from);
|
||||
}
|
||||
|
||||
void membar::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// extract phi-node associates with double-buffered
|
||||
// shared-memory copies. These can be read from and written to
|
||||
// without needing synchronization
|
||||
std::set<ir::value*> safe_war;
|
||||
for(const auto& x: layouts_->get_all()){
|
||||
analysis::shared_layout* layout = x.second->to_shared();
|
||||
if(!layout || !layout->get_double_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi)
|
||||
safe_war.insert(v);
|
||||
}
|
||||
|
||||
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::set<ir::instruction*> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
do{
|
||||
// find barrier location
|
||||
for(ir::basic_block *block: rpo){
|
||||
// written to
|
||||
std::vector<interval_vec_t> pred_written_to;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_written_to.push_back(written_to[pred]);
|
||||
// read from
|
||||
std::vector<interval_vec_t> pred_read_from;
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
}
|
||||
size_t n_inserted_i = insert_locs.size();
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs)
|
||||
insert_barrier(i, builder);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
191
lib/codegen/transform/peephole.cc
Normal file
191
lib/codegen/transform/peephole.cc
Normal file
@@ -0,0 +1,191 @@
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
|
||||
const std::vector<int>& perm) {
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
ir::basic_block* block = i->get_parent();
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
|
||||
auto trans = dynamic_cast<ir::trans_inst*>(value);
|
||||
if(!trans)
|
||||
return false;
|
||||
auto users = trans->get_users();
|
||||
auto ops = trans->ops();
|
||||
if(users.size() > 1 || ops.size() > 1)
|
||||
return false;
|
||||
ir::value* op = *ops.begin();
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
auto* phi = dynamic_cast<ir::phi_node*>(op);
|
||||
if(!phi)
|
||||
return false;
|
||||
ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
|
||||
if(!new_phi)
|
||||
return false;
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, 0) + c -> dot(a, b, c)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
|
||||
if(!lhs_dot && !rhs_dot)
|
||||
return false;
|
||||
ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant_fp *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
|
||||
if(!(_0 && _0->get_value() == 0.0))
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
|
||||
add->replace_all_uses_with(new_dot);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
ir::value *arg = x->get_operand(0);
|
||||
auto shapes = arg->get_type()->get_tile_shapes();
|
||||
if(shapes[x->get_axis()] == 1){
|
||||
builder.set_insert_point(x);
|
||||
ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes());
|
||||
x->replace_all_uses_with(new_red);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||
ir::value *lhs = binop->get_operand(0);
|
||||
ir::value *rhs = binop->get_operand(1);
|
||||
ir::constant_int *_1_lhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
|
||||
_1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
ir::constant_int *_1_rhs = nullptr;
|
||||
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
|
||||
_1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||
if(_1_lhs){
|
||||
binop->replace_all_uses_with(rhs);
|
||||
return true;
|
||||
}
|
||||
else if(_1_rhs){
|
||||
binop->replace_all_uses_with(lhs);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||
if(!x)
|
||||
return false;
|
||||
auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
|
||||
if(!y)
|
||||
return false;
|
||||
auto idx = *y->idx_begin();
|
||||
auto z = dynamic_cast<ir::binary_operator*>(idx);
|
||||
if(!z)
|
||||
return false;
|
||||
bool is_sub = z->get_op() == ir::binary_op_t::Sub;
|
||||
auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
|
||||
bool is_lhs_0 = lhs && (lhs->get_value()==0);
|
||||
bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
|
||||
if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
|
||||
x->replace_all_uses_with(y->get_pointer_operand());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void peephole::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
// keep track of whether any modification was made
|
||||
std::set<ir::value*> seen;
|
||||
size_t n_seen;
|
||||
|
||||
// rewrite dots first
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = rewrite_dot(i, builder);
|
||||
if(was_modified){
|
||||
seen.insert(i);
|
||||
}
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
|
||||
// rewrite other ops
|
||||
seen.clear();
|
||||
do{
|
||||
n_seen = seen.size();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(seen.find(i) != seen.end())
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
}while(seen.size() != n_seen);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
264
lib/codegen/transform/reassociate.cc
Normal file
264
lib/codegen/transform/reassociate.cc
Normal file
@@ -0,0 +1,264 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
|
||||
inline ir::instruction* reassociate::is_bin_add(ir::value *x) {
|
||||
ir::binary_operator *bin_op = dynamic_cast<ir::binary_operator*>(x);
|
||||
bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add;
|
||||
if(is_bin_add)
|
||||
return (ir::instruction*)x;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline bool is_cst(ir::value *x) {
|
||||
if(dynamic_cast<ir::constant*>(x))
|
||||
return true;
|
||||
if(auto *v = dynamic_cast<ir::retile_inst*>(x))
|
||||
return is_cst(v->get_operand(0));
|
||||
return false;
|
||||
}
|
||||
|
||||
ir::value *reassociate::reassociate_idx(ir::value *old_value,
|
||||
ir::builder &builder,
|
||||
ir::value *&noncst,
|
||||
ir::value *&cst){
|
||||
// value doesn't change by default
|
||||
ir::value* new_value = old_value;
|
||||
cst = nullptr;
|
||||
noncst = old_value;
|
||||
|
||||
// handle retiling
|
||||
if(ir::instruction* op = dynamic_cast<ir::retile_inst*>(old_value)){
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
ir::value *old_arg = op->get_operand(0);
|
||||
ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst);
|
||||
// retile(x + y) = retile(x) + retile(y)
|
||||
if(ir::instruction* bin_add = is_bin_add(new_arg))
|
||||
if(cst){
|
||||
ir::value *old_lhs = bin_add->get_operand(0);
|
||||
ir::value *old_rhs = bin_add->get_operand(1);
|
||||
ir::value *new_lhs = nullptr;
|
||||
ir::value *new_rhs = nullptr;
|
||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle binary addition
|
||||
if(ir::instruction* op = is_bin_add(old_value)){
|
||||
builder.set_insert_point(op);
|
||||
std::string name = op->get_name();
|
||||
ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst);
|
||||
ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst);
|
||||
builder.set_insert_point(op);
|
||||
// (x + y) + z
|
||||
if(ir::instruction* bin_lhs = is_bin_add(lhs)){
|
||||
ir::value *llhs = bin_lhs->get_operand(0);
|
||||
ir::value *rlhs = bin_lhs->get_operand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(is_cst(llhs))
|
||||
new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs), name);
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
ir::value *lrhs = bin_rhs->get_operand(0);
|
||||
ir::value *rrhs = bin_rhs->get_operand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(is_cst(lrhs))
|
||||
new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), name, cst);
|
||||
// x + (y + cst) -> cst + (x + y)
|
||||
if(is_cst(rrhs))
|
||||
new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), name, cst);
|
||||
}
|
||||
}
|
||||
// extract constant and non-constant
|
||||
if(ir::instruction *bin_add = is_bin_add(new_value)){
|
||||
ir::value *new_lhs = bin_add->get_operand(0);
|
||||
ir::value *new_rhs = bin_add->get_operand(1);
|
||||
if(is_cst(new_lhs)){
|
||||
cst = new_lhs;
|
||||
noncst = new_rhs;
|
||||
}
|
||||
if(is_cst(new_rhs)){
|
||||
cst = new_rhs;
|
||||
noncst = new_lhs;
|
||||
}
|
||||
}
|
||||
// clean-up if some re-ordering happened
|
||||
if(old_value != new_value)
|
||||
old_value->replace_all_uses_with(new_value);
|
||||
return new_value;
|
||||
}
|
||||
|
||||
/* run */
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
|
||||
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::make_range*> ranges;
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::value* op: i->ops())
|
||||
if(auto *range = dynamic_cast<ir::make_range*>(op))
|
||||
ranges.push_back(range);
|
||||
}
|
||||
|
||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||
for(ir::make_range* old_range: ranges){
|
||||
ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::make_range_sta::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
}
|
||||
}
|
||||
|
||||
// reassociate
|
||||
std::map<ir::value*, cst_info> infos;
|
||||
std::set<ir::value*> replaced;
|
||||
size_t n_replaced;
|
||||
do{
|
||||
n_replaced = replaced.size();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// retiling
|
||||
if(ir::retile_inst *rt = dynamic_cast<ir::retile_inst*>(i)) {
|
||||
ir::value* op = rt->get_operand(0);
|
||||
if(infos.find(op) != infos.end()){
|
||||
builder.set_insert_point(rt);
|
||||
ir::getelementptr_inst* sta = infos.at(op).sta_ptr;
|
||||
ir::value* dyn = infos.at(op).dyn_ptr;
|
||||
ir::value* cst = *sta->idx_begin();
|
||||
if(dynamic_cast<ir::broadcast_inst*>(rt)) {
|
||||
auto shapes = rt->get_type()->get_tile_shapes();
|
||||
ir::value* ndyn = builder.create_broadcast(dyn, shapes);
|
||||
ir::value* broadcast = builder.create_broadcast(cst, shapes);
|
||||
ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast});
|
||||
infos[rt] = cst_info{ndyn, nsta};
|
||||
}
|
||||
}
|
||||
}
|
||||
// getelementptr instruction
|
||||
if(ir::getelementptr_inst *pz = dynamic_cast<ir::getelementptr_inst*>(i)){
|
||||
if(replaced.find(pz) != replaced.end())
|
||||
continue;
|
||||
// unpack GEP instruction
|
||||
ir::value* py = pz->get_pointer_operand();
|
||||
ir::value* offset = *pz->idx_begin();
|
||||
// reassociate index
|
||||
ir::value *sta = nullptr;
|
||||
ir::value *dyn = offset;
|
||||
reassociate_idx(offset, builder, dyn, sta);
|
||||
if(sta){
|
||||
builder.set_insert_point(pz);
|
||||
ir::value *dyn_ptr = builder.create_gep(py, {dyn});
|
||||
ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta});
|
||||
pz->replace_all_uses_with(sta_ptr);
|
||||
infos[sta_ptr].dyn_ptr = dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(infos.find(py) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[py].sta_ptr;
|
||||
ir::value *dyn = infos[py].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *pz_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}, pz->get_name());
|
||||
pz->replace_all_uses_with(pz_sta);
|
||||
infos[pz_sta].dyn_ptr = pz_dyn;
|
||||
infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta;
|
||||
replaced.insert(pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
std::vector<ir::value*> ops = phi->ops();
|
||||
if(ops.size() != 2)
|
||||
continue;
|
||||
if(ops[0] != pz && ops[1] != pz)
|
||||
continue;
|
||||
// grab incoming
|
||||
size_t idx_z = (ops[0] == pz) ? 0 : 1;
|
||||
size_t idx_a = (ops[0] == pz) ? 1 : 0;
|
||||
// check if pa is known to have constant offset
|
||||
ir::value *vpa = phi->get_incoming_value(idx_a);
|
||||
auto it_a = infos.find(vpa);
|
||||
if(it_a == infos.end())
|
||||
continue;
|
||||
// unpack dynamically/statically offset pointer
|
||||
ir::value *pa_dyn = it_a->second.dyn_ptr;
|
||||
ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr;
|
||||
ir::value *pz = phi->get_incoming_value(idx_z);
|
||||
// extract offset
|
||||
ir::value *off = *pa_sta->idx_begin();
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2);
|
||||
phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a));
|
||||
builder.set_insert_point(phi->get_parent()->get_first_non_phi());
|
||||
// re-add the offset
|
||||
ir::value *phi_sta = builder.create_gep(phi_dyn, {off}, phi->get_name() + "_sta");
|
||||
phi->replace_all_uses_with(phi_sta);
|
||||
// remove offset from pz
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(pz)){
|
||||
auto insts = x->get_parent()->get_inst_list();
|
||||
auto it = std::find(insts.begin(), insts.end(), x);
|
||||
it++;
|
||||
builder.set_insert_point(*it);
|
||||
}
|
||||
ir::value *_0 = builder.get_int32(0);
|
||||
if(off->get_type()->is_tile_ty())
|
||||
_0 = builder.create_splat(_0, off->get_type()->get_tile_shapes());
|
||||
ir::value *neg_off = builder.create_sub(_0, off);
|
||||
ir::value *pz_dyn = builder.create_gep(pz, {neg_off});
|
||||
phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z));
|
||||
infos[phi_sta].dyn_ptr = phi_dyn;
|
||||
infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta;
|
||||
replaced.insert(phi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}while(replaced.size() != n_replaced);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user