[code generation] reparameterization

This commit is contained in:
Philippe Tillet
2019-03-11 19:30:21 -04:00
parent 614f83baee
commit 87c85ed50d
5 changed files with 29 additions and 25 deletions

View File

@@ -86,7 +86,7 @@ int main() {
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
size_t M = 128, N = 128, K = 128;
size_t M = 512, N = 512, K = 512;
size_t bound = 8;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);

View File

@@ -412,7 +412,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
assert(expr_ == nullptr);
//TODO: implement ranges
value = ir::metaparameter::create(mod->get_context(), ty, 8, 128);
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
}
if(expr_){
value = expr_->codegen(mod);

View File

@@ -371,18 +371,37 @@ std::vector<Value*> delinearize(Value *trailing, std::vector<unsigned> &shapes,
return result;
}
inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
static const size_t warp_size = 32;
size_t nthreads = 1, nwarps = 1;
nw.resize(bs.size());
ws.resize(bs.size());
for(size_t i = 0; i < bs.size(); ++i){
nthreads *= bs[i];
nw[i] = ceil(nthreads, nwarps*warp_size);
nwarps *= nw[i];
}
for(size_t i = 0; i < bs.size(); ++i)
ws[i] = bs[i] / nw[i];
}
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
warp_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
n_warps[i] = shapes[i]->get_value() / (contiguous[i] * warp_size[i]);
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
}
to_warps(block_size, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
// Create axes

View File

@@ -144,8 +144,8 @@ void tune::run(ir::module &mod) {
// Layout parameters
while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 8);
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
}
}
@@ -157,21 +157,12 @@ void tune::init(ir::module &mod) {
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
// number of warps
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
std::string strk = std::to_string(axis);
unsigned mts = params_[i]["mts.d" + strk]->get_value();
unsigned nts = params_[i]["nts.d" + strk]->get_value();
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
return shape / (mts * nts);
};
// number of threads
num_threads_ = 1;
ir::instruction *first = grids_.front();
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
num_threads_ *= get_num_warps(first, k);
}
}
@@ -243,15 +234,10 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
int num_threads = 1;
for(size_t k = 0; k < shapes.size(); k++)
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
if(num_threads != 32)
errors[i].push_back("number of threads per warp (" + to_string(num_threads) + ") must be 32");
// The number of warps required by the layout is the same
// for all tiles in the function
int required_num_warps = 1;
for(size_t k = 0; k < shapes.size(); k++)
required_num_warps *= get_num_warps(i, k);
if(required_num_warps != num_warps)
errors[i].push_back("number of warps (" + to_string(required_num_warps) + ") must be " + to_string(num_warps));
if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
}
return errors.empty();
}

View File

@@ -142,7 +142,6 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
launch_information info = launch_info_map_.at("matmul");
for(unsigned p: params)
std::cout << p << " " << std::flush;
std::cout << std::endl;
benchmark(kernel, info);
});
}