Reparameterized in terms of micro- and nano- tiles

This commit is contained in:
Philippe Tillet
2019-03-10 23:10:17 -04:00
parent c96a263896
commit 94e315ea8a
7 changed files with 62 additions and 53 deletions

View File

@@ -108,18 +108,32 @@ int main() {
return float(0);
};
// std::vector<unsigned> params = {
// // a0
// 2, 8, 1, 16,
// // b0
// 4, 4, 1, 16,
// // c
// 2, 4, 8, 4, 1, 1,
// // a1
// 2, 4, 1, 8,
// // b1
// 1, 8, 1
// };
// just-in-time compile source-code
std::vector<unsigned> params = {
// a0
2, 8, 1, 16,
8, 2, 16,
// b0
4, 4, 1, 16,
4, 4, 16,
// c
2, 4, 8, 4, 1, 1,
8, 4, 2, 4,
// a1
2, 4, 1, 8,
4, 2, 8,
// b1
1, 8, 1
8, 1
};
triton::jit jit(context);
jit.add_module(src, params);

View File

@@ -47,9 +47,9 @@ private:
std::set<node_t> nodes_;
std::map<node_t, unsigned> static_params_;
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
std::vector<ir::metaparameter*> num_threads_mp_vec_;
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
unsigned num_global_ranges_;
unsigned num_threads_;
};

View File

@@ -412,7 +412,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
assert(expr_ == nullptr);
//TODO: implement ranges
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
value = ir::metaparameter::create(mod->get_context(), ty, 8, 128);
}
if(expr_){
value = expr_->codegen(mod);

View File

@@ -379,9 +379,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = params_->get_param(v, "p0.d" + str_i)->get_value();
warp_size[i] = params_->get_param(v, "p1.d" + str_i)->get_value();
n_warps[i] = params_->get_param(v, "p2.d" + str_i)->get_value();
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
warp_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
n_warps[i] = shapes[i]->get_value() / (contiguous[i] * warp_size[i]);
}
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
@@ -399,7 +399,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->get_param(v, "p0.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
axes_[params_->get_param(v, "nts.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
}
}
@@ -432,7 +432,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() == 1)
continue;
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
ir::value *&r = references[x];
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
r = v;
@@ -517,7 +517,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d]->get_value() > 1){
ir::metaparameter *x = params_->get_param(v, "p0.d" + std::to_string(d));
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
axes[d] = axes_.at(x);
}
else{

View File

@@ -84,9 +84,8 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
std::string suffix = ".d" + std::to_string(x.second);
params_[x.first].insert({"p0" + suffix, mps[0]});
params_[x.first].insert({"p1" + suffix, mps[1]});
params_[x.first].insert({"p2" + suffix, mps[2]});
params_[x.first].insert({"nts" + suffix, mps[0]});
params_[x.first].insert({"mts" + suffix, mps[1]});
ir::type *ty = x.first->get_type();
if(ty->is_tile_ty()){
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
@@ -101,7 +100,6 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
if(static_params_.find(x) != static_params_.end()){
mps[0]->set_value(static_params_.at(x));
mps[1]->set_value(static_params_.at(x));
mps[2]->set_value(static_params_.at(x));
}
for(const node_t &y: graph[x])
connected_components(y, mps, nodes, graph);
@@ -145,25 +143,11 @@ void tune::run(ir::module &mod) {
// Layout parameters
while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *mp0 = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *mp1 = ir::metaparameter::create(ctx, ty, 4, 8);
ir::metaparameter *mp2 = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(*nodes_.begin(), {mp0, mp1, mp2}, nodes_, dependencies_);
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 8);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
}
}
// // Get launch info
// for(ir::function *fn: mod.get_function_list()){
// std::map<ir::metaparameter*, ir::instruction*> references;
// std::vector<ir::instruction*> grids;
// create_grids(grids, references, fn);
// ir::instruction *first = grids.front();
// for(unsigned i = 0; i < first->get_type()->get_tile_shapes().size(); i++){
// std::string suffix = ".d" + std::to_string(i);
// num_threads_mp_vec_.push_back(params_.at(first).at("p1" + suffix));
// num_threads_mp_vec_.push_back(params_.at(first).at("p2" + suffix));
// }
// }
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
@@ -207,16 +191,26 @@ for(ir::function *fn: mod.get_function_list()){
std::vector<ir::instruction*> grids;
create_grids(grids, references, fn);
for(unsigned i = 0; i < grids.front()->get_type()->get_tile_shapes().size(); i++){
std::string suffix = ".d" + std::to_string(i);
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p1" + suffix));
num_threads_mp_vec_.push_back(params_.at(grids.front()).at("p2" + suffix));
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
std::string strk = to_string(axis);
unsigned mts = params_[i]["mts.d" + strk]->get_value();
unsigned nts = params_[i]["nts.d" + strk]->get_value();
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
return shape / (mts * nts);
};
num_threads_ = 1;
ir::instruction *first = grids.front();
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
num_threads_ *= get_num_warps(first, k);
}
// number of warps
int num_warps = 1;
for(size_t k = 0; k < grids.front()->get_type()->get_tile_shapes().size(); k++)
num_warps *= params_[grids.front()]["p2.d" + to_string(k)]->get_value();
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
num_warps *= get_num_warps(first, k);
// check constraints
for(ir::instruction *i: grids){
@@ -226,10 +220,9 @@ for(ir::function *fn: mod.get_function_list()){
// must device the shape
for(size_t k = 0; k < shapes.size(); k++) {
std::string strk = to_string(k);
ir::metaparameter *mp0 = params_[i]["p0.d" + strk];
ir::metaparameter *mp1 = params_[i]["p1.d" + strk];
ir::metaparameter *mp2 = params_[i]["p2.d" + strk];
unsigned multiple = mp0->get_value()*mp1->get_value()*mp2->get_value();
ir::metaparameter *mts = params_[i]["mts.d" + strk];
ir::metaparameter *nts = params_[i]["nts.d" + strk];
unsigned multiple = mts->get_value()*nts->get_value();
if(shapes[k]->get_value() % multiple != 0)
errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")"
" is not a multiple of layout (" + to_string(multiple) + ")");
@@ -237,14 +230,14 @@ for(ir::function *fn: mod.get_function_list()){
// the number of thread per warp must be 32
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= params_[i]["p1.d" + to_string(k)]->get_value();
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
if(num_threads != 32)
errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32");
// The number of warps required by the layout is the same
// for all tiles in the function
int required_num_warps = 1;
for(size_t k = 0; k < shapes.size(); k++)
required_num_warps *= params_[i]["p2.d" + to_string(k)]->get_value();
required_num_warps *= get_num_warps(i, k);
if(required_num_warps != num_warps)
errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps));
}
@@ -261,10 +254,7 @@ unsigned tune::get_global_range_size(unsigned axis) {
}
unsigned tune::get_num_threads() {
unsigned result = 1;
for(ir::metaparameter *mp: num_threads_mp_vec_)
result *= mp->get_value();
return result;
return num_threads_;
}

View File

@@ -16,7 +16,7 @@ void vectorize::run(ir::module &mod) {
for(ir::instruction *i: block->get_inst_list())
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
ir::value *x = i->get_operand(0);
if(params_->get_param(x, "p0.d0")->get_value() == 1)
if(params_->get_param(x, "nts.d0")->get_value() == 1)
continue;
builder.set_insert_point(i);
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);

View File

@@ -86,7 +86,6 @@ std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, codegen:
// constraints
std::map<ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);
std::cout << "errors: " << errors.size() << std::endl;
for(auto &x: errors){
for(auto &e: x.second)
std::cout << x.first->get_name() << " " << e << std::endl;
@@ -150,7 +149,13 @@ void jit::autotune(ir::module &tt_module, benchmark_t benchmark) {
tune.check_constraints(tt_module, errors);
if(errors.size())
return;
std::cout << "valid" << std::endl;
ir::module copy(tt_module);
auto ll_module = make_llvm_module(copy, tune);
driver::module module(driver_context_, &*ll_module);
driver::kernel kernel(module, "matmul");
launch_information info = launch_info_map_.at("matmul");
benchmark(kernel, info);
std::cout << "benchmarked" << std::endl;
});
}