Made sure it works for FP16

This commit is contained in:
Philippe Tillet
2019-07-30 20:02:16 -07:00
parent 080bf1af88
commit 5af7e5adac
21 changed files with 118 additions and 101 deletions

View File

@@ -44,11 +44,12 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
auto benchmark = [&](triton::driver::kernel* kernel,
rt::launch_information info) {
// launch info
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info);
clone->enqueue_impl(stream, kernel, args, info);
stream->synchronize();
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); }, stream);
clone->deinit_impl();
// std::cout << ts * 1e-6 << std::endl;
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
@@ -65,7 +66,8 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
rt::launch_information info = jit->get_launch_info(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module(), info);
}
/* retrieved compiled template */
else {
@@ -75,9 +77,10 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
return {it->first, jit};
}
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
base* base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
launch_context_t info = get_launch_context(stream, args, autotune);
info.op->enqueue_impl(stream, info.kernel, args, info.info);
return info.op;
}
launch_context_t base::get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {