[code generation] now using ir::metaparameter* for all tunable
metaparameters
This commit is contained in:
@@ -117,9 +117,9 @@ private:
|
||||
|
||||
// grid construction
|
||||
void create_grids(std::vector<ir::value *> &grids,
|
||||
std::map<unsigned *, ir::value *> &references,
|
||||
std::map<ir::metaparameter *, ir::value *> &references,
|
||||
ir::function *fn);
|
||||
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<unsigned *, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
|
||||
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<ir::metaparameter *, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
|
||||
void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
|
||||
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr);
|
||||
|
||||
@@ -139,7 +139,7 @@ private:
|
||||
allocation *alloc_;
|
||||
tune *params_;
|
||||
buffer_info_pass *buffer_info_;
|
||||
std::map<unsigned*, distributed_axis> axes_;
|
||||
std::map<ir::metaparameter*, distributed_axis> axes_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -12,6 +12,7 @@ namespace ir{
|
||||
class module;
|
||||
class instruction;
|
||||
class function;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -24,24 +25,28 @@ private:
|
||||
void add_constraint(node_t x, node_t y);
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
void connected_components(node_t x, const std::vector<unsigned*> vals, std::set<node_t> &nodes, graph_t &graph);
|
||||
void create_grids(std::vector<ir::instruction*> &grids, std::map<unsigned*, ir::instruction*> &references, ir::function *fn);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph);
|
||||
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
|
||||
|
||||
|
||||
public:
|
||||
std::vector<unsigned *> get_params(ir::module& mod);
|
||||
std::map<std::string, unsigned *> get_params(ir::instruction* i);
|
||||
unsigned *get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
std::vector<ir::metaparameter *> get_params(ir::module& mod);
|
||||
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
|
||||
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
ir::metaparameter* get_num_threads();
|
||||
ir::metaparameter* get_global_range_size(unsigned axis);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, std::map<std::string, unsigned*>> params_;
|
||||
std::vector<unsigned*> pool_;
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
std::map<node_t, unsigned> static_params_;
|
||||
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
|
||||
ir::metaparameter *num_threads_;
|
||||
std::vector<ir::metaparameter*> global_range_sizes_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -35,6 +35,11 @@ public:
|
||||
// Constants
|
||||
value *get_int32(unsigned val);
|
||||
// Types
|
||||
type *get_int1_ty();
|
||||
type *get_int8_ty();
|
||||
type *get_int16_ty();
|
||||
type *get_int32_ty();
|
||||
type *get_int64_ty();
|
||||
type *get_float_ty();
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
|
@@ -49,11 +49,13 @@ class metaparameter: public constant_int{
|
||||
|
||||
public:
|
||||
static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi);
|
||||
void set_value(uint64_t value) { value_ = value; }
|
||||
void set_value(uint64_t value) { has_value_ = true; value_ = value; }
|
||||
bool has_value() { return has_value_; }
|
||||
|
||||
private:
|
||||
unsigned lo_;
|
||||
unsigned hi_;
|
||||
bool has_value_;
|
||||
};
|
||||
|
||||
/* constant range */
|
||||
|
@@ -379,9 +379,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = *params_->get_param(v, "p0.d" + str_i);
|
||||
warp_size[i] = *params_->get_param(v, "p1.d" + str_i);
|
||||
n_warps[i] = *params_->get_param(v, "p2.d" + str_i);
|
||||
contiguous[i] = params_->get_param(v, "p0.d" + str_i)->get_value();
|
||||
warp_size[i] = params_->get_param(v, "p1.d" + str_i)->get_value();
|
||||
n_warps[i] = params_->get_param(v, "p2.d" + str_i)->get_value();
|
||||
}
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||
@@ -404,7 +404,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
}
|
||||
|
||||
void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned*, ir::value*> &references,
|
||||
std::map<ir::metaparameter*, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
@@ -432,7 +432,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
continue;
|
||||
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
ir::value *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
@@ -457,7 +457,7 @@ bool static inline has_phi_user(ir::value *v) {
|
||||
return false;
|
||||
}
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
const std::map<unsigned*, ir::value*>& references,
|
||||
const std::map<ir::metaparameter*, ir::value*>& references,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
@@ -517,7 +517,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() > 1){
|
||||
unsigned *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
@@ -549,7 +549,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
std::vector<ir::value*> grids;
|
||||
std::map<unsigned*, ir::value*> references;
|
||||
std::map<ir::metaparameter*, ir::value*> references;
|
||||
create_grids(grids, references, fn);
|
||||
for(ir::value* i: grids){
|
||||
if(auto *instr = dynamic_cast<ir::instruction*>(i))
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/ir/constant.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
@@ -77,43 +78,44 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
}
|
||||
|
||||
void tune::connected_components(node_t x, const std::vector<unsigned *> vals, std::set<node_t> &nodes, graph_t &graph) {
|
||||
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph) {
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
std::string suffix = ".d" + std::to_string(x.second);
|
||||
params_[x.first].insert({"p0" + suffix, vals[0]});
|
||||
params_[x.first].insert({"p1" + suffix, vals[1]});
|
||||
params_[x.first].insert({"p2" + suffix, vals[2]});
|
||||
params_[x.first].insert({"p0" + suffix, mps[0]});
|
||||
params_[x.first].insert({"p1" + suffix, mps[1]});
|
||||
params_[x.first].insert({"p2" + suffix, mps[2]});
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
*vals[0] = static_params_.at(x);
|
||||
*vals[1] = static_params_.at(x);
|
||||
*vals[2] = static_params_.at(x);
|
||||
mps[0]->set_value(static_params_.at(x));
|
||||
mps[1]->set_value(static_params_.at(x));
|
||||
mps[2]->set_value(static_params_.at(x));
|
||||
}
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, vals, nodes, graph);
|
||||
connected_components(y, mps, nodes, graph);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<unsigned*> tune::get_params(ir::module &mod) {
|
||||
std::vector<unsigned *> result;
|
||||
std::set<unsigned*> seen;
|
||||
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
std::vector<ir::metaparameter*> result;
|
||||
std::set<ir::metaparameter*> seen;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
for(auto &x: params_[i])
|
||||
if(seen.insert(x.second).second && *x.second == 0){
|
||||
if(seen.insert(x.second).second && !x.second->has_value()){
|
||||
result.push_back(x.second);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, unsigned*> tune::get_params(ir::instruction* i) {
|
||||
std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i) {
|
||||
return params_.at(i);
|
||||
}
|
||||
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -128,16 +130,17 @@ void tune::run(ir::module &mod) {
|
||||
init_c_phi(i);
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
unsigned *v0 = new unsigned(0);
|
||||
unsigned *v1 = new unsigned(0);
|
||||
unsigned *v2 = new unsigned(0);
|
||||
connected_components(*nodes_.begin(), {v0, v1, v2}, nodes_, dependencies_);
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *mp0 = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
ir::metaparameter *mp1 = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
ir::metaparameter *mp2 = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<unsigned*, ir::instruction*> &references,
|
||||
std::map<ir::metaparameter*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
@@ -154,7 +157,7 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
for(auto ¶m: params_.at(i)){
|
||||
if(*param.second == 1)
|
||||
if(param.second->get_value() == 1)
|
||||
continue;
|
||||
ir::instruction *&r = references[param.second];
|
||||
if(!r || get_tile_gt1_dim(i) > get_tile_gt1_dim(r))
|
||||
@@ -173,14 +176,14 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
using std::to_string;
|
||||
|
||||
// initialize grids
|
||||
std::map<unsigned*, ir::instruction*> references;
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
std::vector<ir::instruction*> grids;
|
||||
create_grids(grids, references, fn);
|
||||
|
||||
// number of warps
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
|
||||
num_warps *= *params_[grids.front()]["p2.d" + to_string(k)];
|
||||
num_warps *= params_[grids.front()]["p2.d" + to_string(k)]->get_value();
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids){
|
||||
@@ -190,10 +193,10 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
// must device the shape
|
||||
for(size_t k = 0; k < shapes.size(); k++) {
|
||||
std::string strk = to_string(k);
|
||||
unsigned *s0 = params_[i]["p0.d" + strk];
|
||||
unsigned *s1 = params_[i]["p1.d" + strk];
|
||||
unsigned *s2 = params_[i]["p2.d" + strk];
|
||||
unsigned multiple = (*s0)*(*s1)*(*s2);
|
||||
ir::metaparameter *mp0 = params_[i]["p0.d" + strk];
|
||||
ir::metaparameter *mp1 = params_[i]["p1.d" + strk];
|
||||
ir::metaparameter *mp2 = params_[i]["p2.d" + strk];
|
||||
unsigned multiple = mp0->get_value()*mp1->get_value()*mp2->get_value();
|
||||
if(shapes[k]->get_value() % multiple != 0)
|
||||
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
|
||||
" is not a multiple of layout (" + to_string(multiple) + ")");
|
||||
@@ -201,14 +204,14 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
// the number of thread per warp must be 32
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= *params_[i]["p1.d" + to_string(k)];
|
||||
num_threads *= params_[i]["p1.d" + to_string(k)]->get_value();
|
||||
if(num_threads != 32)
|
||||
errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32");
|
||||
// The number of warps required by the layout is the same
|
||||
// for all tiles in the function
|
||||
int required_num_warps = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
required_num_warps *= *params_[i]["p2.d" + to_string(k)];
|
||||
required_num_warps *= params_[i]["p2.d" + to_string(k)]->get_value();
|
||||
if(required_num_warps != num_warps)
|
||||
errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps));
|
||||
}
|
||||
|
@@ -16,7 +16,7 @@ void vectorize::run(ir::module &mod) {
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(*params_->get_param(x, "p0.d0") == 1)
|
||||
if(params_->get_param(x, "p0.d0")->get_value() == 1)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
|
@@ -41,6 +41,21 @@ value *builder::get_int32(unsigned val) {
|
||||
return constant_int::get(type::get_int32_ty(ctx_), val);
|
||||
}
|
||||
|
||||
type *builder::get_int1_ty()
|
||||
{ return type::get_int1_ty(ctx_); }
|
||||
|
||||
type *builder::get_int8_ty()
|
||||
{ return type::get_int8_ty(ctx_); }
|
||||
|
||||
type *builder::get_int16_ty()
|
||||
{ return type::get_int16_ty(ctx_); }
|
||||
|
||||
type *builder::get_int32_ty()
|
||||
{ return type::get_int32_ty(ctx_); }
|
||||
|
||||
type *builder::get_int64_ty()
|
||||
{ return type::get_int64_ty(ctx_); }
|
||||
|
||||
type *builder::get_float_ty()
|
||||
{ return type::get_float_ty(ctx_); }
|
||||
|
||||
|
@@ -99,7 +99,7 @@ constant *constant_fp::get(context &ctx, double v){
|
||||
|
||||
// metaparameter
|
||||
metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi)
|
||||
: constant_int(ty, 0), lo_(lo), hi_(hi){ }
|
||||
: constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ }
|
||||
|
||||
metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) {
|
||||
context_impl *impl = ctx.p_impl.get();
|
||||
|
@@ -65,8 +65,9 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, const st
|
||||
triton_context_.p_impl->mp_constants_[0]->set_value(params[0]);
|
||||
triton_context_.p_impl->mp_constants_[1]->set_value(params[1]);
|
||||
triton_context_.p_impl->mp_constants_[2]->set_value(params[2]);
|
||||
for(unsigned *x: tune.get_params(module))
|
||||
*x = params[3 + i++];
|
||||
for(ir::metaparameter *x: tune.get_params(module)){
|
||||
x->set_value(params[3 + i++]);
|
||||
}
|
||||
// constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
tune.check_constraints(module, errors);
|
||||
|
Reference in New Issue
Block a user