From b99a73afe37206775ac8b884d32a36e213a3fac2 Mon Sep 17 00:00:00 2001 From: SOLARIS <14846501+solaris33@users.noreply.github.com> Date: Sat, 9 Nov 2019 08:20:54 +0900 Subject: [PATCH] entrypoint variable made public (#970) and Fix RuntimeError (#910) (#1015) (#1032) --- baselines/run.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/baselines/run.py b/baselines/run.py index 4910290..fbf2453 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -31,7 +31,7 @@ except ImportError: _game_envs = defaultdict(set) for env in gym.envs.registry.all(): # TODO: solve this with regexes - env_type = env._entry_point.split(':')[0].split('.')[-1] + env_type = env.entry_point.split(':')[0].split('.')[-1] _game_envs[env_type].add(env.id) # reading benchmark names directly from retro requires @@ -119,7 +119,7 @@ def get_env_type(args): # Re-parse the gym registry, since we could have new envs since last time. for env in gym.envs.registry.all(): - env_type = env._entry_point.split(':')[0].split('.')[-1] + env_type = env.entry_point.split(':')[0].split('.')[-1] _game_envs[env_type].add(env.id) # This is a set so add is idempotent if env_id in _game_envs.keys(): @@ -222,7 +222,7 @@ def main(args): state = model.initial_state if hasattr(model, 'initial_state') else None - episode_rew = 0 + episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1) while True: if state is not None: actions, _, state, _ = model.step(obs) @@ -232,13 +232,13 @@ def main(args): obs, rew, done, _ = env.step(actions.numpy()) if not isinstance(env, VecEnv): obs = np.expand_dims(np.array(obs), axis=0) - episode_rew += rew[0] if isinstance(env, VecEnv) else rew + episode_rew += rew env.render() - done = done.any() if isinstance(done, np.ndarray) else done - if done: - print('episode_rew={}'.format(episode_rew)) - episode_rew = 0 - obs = env.reset() + done_any = done.any() if isinstance(done, np.ndarray) else done + if done_any: + for i in np.nonzero(done)[0]: + print('episode_rew={}'.format(episode_rew[i])) + episode_rew[i] = 0 env.close()