Minor fixes to HER release (#319)
* Fix plotting script * Add warning if num_cpu = 1
This commit is contained in:
committed by
GitHub
parent
8b3a6c2051
commit
3cc7df0608
@@ -67,13 +67,13 @@ for curr_path in paths:
|
|||||||
print('skipping {}'.format(curr_path))
|
print('skipping {}'.format(curr_path))
|
||||||
continue
|
continue
|
||||||
print('loading {} ({})'.format(curr_path, len(results['epoch'])))
|
print('loading {} ({})'.format(curr_path, len(results['epoch'])))
|
||||||
with open(os.path.join(curr_path, 'metadata.json'), 'r') as f:
|
with open(os.path.join(curr_path, 'params.json'), 'r') as f:
|
||||||
metadata = json.load(f)
|
params = json.load(f)
|
||||||
|
|
||||||
success_rate = np.array(results['test/success_rate'])
|
success_rate = np.array(results['test/success_rate'])
|
||||||
epoch = np.array(results['epoch']) + 1
|
epoch = np.array(results['epoch']) + 1
|
||||||
env_id = metadata['kwargs']['env_name']
|
env_id = params['env_name']
|
||||||
replay_strategy = metadata['kwargs']['replay_strategy']
|
replay_strategy = params['replay_strategy']
|
||||||
|
|
||||||
if replay_strategy == 'future':
|
if replay_strategy == 'future':
|
||||||
config = 'her'
|
config = 'her'
|
||||||
|
@@ -119,6 +119,18 @@ def launch(
|
|||||||
params = config.prepare_params(params)
|
params = config.prepare_params(params)
|
||||||
config.log_params(params, logger=logger)
|
config.log_params(params, logger=logger)
|
||||||
|
|
||||||
|
if num_cpu == 1:
|
||||||
|
logger.warn()
|
||||||
|
logger.warn('*** Warning ***')
|
||||||
|
logger.warn(
|
||||||
|
'You are running HER with just a single MPI worker. This will work, but the ' +
|
||||||
|
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
|
||||||
|
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
|
||||||
|
'are looking to reproduce those results, be aware of this. Please also refer to ' +
|
||||||
|
'https://github.com/openai/baselines/issues/314 for further details.')
|
||||||
|
logger.warn('****************')
|
||||||
|
logger.warn()
|
||||||
|
|
||||||
dims = config.configure_dims(params)
|
dims = config.configure_dims(params)
|
||||||
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user