progress on re-association
This commit is contained in:
@@ -32,6 +32,7 @@ public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get_starting_multiple(ir::value* v) const;
|
||||
unsigned get_max_contiguous(ir::value* v) const;
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, cst_info> is_constant_;
|
||||
|
@@ -18,14 +18,19 @@ class instruction;
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class tune;
|
||||
|
||||
class reassociate {
|
||||
private:
|
||||
ir::instruction* is_bin_add(ir::value *x);
|
||||
ir::value *reorder_op(ir::value *value, ir::builder &builder, std::vector<ir::instruction*>& to_delete, ir::value *&noncst, ir::value *&cst);
|
||||
|
||||
public:
|
||||
reassociate();
|
||||
reassociate(tune *params);
|
||||
void run(ir::module& module);
|
||||
|
||||
private:
|
||||
tune* params_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -44,7 +44,7 @@ public:
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; groups_[dst] = groups_[src]; }
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
void init(ir::module &mod);
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#define TDL_INCLUDE_IR_INSTRUCTIONS_H
|
||||
|
||||
#include <vector>
|
||||
#include "triton/ir/constant.h"
|
||||
#include "triton/ir/value.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
@@ -651,6 +652,31 @@ public:
|
||||
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
// On NVIDIA, implementation is such that
|
||||
// constant_range = nv_dynamic_range_idx + nv_static_range_idx
|
||||
// so as to enable re-association on nv_static_range_idx which is constant
|
||||
class nv_dynamic_range_idx_inst: public instruction {
|
||||
private:
|
||||
nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "nv_dynamic_range_idx"; }
|
||||
|
||||
public:
|
||||
static nv_dynamic_range_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class nv_static_range_idx: public constant {
|
||||
private:
|
||||
nv_static_range_idx(constant_range *range);
|
||||
|
||||
public:
|
||||
static nv_static_range_idx *get(constant_range* range);
|
||||
constant_range* get_range() const;
|
||||
|
||||
private:
|
||||
constant_range *range_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -66,18 +66,19 @@ public:
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
alignment_info(),
|
||||
reassociate(&tune),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
ir::print(module, std::cout);
|
||||
reassociate_.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
ir::print(module, std::cout);
|
||||
alignment_info.run(module);
|
||||
reassociate.run(module);
|
||||
ir::print(module, std::cout);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
@@ -98,7 +99,7 @@ public:
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::alignment_info alignment_info;
|
||||
codegen::reassociate reassociate_;
|
||||
codegen::reassociate reassociate;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
|
@@ -228,6 +228,12 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
||||
return cache(x->get_first()->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(v)){
|
||||
return cache(128);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
return cache(x->get_range()->get_first()->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
int lhs = populate_starting_multiple(x->get_operand(0));
|
||||
int rhs = populate_starting_multiple(x->get_operand(1));
|
||||
@@ -280,6 +286,12 @@ unsigned alignment_info::get_max_contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
void alignment_info::copy(ir::value *dst, ir::value *src) {
|
||||
starting_multiple_[dst] = starting_multiple_[src];
|
||||
max_contiguous_[dst] = max_contiguous_[src];
|
||||
is_constant_[dst] = is_constant_[src];
|
||||
}
|
||||
|
||||
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
|
||||
void alignment_info::run(ir::module &mod) {
|
||||
// populate constant
|
||||
@@ -301,7 +313,7 @@ void alignment_info::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
// std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
|
||||
std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
@@ -68,25 +69,32 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
|
||||
if(cst){
|
||||
ir::value *old_lhs = bin_add->get_operand(0);
|
||||
ir::value *old_rhs = bin_add->get_operand(1);
|
||||
ir::value *new_lhs = nullptr;
|
||||
ir::value *new_rhs = nullptr;
|
||||
if(dynamic_cast<ir::reshape_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
ir::value *new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
ir::value *new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_lhs = builder.create_reshape(old_lhs, shapes);
|
||||
new_rhs = builder.create_reshape(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::broadcast_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
ir::value *new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
ir::value *new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_lhs = builder.create_broadcast(old_lhs, shapes);
|
||||
new_rhs = builder.create_broadcast(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
if(dynamic_cast<ir::splat_inst*>(op)){
|
||||
builder.set_insert_point(op);
|
||||
ir::value *new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
ir::value *new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_lhs = builder.create_splat(old_lhs, shapes);
|
||||
new_rhs = builder.create_splat(old_rhs, shapes);
|
||||
new_value = builder.create_add(new_lhs, new_rhs, op->get_name());
|
||||
}
|
||||
to_delete.push_back(op);
|
||||
if(new_value != old_value){
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(new_lhs, old_value);
|
||||
params_->copy(new_rhs, old_value);
|
||||
to_delete.push_back(op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,8 +115,9 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(is_cst(rlhs))
|
||||
new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs), name);
|
||||
if(new_value != op)
|
||||
if(new_value != old_value){
|
||||
to_delete.push_back(bin_lhs);
|
||||
}
|
||||
}
|
||||
// x + (y + z)
|
||||
if(ir::instruction* bin_rhs = is_bin_add(rhs)){
|
||||
@@ -123,6 +132,11 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
|
||||
if(new_value != op)
|
||||
to_delete.push_back(bin_rhs);
|
||||
}
|
||||
if(new_value != old_value){
|
||||
params_->copy(new_value, old_value);
|
||||
params_->copy(((ir::instruction*)new_value)->get_operand(0), old_value);
|
||||
params_->copy(((ir::instruction*)new_value)->get_operand(1), old_value);
|
||||
}
|
||||
}
|
||||
|
||||
// extract constant and non-constant
|
||||
@@ -149,13 +163,39 @@ inline ir::value *reassociate::reorder_op(ir::value *old_value,
|
||||
return new_value;
|
||||
}
|
||||
|
||||
reassociate::reassociate() {
|
||||
|
||||
}
|
||||
reassociate::reassociate(tune* params)
|
||||
: params_(params)
|
||||
{ }
|
||||
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
|
||||
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::constant_range*> ranges;
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
for(ir::basic_block *block: rpo){
|
||||
// iterate through instruction
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::value* op: i->ops())
|
||||
if(auto *range = dynamic_cast<ir::constant_range*>(op))
|
||||
ranges.push_back(range);
|
||||
}
|
||||
|
||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||
for(ir::constant_range* old_range: ranges){
|
||||
ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::nv_static_range_idx::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
params_->copy(dyn_range, old_range);
|
||||
params_->copy(static_range, old_range);
|
||||
params_->copy(new_range, old_range);
|
||||
}
|
||||
}
|
||||
|
||||
// reassociate
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
bool done = false;
|
||||
|
@@ -690,12 +690,22 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize);
|
||||
tmap_.insert({v, T});
|
||||
// constant range
|
||||
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){
|
||||
if(dynamic_cast<ir::constant_range*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
if(dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(1);
|
||||
assert(isa<Constant>(res));
|
||||
T->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -835,6 +845,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
result->set_value(idx, builder.CreateAdd(bin, offset));
|
||||
});
|
||||
}
|
||||
// nv_dynamic_range_idx_inst
|
||||
if(dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins)){
|
||||
result->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
assert(bin_add);
|
||||
Value *res = bin_add->getOperand(0);
|
||||
result->set_value(idx, res);
|
||||
});
|
||||
}
|
||||
// // mask
|
||||
// else if(dynamic_cast<ir::mask_inst*>(ins)) {
|
||||
// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
|
@@ -133,7 +133,7 @@ tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
}
|
||||
|
||||
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x.first][x.second] = group_id;
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
std::string suffix = ".d" + std::to_string(x.second);
|
||||
@@ -145,11 +145,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(shape))
|
||||
params_[x.first].insert({"shape" + suffix, mp});
|
||||
}
|
||||
if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
|
||||
unsigned ax = range->get_axis();
|
||||
global_range_sizes_[ax] = params_[x.first].at("shape.d0");
|
||||
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
}
|
||||
// if(auto range = dynamic_cast<ir::get_global_range_inst*>(x.first)){
|
||||
// unsigned ax = range->get_axis();
|
||||
// global_range_sizes_[ax] = params_[x.first].at("shape.d0");
|
||||
// num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
// }
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(static_params_.at(x));
|
||||
@@ -190,6 +190,14 @@ unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
return result;
|
||||
}
|
||||
|
||||
//TODO: This shouldn't exist!
|
||||
void tune::copy(ir::value *dst, ir::value *src) {
|
||||
params_[dst] = params_[src];
|
||||
groups_[dst] = groups_[src];
|
||||
fragments_[{dst, 0}] = fragments_[{src, 0}];
|
||||
}
|
||||
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
// Create metaparameters
|
||||
|
@@ -59,6 +59,8 @@ constant_int *constant_int::get(type *ty, uint64_t value) {
|
||||
// constant_range
|
||||
// FIXME use something like APInt
|
||||
|
||||
//"[" + std::to_string(first->get_value()) + " ... " + std::to_string(ty->get_tile_shapes()[0]->get_value()) + "]"
|
||||
|
||||
constant_range::constant_range(type *ty, constant_int *first, constant_int *last)
|
||||
: constant(ty, 0), first_(first), last_(last){ }
|
||||
|
||||
|
@@ -688,22 +688,48 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
|
||||
//===----------------------------------------------------------------------===//
|
||||
// intrinsic instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
// copy to shared
|
||||
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_to_shared_inst(arg->get_type(), arg, name, next);
|
||||
}
|
||||
|
||||
// vectorize
|
||||
vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new vectorize_inst(arg->get_type(), arg, name, next);
|
||||
}
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), 0, 0, name, next){ }
|
||||
: instruction(type::get_void_ty(ctx), 0, 0, name, next) { }
|
||||
|
||||
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new barrier_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
// nv_dynamic_range_idx
|
||||
nv_dynamic_range_idx_inst::nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, 0, 1, name, next) { }
|
||||
|
||||
nv_dynamic_range_idx_inst* nv_dynamic_range_idx_inst::create(type *ty, const std::string &name, instruction *next) {
|
||||
return new nv_dynamic_range_idx_inst(ty, name, next);
|
||||
}
|
||||
|
||||
// nv_static_range_idx
|
||||
nv_static_range_idx::nv_static_range_idx(constant_range *range)
|
||||
: constant(range->get_type(), 0), range_(range) { }
|
||||
|
||||
constant_range* nv_static_range_idx::get_range() const
|
||||
{ return range_; }
|
||||
|
||||
nv_static_range_idx* nv_static_range_idx::get(constant_range* range) {
|
||||
static std::map<constant_range*, nv_static_range_idx*> cache;
|
||||
if(cache.find(range) == cache.end())
|
||||
cache.insert({range, new nv_static_range_idx(range)});
|
||||
return cache.at(range);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user