[examples] removed dependency on isaac for auto-tuning

This commit is contained in:
Philippe Tillet
2019-03-11 22:22:43 -04:00
parent 87c85ed50d
commit b73c3bdd25
6 changed files with 86 additions and 32 deletions

View File

@@ -81,13 +81,37 @@ private:
high_resolution_clock::time_point _start; high_resolution_clock::time_point _start;
}; };
template<class T>
T min(std::vector<T> x)
{ return *std::min_element(x.begin(), x.end()); }
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
{
timer tmr;
std::vector<size_t> times;
double total_time = 0;
op();
sync();
while(total_time*1e-9 < 1e-3){
float norm = (float)device.current_sm_clock()/device.max_sm_clock();
tmr.start();
op();
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
}
return min(times);
}
int main() { int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
// matrix multiplication parameters // matrix multiplication parameters
size_t M = 512, N = 512, K = 512; size_t M = 512, N = 512, K = 512;
size_t bound = 8;
std::vector<float> hc(M*N); std::vector<float> hc(M*N);
std::vector<float> rc(M*N); std::vector<float> rc(M*N);
std::vector<float> ha(M*K); std::vector<float> ha(M*K);
@@ -112,6 +136,22 @@ int main() {
// benchmark a given matrix multiplication kernel // benchmark a given matrix multiplication kernel
auto benchmark = [&](triton::driver::kernel kernel, auto benchmark = [&](triton::driver::kernel kernel,
triton::jit::launch_information info) { triton::jit::launch_information info) {
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
// fast bounds-checking
unsigned TK = jit.get_int("TK");
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
unsigned lastk = TK - 1;
bool AT = false;
bool BT = true;
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
// set argument
kernel.setArg(0, da); kernel.setArg(0, da);
kernel.setArg(1, db); kernel.setArg(1, db);
kernel.setArg(2, dc); kernel.setArg(2, dc);
@@ -119,39 +159,33 @@ int main() {
kernel.setArg(4, N); kernel.setArg(4, N);
kernel.setArg(5, K); kernel.setArg(5, K);
kernel.setArg(6, bound); kernel.setArg(6, bound);
unsigned TM = info.global_range_size[0]; // dry run
unsigned TN = info.global_range_size[1]; stream.enqueue(kernel, grid, {nthreads, 1, 1});
unsigned nthreads = info.num_threads;
timer t;
t.start();
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
stream.synchronize(); stream.synchronize();
double ts = t.get().count()*1e-9; // benchmark
double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});},
[&](){ stream.synchronize(); },
context.device());
ts = ts * 1e-9;
double tflops = 2*M*N*K / ts * 1e-12; double tflops = 2*M*N*K / ts * 1e-12;
std::cout << tflops << std::endl; return tflops;
return ts;
}; };
// just-in-time compile source-code // just-in-time compile source-code
std::vector<unsigned> params = { std::vector<unsigned> params = {
// a0 16, 2, 64,
8, 2, 16, 32, 2, 64,
// b0 16, 8, 2, 2,
4, 4, 16, 8, 1, 8,
// c 4, 1
8, 4, 2, 4,
// a1
4, 2, 8,
// b1
8, 1
}; };
triton::jit jit(context);
jit.autotune(src, benchmark); // jit.autotune(src, benchmark);
jit.add_module(src, params); jit.add_module(src, params);
triton::driver::kernel kernel = jit.get_function("matmul"); triton::driver::kernel kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul");
benchmark(kernel, info); std::cout << benchmark(kernel, info) << std::endl;
stream.read(dc, true, 0, hc); stream.read(dc, true, 0, hc);
simple_gemm(rc, ha, hb, M, N, K); simple_gemm(rc, ha, hb, M, N, K);
for(size_t i = 0; i < M*N; i++) for(size_t i = 0; i < M*N; i++)

View File

@@ -74,12 +74,15 @@ public:
functions_list_t &get_function_list() { return functions_; } functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty); function *get_or_insert_function(const std::string &name, function_type *ty);
// Scope // Scope
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
void pop_scope() { scopes_.pop(); } void pop_scope() { scopes_.pop(); }
scope& get_scope() { return scopes_.top(); } scope& get_scope() { return scopes_.top(); }
// Const allocation // Const allocation
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); } void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
const std::vector<ir::alloc_const*>& allocs() { return allocs_; } const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
// Register global
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
const std::map<std::string, ir::value*>& globals() const { return globals_; }
private: private:
std::string name_; std::string name_;
@@ -96,6 +99,7 @@ private:
std::map<value*, value**> current_phi_; std::map<value*, value**> current_phi_;
std::stack<scope> scopes_; std::stack<scope> scopes_;
std::vector<ir::alloc_const*> allocs_; std::vector<ir::alloc_const*> allocs_;
std::map<std::string, ir::value*> globals_;
}; };
} }

View File

@@ -39,7 +39,7 @@ public:
std::vector<unsigned> global_range_size; std::vector<unsigned> global_range_size;
unsigned num_threads; unsigned num_threads;
}; };
typedef std::function<unsigned(driver::kernel, launch_information)> benchmark_t; typedef std::function<double(driver::kernel, launch_information)> benchmark_t;
struct passes_wrapper { struct passes_wrapper {
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info), passes_wrapper(): shared(&buffer_info), liveness(&buffer_info),
@@ -80,6 +80,7 @@ public:
void add_module(const std::string &src, const std::vector<unsigned>& params = {}); void add_module(const std::string &src, const std::vector<unsigned>& params = {});
driver::kernel get_function(const std::string &name); driver::kernel get_function(const std::string &name);
launch_information get_launch_info(const std::string &name); launch_information get_launch_info(const std::string &name);
unsigned get_int(const std::string &name);
private: private:
std::vector<driver::module> modules_; std::vector<driver::module> modules_;
@@ -87,6 +88,7 @@ private:
llvm::LLVMContext llvm_context_; llvm::LLVMContext llvm_context_;
ir::context triton_context_; ir::context triton_context_;
std::map<std::string, launch_information> launch_info_map_; std::map<std::string, launch_information> launch_info_map_;
std::map<std::string, unsigned> global_ints_;
}; };

View File

@@ -412,7 +412,8 @@ ir::value* initializer::codegen(ir::module * mod) const{
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
assert(expr_ == nullptr); assert(expr_ == nullptr);
//TODO: implement ranges //TODO: implement ranges
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64); value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64);
mod->register_global(name, value);
} }
if(expr_){ if(expr_){
value = expr_->codegen(mod); value = expr_->codegen(mod);

View File

@@ -144,7 +144,7 @@ void tune::run(ir::module &mod) {
// Layout parameters // Layout parameters
while(!nodes_.empty()){ while(!nodes_.empty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2); ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_); connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
} }

View File

@@ -111,6 +111,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
} }
// iterate over parameters // iterate over parameters
unsigned i; unsigned i;
double best = 0;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){ loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors; std::map<ir::value*, std::vector<std::string>> errors;
i = 0; i = 0;
@@ -142,7 +143,12 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
launch_information info = launch_info_map_.at("matmul"); launch_information info = launch_info_map_.at("matmul");
for(unsigned p: params) for(unsigned p: params)
std::cout << p << " " << std::flush; std::cout << p << " " << std::flush;
benchmark(kernel, info); // add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
double perf = benchmark(kernel, info);
best = std::max(perf, best);
std::cout << perf << " [ " << best << " ] " << std::endl;
}); });
} }
@@ -166,6 +172,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params)
auto ll_module = make_llvm_module(tt_module, passes); auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code // llvm module -> machine code
modules_.push_back(driver::module(driver_context_, &*ll_module)); modules_.push_back(driver::module(driver_context_, &*ll_module));
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
} }
void jit::add_module(const std::string &src, const std::vector<unsigned> &params) { void jit::add_module(const std::string &src, const std::vector<unsigned> &params) {
@@ -181,4 +190,8 @@ jit::launch_information jit::get_launch_info(const std::string &name) {
return launch_info_map_.at(name); return launch_info_map_.at(name);
} }
unsigned jit::get_int(const std::string &name){
return global_ints_.at(name);
}
} }