[code generation] now using ir::metaparameter* for all tunable

metaparameters
This commit is contained in:
Philippe Tillet
2019-03-09 12:05:12 -05:00
parent d049679aa2
commit 5f29263044
10 changed files with 80 additions and 49 deletions

View File

@@ -117,9 +117,9 @@ private:
// grid construction
void create_grids(std::vector<ir::value *> &grids,
std::map<unsigned *, ir::value *> &references,
std::map<ir::metaparameter *, ir::value *> &references,
ir::function *fn);
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<unsigned *, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<ir::metaparameter *, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr);
@@ -139,7 +139,7 @@ private:
allocation *alloc_;
tune *params_;
buffer_info_pass *buffer_info_;
std::map<unsigned*, distributed_axis> axes_;
std::map<ir::metaparameter*, distributed_axis> axes_;
};
}

View File

@@ -12,6 +12,7 @@ namespace ir{
class module;
class instruction;
class function;
class metaparameter;
}
namespace codegen{
@@ -24,24 +25,28 @@ private:
void add_constraint(node_t x, node_t y);
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
void connected_components(node_t x, const std::vector<unsigned*> vals, std::set<node_t> &nodes, graph_t &graph);
void create_grids(std::vector<ir::instruction*> &grids, std::map<unsigned*, ir::instruction*> &references, ir::function *fn);
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph);
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
public:
std::vector<unsigned *> get_params(ir::module& mod);
std::map<std::string, unsigned *> get_params(ir::instruction* i);
unsigned *get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
std::vector<ir::metaparameter *> get_params(ir::module& mod);
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
ir::metaparameter* get_num_threads();
ir::metaparameter* get_global_range_size(unsigned axis);
private:
std::map<ir::value*, std::map<std::string, unsigned*>> params_;
std::vector<unsigned*> pool_;
graph_t dependencies_;
std::set<node_t> nodes_;
std::map<node_t, unsigned> static_params_;
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
ir::metaparameter *num_threads_;
std::vector<ir::metaparameter*> global_range_sizes_;
};

View File

@@ -35,6 +35,11 @@ public:
// Constants
value *get_int32(unsigned val);
// Types
type *get_int1_ty();
type *get_int8_ty();
type *get_int16_ty();
type *get_int32_ty();
type *get_int64_ty();
type *get_float_ty();
type *get_double_ty();
// Insert

View File

@@ -49,11 +49,13 @@ class metaparameter: public constant_int{
public:
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
void set_value(uint64_t value) { value_ = value; }
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
bool has_value() { return has_value_; }
private:
unsigned lo_;
unsigned hi_;
bool has_value_;
};
/* constant range */

View File

@@ -379,9 +379,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
warp_size[i] = *params_->get_param(v, "p1.d" + str_i);
n_warps[i] = *params_->get_param(v, "p2.d" + str_i);
contiguous[i] = params_->get_param(v, "p0.d" + str_i)->get_value();
warp_size[i] = params_->get_param(v, "p1.d" + str_i)->get_value();
n_warps[i] = params_->get_param(v, "p2.d" + str_i)->get_value();
}
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
@@ -404,7 +404,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
}
void selection::create_grids(std::vector<ir::value*> &grids,
std::map<unsigned*, ir::value*> &references,
std::map<ir::metaparameter*, ir::value*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
@@ -432,7 +432,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
continue;
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
@@ -457,7 +457,7 @@ bool static inline has_phi_user(ir::value *v) {
return false;
}
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
const std::map<unsigned*, ir::value*>& references,
const std::map<ir::metaparameter*, ir::value*>& references,
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
return;
@@ -517,7 +517,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() > 1){
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
axes[d] = axes_.at(x);
}
else{
@@ -549,7 +549,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
// create grid
std::vector<ir::value*> grids;
std::map<unsigned*, ir::value*> references;
std::map<ir::metaparameter*, ir::value*> references;
create_grids(grids, references, fn);
for(ir::value* i: grids){
if(auto *instr = dynamic_cast<ir::instruction*>(i))

View File

@@ -4,6 +4,7 @@
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/context_impl.h"
#include "triton/ir/constant.h"
#include <cstdlib>
@@ -77,43 +78,44 @@ void tune::init_c_graph(ir::instruction *v) {
}
}
void tune::connected_components(node_t x, const std::vector<unsigned *> vals, std::set<node_t> &nodes, graph_t &graph) {
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph) {
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
params_[x.first].insert({"p0" + suffix, vals[0]});
params_[x.first].insert({"p1" + suffix, vals[1]});
params_[x.first].insert({"p2" + suffix, vals[2]});
params_[x.first].insert({"p0" + suffix, mps[0]});
params_[x.first].insert({"p1" + suffix, mps[1]});
params_[x.first].insert({"p2" + suffix, mps[2]});
if(static_params_.find(x) != static_params_.end()){
*vals[0] = static_params_.at(x);
*vals[1] = static_params_.at(x);
*vals[2] = static_params_.at(x);
mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x));
mps[2]->set_value(static_params_.at(x));
}
for(const node_t &y: graph[x])
connected_components(y, vals, nodes, graph);
connected_components(y, mps, nodes, graph);
}
}
std::vector<unsigned*> tune::get_params(ir::module &mod) {
std::vector<unsigned *> result;
std::set<unsigned*> seen;
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
std::vector<ir::metaparameter*> result;
std::set<ir::metaparameter*> seen;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
for(auto &x: params_[i])
if(seen.insert(x.second).second && *x.second == 0){
if(seen.insert(x.second).second && !x.second->has_value()){
result.push_back(x.second);
}
return result;
}
std::map<std::string, unsigned*> tune::get_params(ir::instruction* i) {
std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i) {
return params_.at(i);
}
void tune::run(ir::module &mod) {
ir::context &ctx = mod.get_context();
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
@@ -128,16 +130,17 @@ void tune::run(ir::module &mod) {
init_c_phi(i);
// Layout parameters
while(!nodes_.empty()){
unsigned *v0 = new unsigned(0);
unsigned *v1 = new unsigned(0);
unsigned *v2 = new unsigned(0);
connected_components(*nodes_.begin(), {v0, v1, v2}, nodes_, dependencies_);
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *mp0 = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *mp1 = ir::metaparameter::create(ctx, ty, 4, 32);
ir::metaparameter *mp2 = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
}
}
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
std::map<unsigned*, ir::instruction*> &references,
std::map<ir::metaparameter*, ir::instruction*> &references,
ir::function *fn) {
// get number of dimensions greater than 1
auto get_tile_gt1_dim = [&](ir::value *v){
@@ -154,7 +157,7 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
if(!i->get_type()->is_tile_ty())
continue;
for(auto &param: params_.at(i)){
if(*param.second == 1)
if(param.second->get_value() == 1)
continue;
ir::instruction *&r = references[param.second];
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
@@ -173,14 +176,14 @@ for(ir::function *fn: mod.get_function_list()){
using std::to_string;
// initialize grids
std::map<unsigned*, ir::instruction*> references;
std::map<ir::metaparameter*, ir::instruction*> references;
std::vector<ir::instruction*> grids;
create_grids(grids, references, fn);
// number of warps
int num_warps = 1;
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
num_warps *= *params_[grids.front()]["p2.d" + to_string(k)];
num_warps *= params_[grids.front()]["p2.d" + to_string(k)]->get_value();
// check constraints
for(ir::instruction *i: grids){
@@ -190,10 +193,10 @@ for(ir::function *fn: mod.get_function_list()){
// must device the shape
for(size_t k = 0; k < shapes.size(); k++) {
std::string strk = to_string(k);
unsigned *s0 = params_[i]["p0.d" + strk];
unsigned *s1 = params_[i]["p1.d" + strk];
unsigned *s2 = params_[i]["p2.d" + strk];
unsigned multiple = (*s0)*(*s1)*(*s2);
ir::metaparameter *mp0 = params_[i]["p0.d" + strk];
ir::metaparameter *mp1 = params_[i]["p1.d" + strk];
ir::metaparameter *mp2 = params_[i]["p2.d" + strk];
unsigned multiple = mp0->get_value()*mp1->get_value()*mp2->get_value();
if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
" is not a multiple of layout (" + to_string(multiple) + ")");
@@ -201,14 +204,14 @@ for(ir::function *fn: mod.get_function_list()){
// the number of thread per warp must be 32
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= *params_[i]["p1.d" + to_string(k)];
num_threads *= params_[i]["p1.d" + to_string(k)]->get_value();
if(num_threads != 32)
errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32");
// The number of warps required by the layout is the same
// for all tiles in the function
int required_num_warps = 1;
for(size_t k = 0; k < shapes.size(); k++)
required_num_warps *= *params_[i]["p2.d" + to_string(k)];
required_num_warps *= params_[i]["p2.d" + to_string(k)]->get_value();
if(required_num_warps != num_warps)
errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps));
}

View File

@@ -16,7 +16,7 @@ void vectorize::run(ir::module &mod) {
for(ir::instruction *i: block->get_inst_list())
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
ir::value *x = i->get_operand(0);
if(*params_->get_param(x, "p0.d0") == 1)
if(params_->get_param(x, "p0.d0")->get_value() == 1)
continue;
builder.set_insert_point(i);
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);

View File

@@ -41,6 +41,21 @@ value *builder::get_int32(unsigned val) {
return constant_int::get(type::get_int32_ty(ctx_), val);
}
type *builder::get_int1_ty()
{ return type::get_int1_ty(ctx_); }
type *builder::get_int8_ty()
{ return type::get_int8_ty(ctx_); }
type *builder::get_int16_ty()
{ return type::get_int16_ty(ctx_); }
type *builder::get_int32_ty()
{ return type::get_int32_ty(ctx_); }
type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); }
type *builder::get_float_ty()
{ return type::get_float_ty(ctx_); }

View File

@@ -99,7 +99,7 @@ constant *constant_fp::get(context &ctx, double v){
// metaparameter
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
: constant_int(ty, 0), lo_(lo), hi_(hi){ }
: constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ }
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
context_impl *impl = ctx.p_impl.get();

View File

@@ -65,8 +65,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
for(unsigned *x: tune.get_params(module))
*x = params[3 + i++];
for(ir::metaparameter *x: tune.get_params(module)){
x->set_value(params[3 + i++]);
}
// constraints
std::map<ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);