diff --git a/python/src/triton.cc b/python/src/triton.cc index 940c5bc0a..f72513395 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -558,16 +558,16 @@ void init_triton_codegen(py::module &&m) { } if(backend == CUDA) return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); - if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); + assert(backend == ROCM); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); }, py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::gil_scoped_release allow_threads; if(backend == CUDA) return cu_load_binary(name, asm_map, n_shared_bytes, dev); - if(backend == ROCM) - return hip_load_binary(name, asm_map, n_shared_bytes, dev); + assert(backend == ROCM); + return hip_load_binary(name, asm_map, n_shared_bytes, dev); }, py::return_value_policy::take_ownership); }