[jit] basic auto-tuning support

This commit is contained in:
Philippe Tillet
2019-03-11 12:00:50 -04:00
parent 94e315ea8a
commit 614f83baee
5 changed files with 151 additions and 99 deletions

View File

@@ -63,6 +63,24 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
}
}
class timer{
typedef std::chrono::high_resolution_clock high_resolution_clock;
typedef std::chrono::nanoseconds nanoseconds;
public:
explicit timer(bool run = false)
{ if (run) start(); }
void start()
{ _start = high_resolution_clock::now(); }
nanoseconds get() const
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
private:
high_resolution_clock::time_point _start;
};
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
@@ -90,6 +108,7 @@ int main() {
stream.write(dc, true, 0, hc);
stream.synchronize();
// benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel kernel,
triton::jit::launch_information info) {
@@ -103,25 +122,17 @@ int main() {
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
timer t;
t.start();
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
stream.synchronize();
return float(0);
double ts = t.get().count()*1e-9;
double tflops = 2*M*N*K / ts * 1e-12;
std::cout << tflops << std::endl;
return ts;
};
// std::vector<unsigned> params = {
// // a0
// 2, 8, 1, 16,
// // b0
// 4, 4, 1, 16,
// // c
// 2, 4, 8, 4, 1, 1,
// // a1
// 2, 4, 1, 8,
// // b1
// 1, 8, 1
// };
// just-in-time compile source-code
std::vector<unsigned> params = {
// a0
@@ -136,8 +147,8 @@ int main() {
8, 1
};
triton::jit jit(context);
jit.add_module(src, params);
jit.autotune(src, benchmark);
jit.add_module(src, params);
triton::driver::kernel kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
benchmark(kernel, info);

View File

@@ -17,6 +17,8 @@ namespace ir{
namespace codegen{
class place_shared_copy;
class tune {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
@@ -35,8 +37,9 @@ public:
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
void init(ir::module &mod);
unsigned get_num_global_range();
unsigned get_global_range_size(unsigned axis);
unsigned get_num_threads();
@@ -50,6 +53,7 @@ private:
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
unsigned num_global_ranges_;
unsigned num_threads_;
std::vector<ir::instruction*> grids_;
};

View File

@@ -7,6 +7,14 @@
#include "triton/ir/context.h"
#include "triton/driver/module.h"
#include "triton/driver/kernel.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/liveness.h"
#include "triton/codegen/vectorize.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/barriers.h"
#include <functional>
namespace llvm {
@@ -33,14 +41,40 @@ public:
};
typedef std::function<unsigned(driver::kernel, launch_information)> benchmark_t;
struct passes_wrapper {
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info),
allocation(&liveness, &buffer_info),
barriers(&allocation, &buffer_info),
vectorize(&tune),
selection(&allocation, &tune, &buffer_info){ }
void init(ir::module &module) {
// generate ptx
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
}
codegen::tune tune;
codegen::buffer_info_pass buffer_info;
codegen::place_shared_copy shared;
codegen::liveness liveness;
codegen::allocation allocation;
codegen::barriers barriers;
codegen::vectorize vectorize;
codegen::selection selection;
};
private:
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, codegen::tune &tune);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
std::unique_ptr<ir::module> make_triton_module(const std::string &src);
public:
jit(driver::context context);
void autotune(ir::module &module, benchmark_t benchmark);
void autotune(const std::string &src, benchmark_t benchmark);
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const std::string &src, const std::vector<unsigned>& params = {});

View File

@@ -1,4 +1,5 @@
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
@@ -143,13 +144,37 @@ void tune::run(ir::module &mod) {
// Layout parameters
while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 8);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
}
}
}
void tune::init(ir::module &mod) {
for(ir::function *fn: mod.get_function_list()){
// initialize grids
std::map<ir::metaparameter*, ir::instruction*> references;
create_grids(grids_, references, fn);
}
// number of warps
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
std::string strk = std::to_string(axis);
unsigned mts = params_[i]["mts.d" + strk]->get_value();
unsigned nts = params_[i]["nts.d" + strk]->get_value();
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
return shape / (mts * nts);
};
// number of threads
num_threads_ = 1;
ir::instruction *first = grids_.front();
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
num_threads_ *= get_num_warps(first, k);
}
}
void tune::create_grids(std::vector<ir::instruction*> &grids,
std::map<ir::metaparameter*, ir::instruction*> &references,
ir::function *fn) {
@@ -182,15 +207,9 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
}
bool tune::check_constraints(ir::module &mod, std::map<ir::value *, std::vector<std::string>> &errors) {
for(ir::function *fn: mod.get_function_list()){
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
using std::to_string;
// initialize grids
std::map<ir::metaparameter*, ir::instruction*> references;
std::vector<ir::instruction*> grids;
create_grids(grids, references, fn);
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
std::string strk = to_string(axis);
unsigned mts = params_[i]["mts.d" + strk]->get_value();
@@ -199,21 +218,14 @@ for(ir::function *fn: mod.get_function_list()){
return shape / (mts * nts);
};
num_threads_ = 1;
ir::instruction *first = grids.front();
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
std::string suffix = ".d" + std::to_string(k);
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
num_threads_ *= get_num_warps(first, k);
}
// number of warps
ir::instruction *first = grids_.front();
int num_warps = 1;
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
num_warps *= get_num_warps(first, k);
// check constraints
for(ir::instruction *i: grids){
for(ir::instruction *i: grids_){
ir::type *ty = i->get_type();
const auto &shapes = ty->get_tile_shapes();
// for each dimension, the product of layout components
@@ -243,7 +255,6 @@ for(ir::function *fn: mod.get_function_list()){
}
return errors.empty();
}
}
unsigned tune::get_num_global_range() {
return num_global_ranges_;

View File

@@ -1,16 +1,9 @@
#include "triton/jit.h"
#include "triton/jit.h"
#include <string>
#include "triton/ast/ast.h"
#include "triton/ir/context.h"
#include "triton/ir/context_impl.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/shared_copy.h"
#include "triton/codegen/allocation.h"
#include "triton/codegen/liveness.h"
#include "triton/codegen/vectorize.h"
#include "triton/codegen/buffer_info.h"
#include "triton/codegen/barriers.h"
#include "triton/driver/device.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
@@ -71,42 +64,15 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, codegen::tune & tune) {
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
// create passes
codegen::buffer_info_pass buffer_info;
codegen::place_shared_copy shared(&buffer_info);
codegen::liveness liveness(&buffer_info);
codegen::allocation allocation(&liveness, &buffer_info);
codegen::barriers barriers(&allocation, &buffer_info);
codegen::vectorize vectorize(&tune);
codegen::selection selection(&allocation, &tune, &buffer_info);
// constraints
std::map<ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);
for(auto &x: errors){
for(auto &e: x.second)
std::cout << x.first->get_name() << " " << e << std::endl;
}
if(errors.size())
exit(EXIT_FAILURE);
// generate ptx
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
selection.run(module, *result);
passes.selection.run(module, *result);
// launch information
auto &launch_info_map = launch_info_map_[result->getName()];
for(unsigned i = 0; i < tune.get_num_global_range(); i++)
launch_info_map.global_range_size.push_back(tune.get_global_range_size(i));
launch_info_map.num_threads = tune.get_num_threads();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
launch_info_map.num_threads = passes.tune.get_num_threads();
return std::unique_ptr<llvm::Module>(result);
}
@@ -127,11 +93,14 @@ jit::jit(driver::context context): driver_context_(context) {
}
void jit::autotune(ir::module &tt_module, benchmark_t benchmark) {
void jit::autotune(const std::string &src, benchmark_t benchmark) {
// find metaparameters
codegen::tune tune;
tune.run(tt_module);
auto mps = tune.get_params(tt_module);
auto ptt_module = make_triton_module(src);
ir::module &tt_module = *ptt_module;
// set parameters
passes_wrapper passes;
passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps){
@@ -141,39 +110,62 @@ void jit::autotune(ir::module &tt_module, benchmark_t benchmark) {
ranges.push_back(current);
}
// iterate over parameters
unsigned i;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;
unsigned i = 0;
i = 0;
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
tune.check_constraints(tt_module, errors);
if(errors.size())
passes.tune.init(tt_module);
if(!passes.tune.check_constraints(errors))
return;
ir::module copy(tt_module);
auto ll_module = make_llvm_module(copy, tune);
// Deep copy of the module and tuner
auto ptt_module = make_triton_module(src);
ir::module &tt_module = *ptt_module;
passes_wrapper passes;
passes.tune.run(tt_module);
i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
mp->set_value(params[i++]);
}
passes.tune.init(tt_module);
passes.init(tt_module);
const driver::device &device = driver_context_.device();
if(passes.allocation.get_allocated_size() > device.max_shared_memory())
return;
if(passes.tune.get_num_threads() > device.max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
driver::module module(driver_context_, &*ll_module);
driver::kernel kernel(module, "matmul");
launch_information info = launch_info_map_.at("matmul");
for(unsigned p: params)
std::cout << p << " " << std::flush;
std::cout << std::endl;
benchmark(kernel, info);
std::cout << "benchmarked" << std::endl;
});
}
void jit::autotune(const std::string &src, benchmark_t benchmark) {
auto ptt_module = make_triton_module(src);
autotune(*ptt_module, benchmark);
}
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) {
// set parameters
codegen::tune tune;
tune.run(tt_module);
passes_wrapper passes;
passes.tune.run(tt_module);
unsigned i = 0;
for(ir::metaparameter* mp: tune.get_params(tt_module))
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
mp->set_value(params[i++]);
// compiler to llvm
auto ll_module = make_llvm_module(tt_module, tune);
// send llvm module to driver
passes.tune.init(tt_module);
passes.init(tt_module);
// check constraints
std::map<ir::value*, std::vector<std::string>> errors;
passes.tune.check_constraints(errors);
if(errors.size())
throw std::runtime_error("invalid parameters");
if(passes.allocation.get_allocated_size() > driver_context_.device().max_shared_memory())
throw std::runtime_error("invalid parameters");
// triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code
modules_.push_back(driver::module(driver_context_, &*ll_module));
}