[FRONTEND] improved caching mechanism (#474)

Co-authored-by: Greg Brockman <gdb@gregbrockman.com>
Co-authored-by: Christopher Hesse <christopherhesse@users.noreply.github.com>
This commit is contained in:
Philippe Tillet
2022-03-15 12:20:51 -07:00
committed by GitHub
parent 21f8a0646d
commit d4d8eaf6c0
3 changed files with 106 additions and 84 deletions

View File

@@ -299,8 +299,12 @@ void init_triton_runtime(py::module &&m) {
// get cached binary
py::str key(cache_key);
if(!bin_cache.contains(key))
add_to_cache(key, args, device, num_warps, num_stages);
py::bool_ noop = false;
if(!bin_cache.contains(key)) {
noop = add_to_cache(key, args, device, num_warps, num_stages);
}
if (noop)
return (py::object)py::none();
py::object bin = bin_cache[key];
// get grid
@@ -529,6 +533,7 @@ void init_triton_codegen(py::module &&m) {
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map);
}, py::return_value_policy::take_ownership);
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
py::gil_scoped_release allow_threads;
if(backend == CUDA)
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
if(backend == ROCM)