[examples] removed dependency on isaac for auto-tuning
This commit is contained in:
@@ -81,13 +81,37 @@ private:
|
|||||||
high_resolution_clock::time_point _start;
|
high_resolution_clock::time_point _start;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
T min(std::vector<T> x)
|
||||||
|
{ return *std::min_element(x.begin(), x.end()); }
|
||||||
|
|
||||||
|
|
||||||
|
template<class OP, class SYNC>
|
||||||
|
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||||
|
{
|
||||||
|
timer tmr;
|
||||||
|
std::vector<size_t> times;
|
||||||
|
double total_time = 0;
|
||||||
|
op();
|
||||||
|
sync();
|
||||||
|
while(total_time*1e-9 < 1e-3){
|
||||||
|
float norm = (float)device.current_sm_clock()/device.max_sm_clock();
|
||||||
|
tmr.start();
|
||||||
|
op();
|
||||||
|
sync();
|
||||||
|
times.push_back(norm*tmr.get().count());
|
||||||
|
total_time+=times.back();
|
||||||
|
}
|
||||||
|
return min(times);
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::jit jit(context);
|
||||||
|
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
size_t M = 512, N = 512, K = 512;
|
size_t M = 512, N = 512, K = 512;
|
||||||
size_t bound = 8;
|
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
std::vector<float> ha(M*K);
|
std::vector<float> ha(M*K);
|
||||||
@@ -112,6 +136,22 @@ int main() {
|
|||||||
// benchmark a given matrix multiplication kernel
|
// benchmark a given matrix multiplication kernel
|
||||||
auto benchmark = [&](triton::driver::kernel kernel,
|
auto benchmark = [&](triton::driver::kernel kernel,
|
||||||
triton::jit::launch_information info) {
|
triton::jit::launch_information info) {
|
||||||
|
// launch info
|
||||||
|
unsigned TM = info.global_range_size[0];
|
||||||
|
unsigned TN = info.global_range_size[1];
|
||||||
|
unsigned nthreads = info.num_threads;
|
||||||
|
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||||
|
// fast bounds-checking
|
||||||
|
unsigned TK = jit.get_int("TK");
|
||||||
|
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||||
|
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||||
|
unsigned lastk = TK - 1;
|
||||||
|
bool AT = false;
|
||||||
|
bool BT = true;
|
||||||
|
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||||
|
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||||
|
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||||
|
// set argument
|
||||||
kernel.setArg(0, da);
|
kernel.setArg(0, da);
|
||||||
kernel.setArg(1, db);
|
kernel.setArg(1, db);
|
||||||
kernel.setArg(2, dc);
|
kernel.setArg(2, dc);
|
||||||
@@ -119,39 +159,33 @@ int main() {
|
|||||||
kernel.setArg(4, N);
|
kernel.setArg(4, N);
|
||||||
kernel.setArg(5, K);
|
kernel.setArg(5, K);
|
||||||
kernel.setArg(6, bound);
|
kernel.setArg(6, bound);
|
||||||
unsigned TM = info.global_range_size[0];
|
// dry run
|
||||||
unsigned TN = info.global_range_size[1];
|
stream.enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
unsigned nthreads = info.num_threads;
|
|
||||||
timer t;
|
|
||||||
t.start();
|
|
||||||
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
|
|
||||||
stream.synchronize();
|
stream.synchronize();
|
||||||
double ts = t.get().count()*1e-9;
|
// benchmark
|
||||||
|
double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||||
|
[&](){ stream.synchronize(); },
|
||||||
|
context.device());
|
||||||
|
ts = ts * 1e-9;
|
||||||
double tflops = 2*M*N*K / ts * 1e-12;
|
double tflops = 2*M*N*K / ts * 1e-12;
|
||||||
std::cout << tflops << std::endl;
|
return tflops;
|
||||||
return ts;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// just-in-time compile source-code
|
// just-in-time compile source-code
|
||||||
std::vector<unsigned> params = {
|
std::vector<unsigned> params = {
|
||||||
// a0
|
16, 2, 64,
|
||||||
8, 2, 16,
|
32, 2, 64,
|
||||||
// b0
|
16, 8, 2, 2,
|
||||||
4, 4, 16,
|
8, 1, 8,
|
||||||
// c
|
4, 1
|
||||||
8, 4, 2, 4,
|
|
||||||
// a1
|
|
||||||
4, 2, 8,
|
|
||||||
// b1
|
|
||||||
8, 1
|
|
||||||
};
|
};
|
||||||
triton::jit jit(context);
|
|
||||||
jit.autotune(src, benchmark);
|
// jit.autotune(src, benchmark);
|
||||||
jit.add_module(src, params);
|
jit.add_module(src, params);
|
||||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
benchmark(kernel, info);
|
std::cout << benchmark(kernel, info) << std::endl;
|
||||||
stream.read(dc, true, 0, hc);
|
stream.read(dc, true, 0, hc);
|
||||||
simple_gemm(rc, ha, hb, M, N, K);
|
simple_gemm(rc, ha, hb, M, N, K);
|
||||||
for(size_t i = 0; i < M*N; i++)
|
for(size_t i = 0; i < M*N; i++)
|
||||||
|
@@ -74,12 +74,15 @@ public:
|
|||||||
functions_list_t &get_function_list() { return functions_; }
|
functions_list_t &get_function_list() { return functions_; }
|
||||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||||
// Scope
|
// Scope
|
||||||
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
|
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
|
||||||
void pop_scope() { scopes_.pop(); }
|
void pop_scope() { scopes_.pop(); }
|
||||||
scope& get_scope() { return scopes_.top(); }
|
scope& get_scope() { return scopes_.top(); }
|
||||||
// Const allocation
|
// Const allocation
|
||||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||||
|
// Register global
|
||||||
|
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||||
|
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
@@ -96,6 +99,7 @@ private:
|
|||||||
std::map<value*, value**> current_phi_;
|
std::map<value*, value**> current_phi_;
|
||||||
std::stack<scope> scopes_;
|
std::stack<scope> scopes_;
|
||||||
std::vector<ir::alloc_const*> allocs_;
|
std::vector<ir::alloc_const*> allocs_;
|
||||||
|
std::map<std::string, ir::value*> globals_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -39,7 +39,7 @@ public:
|
|||||||
std::vector<unsigned> global_range_size;
|
std::vector<unsigned> global_range_size;
|
||||||
unsigned num_threads;
|
unsigned num_threads;
|
||||||
};
|
};
|
||||||
typedef std::function<unsigned(driver::kernel, launch_information)> benchmark_t;
|
typedef std::function<double(driver::kernel, launch_information)> benchmark_t;
|
||||||
|
|
||||||
struct passes_wrapper {
|
struct passes_wrapper {
|
||||||
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info),
|
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info),
|
||||||
@@ -80,6 +80,7 @@ public:
|
|||||||
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
||||||
driver::kernel get_function(const std::string &name);
|
driver::kernel get_function(const std::string &name);
|
||||||
launch_information get_launch_info(const std::string &name);
|
launch_information get_launch_info(const std::string &name);
|
||||||
|
unsigned get_int(const std::string &name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<driver::module> modules_;
|
std::vector<driver::module> modules_;
|
||||||
@@ -87,6 +88,7 @@ private:
|
|||||||
llvm::LLVMContext llvm_context_;
|
llvm::LLVMContext llvm_context_;
|
||||||
ir::context triton_context_;
|
ir::context triton_context_;
|
||||||
std::map<std::string, launch_information> launch_info_map_;
|
std::map<std::string, launch_information> launch_info_map_;
|
||||||
|
std::map<std::string, unsigned> global_ints_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@@ -412,7 +412,8 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
|||||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||||
assert(expr_ == nullptr);
|
assert(expr_ == nullptr);
|
||||||
//TODO: implement ranges
|
//TODO: implement ranges
|
||||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
|
value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64);
|
||||||
|
mod->register_global(name, value);
|
||||||
}
|
}
|
||||||
if(expr_){
|
if(expr_){
|
||||||
value = expr_->codegen(mod);
|
value = expr_->codegen(mod);
|
||||||
|
@@ -144,7 +144,7 @@ void tune::run(ir::module &mod) {
|
|||||||
// Layout parameters
|
// Layout parameters
|
||||||
while(!nodes_.empty()){
|
while(!nodes_.empty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
|
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2);
|
||||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||||
}
|
}
|
||||||
|
15
lib/jit.cpp
15
lib/jit.cpp
@@ -111,6 +111,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
|||||||
}
|
}
|
||||||
// iterate over parameters
|
// iterate over parameters
|
||||||
unsigned i;
|
unsigned i;
|
||||||
|
double best = 0;
|
||||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||||
std::map<ir::value*, std::vector<std::string>> errors;
|
std::map<ir::value*, std::vector<std::string>> errors;
|
||||||
i = 0;
|
i = 0;
|
||||||
@@ -142,7 +143,12 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
|||||||
launch_information info = launch_info_map_.at("matmul");
|
launch_information info = launch_info_map_.at("matmul");
|
||||||
for(unsigned p: params)
|
for(unsigned p: params)
|
||||||
std::cout << p << " " << std::flush;
|
std::cout << p << " " << std::flush;
|
||||||
benchmark(kernel, info);
|
// add globals
|
||||||
|
for(auto x: tt_module.globals())
|
||||||
|
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||||
|
double perf = benchmark(kernel, info);
|
||||||
|
best = std::max(perf, best);
|
||||||
|
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,6 +172,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
|||||||
auto ll_module = make_llvm_module(tt_module, passes);
|
auto ll_module = make_llvm_module(tt_module, passes);
|
||||||
// llvm module -> machine code
|
// llvm module -> machine code
|
||||||
modules_.push_back(driver::module(driver_context_, &*ll_module));
|
modules_.push_back(driver::module(driver_context_, &*ll_module));
|
||||||
|
// add globals
|
||||||
|
for(auto x: tt_module.globals())
|
||||||
|
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
void jit::add_module(const std::string &src, const std::vector<unsigned> ¶ms) {
|
void jit::add_module(const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||||
@@ -181,4 +190,8 @@ jit::launch_information jit::get_launch_info(const std::string &name) {
|
|||||||
return launch_info_map_.at(name);
|
return launch_info_map_.at(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned jit::get_int(const std::string &name){
|
||||||
|
return global_ints_.at(name);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user