This commit is contained in:
@@ -31,7 +31,7 @@ except ImportError:
|
||||
_game_envs = defaultdict(set)
|
||||
for env in gym.envs.registry.all():
|
||||
# TODO: solve this with regexes
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
env_type = env.entry_point.split(':')[0].split('.')[-1]
|
||||
_game_envs[env_type].add(env.id)
|
||||
|
||||
# reading benchmark names directly from retro requires
|
||||
@@ -119,7 +119,7 @@ def get_env_type(args):
|
||||
|
||||
# Re-parse the gym registry, since we could have new envs since last time.
|
||||
for env in gym.envs.registry.all():
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
env_type = env.entry_point.split(':')[0].split('.')[-1]
|
||||
_game_envs[env_type].add(env.id) # This is a set so add is idempotent
|
||||
|
||||
if env_id in _game_envs.keys():
|
||||
@@ -222,7 +222,7 @@ def main(args):
|
||||
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
|
||||
episode_rew = 0
|
||||
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
|
||||
while True:
|
||||
if state is not None:
|
||||
actions, _, state, _ = model.step(obs)
|
||||
@@ -232,13 +232,13 @@ def main(args):
|
||||
obs, rew, done, _ = env.step(actions.numpy())
|
||||
if not isinstance(env, VecEnv):
|
||||
obs = np.expand_dims(np.array(obs), axis=0)
|
||||
episode_rew += rew[0] if isinstance(env, VecEnv) else rew
|
||||
episode_rew += rew
|
||||
env.render()
|
||||
done = done.any() if isinstance(done, np.ndarray) else done
|
||||
if done:
|
||||
print('episode_rew={}'.format(episode_rew))
|
||||
episode_rew = 0
|
||||
obs = env.reset()
|
||||
done_any = done.any() if isinstance(done, np.ndarray) else done
|
||||
if done_any:
|
||||
for i in np.nonzero(done)[0]:
|
||||
print('episode_rew={}'.format(episode_rew[i]))
|
||||
episode_rew[i] = 0
|
||||
|
||||
env.close()
|
||||
|
||||
|
Reference in New Issue
Block a user