[examples] removed dependency on isaac for auto-tuning
This commit is contained in:
@@ -81,13 +81,37 @@ private:
|
||||
high_resolution_clock::time_point _start;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
T min(std::vector<T> x)
|
||||
{ return *std::min_element(x.begin(), x.end()); }
|
||||
|
||||
|
||||
template<class OP, class SYNC>
|
||||
double bench(OP const & op, SYNC const & sync, triton::driver::device const & device)
|
||||
{
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = (float)device.current_sm_clock()/device.max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
}
|
||||
return min(times);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
size_t M = 512, N = 512, K = 512;
|
||||
size_t bound = 8;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -112,6 +136,22 @@ int main() {
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// fast bounds-checking
|
||||
unsigned TK = jit.get_int("TK");
|
||||
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||
unsigned lastk = TK - 1;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||
// set argument
|
||||
kernel.setArg(0, da);
|
||||
kernel.setArg(1, db);
|
||||
kernel.setArg(2, dc);
|
||||
@@ -119,39 +159,33 @@ int main() {
|
||||
kernel.setArg(4, N);
|
||||
kernel.setArg(5, K);
|
||||
kernel.setArg(6, bound);
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
timer t;
|
||||
t.start();
|
||||
stream.enqueue(kernel, {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}, {nthreads, 1, 1});
|
||||
// dry run
|
||||
stream.enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream.synchronize();
|
||||
double ts = t.get().count()*1e-9;
|
||||
// benchmark
|
||||
double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream.synchronize(); },
|
||||
context.device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2*M*N*K / ts * 1e-12;
|
||||
std::cout << tflops << std::endl;
|
||||
return ts;
|
||||
return tflops;
|
||||
};
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::vector<unsigned> params = {
|
||||
// a0
|
||||
8, 2, 16,
|
||||
// b0
|
||||
4, 4, 16,
|
||||
// c
|
||||
8, 4, 2, 4,
|
||||
// a1
|
||||
4, 2, 8,
|
||||
// b1
|
||||
8, 1
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 1, 8,
|
||||
4, 1
|
||||
};
|
||||
triton::jit jit(context);
|
||||
jit.autotune(src, benchmark);
|
||||
|
||||
// jit.autotune(src, benchmark);
|
||||
jit.add_module(src, params);
|
||||
triton::driver::kernel kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
benchmark(kernel, info);
|
||||
std::cout << benchmark(kernel, info) << std::endl;
|
||||
stream.read(dc, true, 0, hc);
|
||||
simple_gemm(rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
|
@@ -80,6 +80,9 @@ public:
|
||||
// Const allocation
|
||||
void add_alloc(ir::alloc_const* x) { allocs_.push_back(x); }
|
||||
const std::vector<ir::alloc_const*>& allocs() { return allocs_; }
|
||||
// Register global
|
||||
void register_global(const std::string& name, ir::value *x) { globals_[name] = x; }
|
||||
const std::map<std::string, ir::value*>& globals() const { return globals_; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
@@ -96,6 +99,7 @@ private:
|
||||
std::map<value*, value**> current_phi_;
|
||||
std::stack<scope> scopes_;
|
||||
std::vector<ir::alloc_const*> allocs_;
|
||||
std::map<std::string, ir::value*> globals_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -39,7 +39,7 @@ public:
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
};
|
||||
typedef std::function<unsigned(driver::kernel, launch_information)> benchmark_t;
|
||||
typedef std::function<double(driver::kernel, launch_information)> benchmark_t;
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(): shared(&buffer_info), liveness(&buffer_info),
|
||||
@@ -80,6 +80,7 @@ public:
|
||||
void add_module(const std::string &src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel get_function(const std::string &name);
|
||||
launch_information get_launch_info(const std::string &name);
|
||||
unsigned get_int(const std::string &name);
|
||||
|
||||
private:
|
||||
std::vector<driver::module> modules_;
|
||||
@@ -87,6 +88,7 @@ private:
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
std::map<std::string, unsigned> global_ints_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -412,7 +412,8 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||
assert(expr_ == nullptr);
|
||||
//TODO: implement ranges
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, 64);
|
||||
value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64);
|
||||
mod->register_global(name, value);
|
||||
}
|
||||
if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
|
@@ -144,7 +144,7 @@ void tune::run(ir::module &mod) {
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||
}
|
||||
|
15
lib/jit.cpp
15
lib/jit.cpp
@@ -111,6 +111,7 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
}
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
i = 0;
|
||||
@@ -142,7 +143,12 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) {
|
||||
launch_information info = launch_info_map_.at("matmul");
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
benchmark(kernel, info);
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
double perf = benchmark(kernel, info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -166,6 +172,9 @@ void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms)
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
modules_.push_back(driver::module(driver_context_, &*ll_module));
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
}
|
||||
|
||||
void jit::add_module(const std::string &src, const std::vector<unsigned> ¶ms) {
|
||||
@@ -181,4 +190,8 @@ jit::launch_information jit::get_launch_info(const std::string &name) {
|
||||
return launch_info_map_.at(name);
|
||||
}
|
||||
|
||||
unsigned jit::get_int(const std::string &name){
|
||||
return global_ints_.at(name);
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user