Avoid using default config while requesting available GPUs (#863)
This commit is contained in:
@@ -305,12 +305,17 @@ def display_var_info(vars):
|
||||
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
|
||||
|
||||
|
||||
def get_available_gpus():
|
||||
# recipe from here:
|
||||
# https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
|
||||
def get_available_gpus(session_config=None):
|
||||
# based on recipe from https://stackoverflow.com/a/38580201
|
||||
|
||||
# Unless we allocate a session here, subsequent attempts to create one
|
||||
# will ignore our custom config (in particular, allow_growth=True will have
|
||||
# no effect).
|
||||
if session_config is None:
|
||||
session_config = get_session()._config
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
local_device_protos = device_lib.list_local_devices()
|
||||
local_device_protos = device_lib.list_local_devices(session_config)
|
||||
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
||||
|
||||
# ================================================================
|
||||
|
Reference in New Issue
Block a user