[jit] basic auto-tuning support
This commit is contained in:
@@ -63,6 +63,24 @@ void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T
|
||||
}
|
||||
}
|
||||
|
||||
class timer{
|
||||
typedef std::chrono::high_resolution_clock high_resolution_clock;
|
||||
typedef std::chrono::nanoseconds nanoseconds;
|
||||
|
||||
public:
|
||||
explicit timer(bool run = false)
|
||||
{ if (run) start(); }
|
||||
|
||||
void start()
|
||||
{ _start = high_resolution_clock::now(); }
|
||||
|
||||
nanoseconds get() const
|
||||
{ return std::chrono::duration_cast<nanoseconds>(high_resolution_clock::now() - _start); }
|
||||
|
||||
private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
@@ -90,6 +108,7 @@ int main() {
|
||||
stream.write(dc, true, 0, hc);
|
||||
stream.synchronize();
|
||||
|
||||
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel kernel,
|
||||
triton::jit::launch_information info) {
|
||||
@@ -103,25 +122,17 @@ int main() {
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
timer t;
|
||||
t.start();
|
||||
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
|
||||
stream.synchronize();
|
||||
return float(0);
|
||||
double ts = t.get().count()*1e-9;
|
||||
double tflops = 2*M*N*K / ts * 1e-12;
|
||||
std::cout << tflops << std::endl;
|
||||
return ts;
|
||||
};
|
||||
|
||||
|
||||
// std::vector<unsigned> params = {
|
||||
// // a0
|
||||
// 2, 8, 1, 16,
|
||||
// // b0
|
||||
// 4, 4, 1, 16,
|
||||
// // c
|
||||
// 2, 4, 8, 4, 1, 1,
|
||||
// // a1
|
||||
// 2, 4, 1, 8,
|
||||
// // b1
|
||||
// 1, 8, 1
|
||||
// };
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::vector<unsigned> params = {
|
||||
// a0
|
||||
@@ -136,8 +147,8 @@ int main() {
|
||||
8, 1
|
||||
};
|
||||
triton::jit jit(context);
|
||||
jit.add_module(src, params);
|
||||
jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
benchmark(kernel, info);
|
||||
|
@@ -17,6 +17,8 @@ namespace ir{
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class place_shared_copy;
|
||||
|
||||
class tune {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
@@ -35,8 +37,9 @@ public:
|
||||
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
|
||||
bool check_constraints(ir::module &fn, std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
void init(ir::module &mod);
|
||||
unsigned get_num_global_range();
|
||||
unsigned get_global_range_size(unsigned axis);
|
||||
unsigned get_num_threads();
|
||||
@@ -50,6 +53,7 @@ private:
|
||||
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
|
||||
unsigned num_global_ranges_;
|
||||
unsigned num_threads_;
|
||||
std::vector<ir::instruction*> grids_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -7,6 +7,14 @@
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/barriers.h"
|
||||
#include <functional>
|
||||
|
||||
namespace llvm {
|
||||
@@ -33,14 +41,40 @@ public:
|
||||
};
|
||||
typedef std::function<unsigned(driver::kernel, launch_information)> benchmark_t;
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info),
|
||||
allocation(&liveness, &buffer_info),
|
||||
barriers(&allocation, &buffer_info),
|
||||
vectorize(&tune),
|
||||
selection(&allocation, &tune, &buffer_info){ }
|
||||
|
||||
void init(ir::module &module) {
|
||||
// generate ptx
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
codegen::buffer_info_pass buffer_info;
|
||||
codegen::place_shared_copy shared;
|
||||
codegen::liveness liveness;
|
||||
codegen::allocation allocation;
|
||||
codegen::barriers barriers;
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
};
|
||||
|
||||
private:
|
||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, codegen::tune &tune);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
||||
std::unique_ptr<ir::module> make_triton_module(const std::string &src);
|
||||
|
||||
public:
|
||||
jit(driver::context context);
|
||||
void autotune(ir::module &module, benchmark_t benchmark);
|
||||
void autotune(const std::string &src, benchmark_t benchmark);
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
@@ -143,13 +144,37 @@ void tune::run(ir::module &mod) {
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 8);
|
||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// initialize grids
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
create_grids(grids_, references, fn);
|
||||
}
|
||||
// number of warps
|
||||
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
|
||||
std::string strk = std::to_string(axis);
|
||||
unsigned mts = params_[i]["mts.d" + strk]->get_value();
|
||||
unsigned nts = params_[i]["nts.d" + strk]->get_value();
|
||||
unsigned shape = i->get_type()->get_tile_shapes()[axis]->get_value();
|
||||
return shape / (mts * nts);
|
||||
};
|
||||
// number of threads
|
||||
num_threads_ = 1;
|
||||
ir::instruction *first = grids_.front();
|
||||
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
|
||||
num_threads_ *= get_num_warps(first, k);
|
||||
}
|
||||
}
|
||||
|
||||
void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
std::map<ir::metaparameter*, ir::instruction*> &references,
|
||||
ir::function *fn) {
|
||||
@@ -182,15 +207,9 @@ void tune::create_grids(std::vector<ir::instruction*> &grids,
|
||||
}
|
||||
|
||||
|
||||
bool tune::check_constraints(ir::module &mod, std::map<ir::value *, std::vector<std::string>> &errors) {
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &errors) {
|
||||
using std::to_string;
|
||||
|
||||
// initialize grids
|
||||
std::map<ir::metaparameter*, ir::instruction*> references;
|
||||
std::vector<ir::instruction*> grids;
|
||||
create_grids(grids, references, fn);
|
||||
|
||||
auto get_num_warps = [&](ir::instruction *i, unsigned axis) {
|
||||
std::string strk = to_string(axis);
|
||||
unsigned mts = params_[i]["mts.d" + strk]->get_value();
|
||||
@@ -199,21 +218,14 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
return shape / (mts * nts);
|
||||
};
|
||||
|
||||
num_threads_ = 1;
|
||||
ir::instruction *first = grids.front();
|
||||
for(unsigned k = 0; k < first->get_type()->get_tile_shapes().size(); k++){
|
||||
std::string suffix = ".d" + std::to_string(k);
|
||||
num_threads_ *= params_.at(first).at("mts" + suffix)->get_value();
|
||||
num_threads_ *= get_num_warps(first, k);
|
||||
}
|
||||
|
||||
// number of warps
|
||||
ir::instruction *first = grids_.front();
|
||||
int num_warps = 1;
|
||||
for(size_t k = 0; k < first->get_type()->get_tile_shapes().size(); k++)
|
||||
num_warps *= get_num_warps(first, k);
|
||||
|
||||
// check constraints
|
||||
for(ir::instruction *i: grids){
|
||||
for(ir::instruction *i: grids_){
|
||||
ir::type *ty = i->get_type();
|
||||
const auto &shapes = ty->get_tile_shapes();
|
||||
// for each dimension, the product of layout components
|
||||
@@ -243,7 +255,6 @@ for(ir::function *fn: mod.get_function_list()){
|
||||
}
|
||||
return errors.empty();
|
||||
}
|
||||
}
|
||||
|
||||
unsigned tune::get_num_global_range() {
|
||||
return num_global_ranges_;
|
||||
|
118
lib/jit.cpp
118
lib/jit.cpp
@@ -1,16 +1,9 @@
|
||||
#include "triton/jit.h"
|
||||
#include "triton/jit.h"
|
||||
#include <string>
|
||||
#include "triton/ast/ast.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/shared_copy.h"
|
||||
#include "triton/codegen/allocation.h"
|
||||
#include "triton/codegen/liveness.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include "triton/codegen/buffer_info.h"
|
||||
#include "triton/codegen/barriers.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
@@ -71,42 +64,15 @@ void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(
|
||||
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, codegen::tune & tune) {
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
||||
llvm::Module* result = new llvm::Module("matmul", llvm_context_);
|
||||
|
||||
// create passes
|
||||
codegen::buffer_info_pass buffer_info;
|
||||
codegen::place_shared_copy shared(&buffer_info);
|
||||
codegen::liveness liveness(&buffer_info);
|
||||
codegen::allocation allocation(&liveness, &buffer_info);
|
||||
codegen::barriers barriers(&allocation, &buffer_info);
|
||||
codegen::vectorize vectorize(&tune);
|
||||
codegen::selection selection(&allocation, &tune, &buffer_info);
|
||||
|
||||
// constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
tune.check_constraints(module, errors);
|
||||
for(auto &x: errors){
|
||||
for(auto &e: x.second)
|
||||
std::cout << x.first->get_name() << " " << e << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
// generate ptx
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
selection.run(module, *result);
|
||||
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
auto &launch_info_map = launch_info_map_[result->getName()];
|
||||
for(unsigned i = 0; i < tune.get_num_global_range(); i++)
|
||||
launch_info_map.global_range_size.push_back(tune.get_global_range_size(i));
|
||||
launch_info_map.num_threads = tune.get_num_threads();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
launch_info_map.num_threads = passes.tune.get_num_threads();
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
@@ -127,11 +93,14 @@ jit::jit(driver::context context): driver_context_(context) {
|
||||
}
|
||||
|
||||
|
||||
void jit::autotune(ir::module &tt_module, benchmark_t benchmark) {
|
||||
void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
codegen::tune tune;
|
||||
tune.run(tt_module);
|
||||
auto mps = tune.get_params(tt_module);
|
||||
auto ptt_module = make_triton_module(src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes;
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps){
|
||||
@@ -141,39 +110,62 @@ void jit::autotune(ir::module &tt_module, benchmark_t benchmark) {
|
||||
ranges.push_back(current);
|
||||
}
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
unsigned i = 0;
|
||||
i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
tune.check_constraints(tt_module, errors);
|
||||
if(errors.size())
|
||||
passes.tune.init(tt_module);
|
||||
if(!passes.tune.check_constraints(errors))
|
||||
return;
|
||||
ir::module copy(tt_module);
|
||||
auto ll_module = make_llvm_module(copy, tune);
|
||||
// Deep copy of the module and tuner
|
||||
auto ptt_module = make_triton_module(src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
passes_wrapper passes;
|
||||
passes.tune.run(tt_module);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
const driver::device &device = driver_context_.device();
|
||||
if(passes.allocation.get_allocated_size() > device.max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device.max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
driver::module module(driver_context_, &*ll_module);
|
||||
driver::kernel kernel(module, "matmul");
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
std::cout << std::endl;
|
||||
benchmark(kernel, info);
|
||||
std::cout << "benchmarked" << std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
auto ptt_module = make_triton_module(src);
|
||||
autotune(*ptt_module, benchmark);
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
// set parameters
|
||||
codegen::tune tune;
|
||||
tune.run(tt_module);
|
||||
passes_wrapper passes;
|
||||
passes.tune.run(tt_module);
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter* mp: tune.get_params(tt_module))
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
|
||||
mp->set_value(params[i++]);
|
||||
// compiler to llvm
|
||||
auto ll_module = make_llvm_module(tt_module, tune);
|
||||
// send llvm module to driver
|
||||
passes.tune.init(tt_module);
|
||||
passes.init(tt_module);
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
if(passes.allocation.get_allocated_size() > driver_context_.device().max_shared_memory())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
modules_.push_back(driver::module(driver_context_, &*ll_module));
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user