progress on re-association

This commit is contained in:
Philippe Tillet
2019-07-23 17:21:24 -07:00
parent 38b3771c26
commit 397d76156b
11 changed files with 167 additions and 26 deletions

View File

@@ -32,6 +32,7 @@ public:
void run(ir::module &mod);
unsigned get_starting_multiple(ir::value* v) const;
unsigned get_max_contiguous(ir::value* v) const;
void copy(ir::value *dst, ir::value *src);
private:
std::map<ir::value*, cst_info> is_constant_;

View File

@@ -18,14 +18,19 @@ class instruction;
namespace codegen{
class tune;
class reassociate {
private:
ir::instruction* is_bin_add(ir::value *x);
ir::value *reorder_op(ir::value *value, ir::builder &builder, std::vector<ir::instruction*>& to_delete, ir::value *&noncst, ir::value *&cst);
public:
reassociate();
reassociate(tune *params);
void run(ir::module& module);
private:
tune* params_;
};
}

View File

@@ -44,7 +44,7 @@ public:
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
unsigned get_param_group(ir::value *value, unsigned ax);
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; groups_[dst] = groups_[src]; }
void copy(ir::value *dst, ir::value *src);
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
void init(ir::module &mod);

View File

@@ -2,6 +2,7 @@
#define TDL_INCLUDE_IR_INSTRUCTIONS_H
#include <vector>
#include "triton/ir/constant.h"
#include "triton/ir/value.h"
#include "triton/ir/type.h"
#include "triton/ir/metadata.h"
@@ -651,6 +652,31 @@ public:
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
};
// On NVIDIA, implementation is such that
// constant_range = nv_dynamic_range_idx + nv_static_range_idx
// so as to enable re-association on nv_static_range_idx which is constant
class nv_dynamic_range_idx_inst: public instruction {
private:
nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_range_idx"; }
public:
static nv_dynamic_range_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
};
class nv_static_range_idx: public constant {
private:
nv_static_range_idx(constant_range *range);
public:
static nv_static_range_idx *get(constant_range* range);
constant_range* get_range() const;
private:
constant_range *range_;
};
}
}

View File

@@ -66,18 +66,19 @@ public:
optimize_cse(),
optimize_trans(),
alignment_info(),
reassociate(&tune),
target_(target) { }
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
ir::print(module, std::cout);
reassociate_.run(module);
// ir::print(module, std::cout);
}
void target_dependent(ir::module &module) {
ir::print(module, std::cout);
alignment_info.run(module);
reassociate.run(module);
ir::print(module, std::cout);
if(target_->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
@@ -98,7 +99,7 @@ public:
codegen::optimize_cse optimize_cse;
codegen::optimize_trans optimize_trans;
codegen::alignment_info alignment_info;
codegen::reassociate reassociate_;
codegen::reassociate reassociate;
codegen::target* target_;
};

View File

@@ -228,6 +228,12 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
return cache(x->get_first()->get_value());
}
if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(v)){
return cache(128);
}
if(auto *x = dynamic_cast<ir::nv_static_range_idx*>(v)){
return cache(x->get_range()->get_first()->get_value());
}
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
int lhs = populate_starting_multiple(x->get_operand(0));
int rhs = populate_starting_multiple(x->get_operand(1));
@@ -280,6 +286,12 @@ unsigned alignment_info::get_max_contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void alignment_info::copy(ir::value *dst, ir::value *src) {
starting_multiple_[dst] = starting_multiple_[src];
max_contiguous_[dst] = max_contiguous_[src];
is_constant_[dst] = is_constant_[src];
}
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
void alignment_info::run(ir::module &mod) {
// populate constant
@@ -301,7 +313,7 @@ void alignment_info::run(ir::module &mod) {
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i);
// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
}
}

View File

@@ -5,6 +5,7 @@
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/cfg.h"
#include "triton/codegen/tune.h"
namespace triton {
namespace codegen{
@@ -68,25 +69,32 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
if(cst){
ir::value *old_lhs = bin_add->get_operand(0);
ir::value *old_rhs = bin_add->get_operand(1);
ir::value *new_lhs = nullptr;
ir::value *new_rhs = nullptr;
if(dynamic_cast<ir::reshape_inst*>(op)){
builder.set_insert_point(op);
ir::value *new_lhs = builder.create_reshape(old_lhs, shapes);
ir::value *new_rhs = builder.create_reshape(old_rhs, shapes);
new_lhs = builder.create_reshape(old_lhs, shapes);
new_rhs = builder.create_reshape(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(dynamic_cast<ir::broadcast_inst*>(op)){
builder.set_insert_point(op);
ir::value *new_lhs = builder.create_broadcast(old_lhs, shapes);
ir::value *new_rhs = builder.create_broadcast(old_rhs, shapes);
new_lhs = builder.create_broadcast(old_lhs, shapes);
new_rhs = builder.create_broadcast(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
if(dynamic_cast<ir::splat_inst*>(op)){
builder.set_insert_point(op);
ir::value *new_lhs = builder.create_splat(old_lhs, shapes);
ir::value *new_rhs = builder.create_splat(old_rhs, shapes);
new_lhs = builder.create_splat(old_lhs, shapes);
new_rhs = builder.create_splat(old_rhs, shapes);
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
}
to_delete.push_back(op);
if(new_value != old_value){
params_->copy(new_value, old_value);
params_->copy(new_lhs, old_value);
params_->copy(new_rhs, old_value);
to_delete.push_back(op);
}
}
}
@@ -107,8 +115,9 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
// (x + cst) + y -> cst + (x + y)
if(is_cst(rlhs))
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
if(new_value != op)
if(new_value != old_value){
to_delete.push_back(bin_lhs);
}
}
// x + (y + z)
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
@@ -123,6 +132,11 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
if(new_value != op)
to_delete.push_back(bin_rhs);
}
if(new_value != old_value){
params_->copy(new_value, old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
}
}
// extract constant and non-constant
@@ -149,13 +163,39 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
return new_value;
}
reassociate::reassociate() {
}
reassociate::reassociate(tune* params)
: params_(params)
{ }
void reassociate::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
std::vector<ir::instruction*> to_delete;
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::constant_range*> ranges;
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
for(ir::basic_block *block: rpo){
// iterate through instruction
for(ir::instruction *i: block->get_inst_list())
for(ir::value* op: i->ops())
if(auto *range = dynamic_cast<ir::constant_range*>(op))
ranges.push_back(range);
}
builder.set_insert_point(rpo.front()->get_first_non_phi());
for(ir::constant_range* old_range: ranges){
ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type()));
ir::value* static_range = ir::nv_static_range_idx::get(old_range);
ir::value* new_range = builder.create_add(dyn_range, static_range);
old_range->replace_all_uses_with(new_range);
params_->copy(dyn_range, old_range);
params_->copy(static_range, old_range);
params_->copy(new_range, old_range);
}
}
// reassociate
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
bool done = false;

View File

@@ -690,12 +690,22 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){
if(dynamic_cast<ir::constant_range*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
});
}
if(dynamic_cast<ir::nv_static_range_idx*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *res = bin_add->getOperand(1);
assert(isa<Constant>(res));
T->set_value(idx, res);
});
}
}
}
@@ -835,6 +845,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
result->set_value(idx, builder.CreateAdd(bin, offset));
});
}
// nv_dynamic_range_idx_inst
if(dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins)){
result->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
assert(bin_add);
Value *res = bin_add->getOperand(0);
result->set_value(idx, res);
});
}
// // mask
// else if(dynamic_cast<ir::mask_inst*>(ins)) {
// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));

View File

@@ -133,7 +133,7 @@ tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
}
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
groups_[x.first][x.second] = group_id;
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
@@ -145,11 +145,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
params_[x.first].insert({"shape" + suffix, mp});
}
if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
unsigned ax = range->get_axis();
global_range_sizes_[ax] = params_[x.first].at("shape.d0");
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
}
// if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
// unsigned ax = range->get_axis();
// global_range_sizes_[ax] = params_[x.first].at("shape.d0");
// num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
// }
if(static_params_.find(x) != static_params_.end()){
for(ir::metaparameter *mp: mps)
mp->set_value(static_params_.at(x));
@@ -190,6 +190,14 @@ unsigned tune::get_param_group(ir::value *value, unsigned ax) {
return result;
}
//TODO: This shouldn't exist!
void tune::copy(ir::value *dst, ir::value *src) {
params_[dst] = params_[src];
groups_[dst] = groups_[src];
fragments_[{dst, 0}] = fragments_[{src, 0}];
}
void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
// Create metaparameters

View File

@@ -59,6 +59,8 @@ constant_int *constant_int::get(type *ty, uint64_t value) {
// constant_range
// FIXME use something like APInt
//"[" + std::to_string(first->get_value()) + " ... " + std::to_string(ty->get_tile_shapes()[0]->get_value()) + "]"
constant_range::constant_range(type *ty, constant_int *first, constant_int *last)
: constant(ty, 0), first_(first), last_(last){ }

View File

@@ -688,22 +688,48 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
//===----------------------------------------------------------------------===//
// intrinsic instructions
//===----------------------------------------------------------------------===//
// copy to shared
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
instruction *next) {
return new copy_to_shared_inst(arg->get_type(), arg, name, next);
}
// vectorize
vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, instruction *next) {
return new vectorize_inst(arg->get_type(), arg, name, next);
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), 0, 0, name, next){ }
: instruction(type::get_void_ty(ctx), 0, 0, name, next) { }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
return new barrier_inst(ctx, name, next);
}
// nv_dynamic_range_idx
nv_dynamic_range_idx_inst::nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next)
: instruction(ty, 0, 1, name, next) { }
nv_dynamic_range_idx_inst* nv_dynamic_range_idx_inst::create(type *ty, const std::string &name, instruction *next) {
return new nv_dynamic_range_idx_inst(ty, name, next);
}
// nv_static_range_idx
nv_static_range_idx::nv_static_range_idx(constant_range *range)
: constant(range->get_type(), 0), range_(range) { }
constant_range* nv_static_range_idx::get_range() const
{ return range_; }
nv_static_range_idx* nv_static_range_idx::get(constant_range* range) {
static std::map<constant_range*, nv_static_range_idx*> cache;
if(cache.find(range) == cache.end())
cache.insert({range, new nv_static_range_idx(range)});
return cache.at(range);
}
}
}