mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-21 06:20:15 +00:00
TimeLimit refactor with Monitor Simplification (#482)
* fix double reset, as suggested by @jietang * better floors and ceilings * add convenience methods to monitor * add wrappers to gym namespace * allow playing Atari games, with potentially more coming in the future * simplify example in docs * Move play out of the Env * fix tests * no more deprecation warnings * remove env.monitor * monitor simplification * monitor simplifications * monitor related fixes * a few changes suggested by linter * timestep_limit fixes * keep track of gym env variables for future compatibility * timestep_limit => max_episode_timesteps * don't apply TimeLimit wrapper in make for VNC envs * Respect old timestep_limit argument * Pass max_episode_seconds through registration * Don't include deprecation warnings yet
This commit is contained in:
@@ -47,5 +47,6 @@ from gym.core import Env, Space, Wrapper, ObservationWrapper, ActionWrapper, Rew
|
|||||||
from gym.benchmarks import benchmark_spec
|
from gym.benchmarks import benchmark_spec
|
||||||
from gym.envs import make, spec
|
from gym.envs import make, spec
|
||||||
from gym.scoreboard.api import upload
|
from gym.scoreboard.api import upload
|
||||||
|
from gym import wrappers
|
||||||
|
|
||||||
__all__ = ["Env", "Space", "Wrapper", "make", "spec", "upload"]
|
__all__ = ["Env", "Space", "Wrapper", "make", "spec", "upload", "wrappers"]
|
||||||
|
@@ -269,30 +269,44 @@ register_benchmark(
|
|||||||
{'env_id': 'HalfCheetah-v1',
|
{'env_id': 'HalfCheetah-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': -280.0,
|
||||||
|
'reward_ceiling': 4000.0,
|
||||||
},
|
},
|
||||||
{'env_id': 'Hopper-v1',
|
{'env_id': 'Hopper-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': 16.0,
|
||||||
|
'reward_ceiling': 4000.0,
|
||||||
},
|
},
|
||||||
{'env_id': 'InvertedDoublePendulum-v1',
|
{'env_id': 'InvertedDoublePendulum-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': 53.0,
|
||||||
|
'reward_ceiling': 10000.0,
|
||||||
},
|
},
|
||||||
{'env_id': 'InvertedPendulum-v1',
|
{'env_id': 'InvertedPendulum-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': 5.6,
|
||||||
|
'reward_ceiling': 1000.0,
|
||||||
},
|
},
|
||||||
{'env_id': 'Reacher-v1',
|
{'env_id': 'Reacher-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': -43.0,
|
||||||
|
'reward_ceiling': -0.5,
|
||||||
},
|
},
|
||||||
{'env_id': 'Swimmer-v1',
|
{'env_id': 'Swimmer-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': 0.23,
|
||||||
|
'reward_ceiling': 500.0,
|
||||||
},
|
},
|
||||||
{'env_id': 'Walker2d-v1',
|
{'env_id': 'Walker2d-v1',
|
||||||
'trials': 3,
|
'trials': 3,
|
||||||
'max_timesteps': 1000000,
|
'max_timesteps': 1000000,
|
||||||
|
'reward_floor': 1.6,
|
||||||
|
'reward_ceiling': 5500.0,
|
||||||
}
|
}
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import gym
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
|
@@ -6,14 +6,14 @@ from gym.envs.registration import registry, register, make, spec
|
|||||||
register(
|
register(
|
||||||
id='Copy-v0',
|
id='Copy-v0',
|
||||||
entry_point='gym.envs.algorithmic:CopyEnv',
|
entry_point='gym.envs.algorithmic:CopyEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=25.0,
|
reward_threshold=25.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='RepeatCopy-v0',
|
id='RepeatCopy-v0',
|
||||||
entry_point='gym.envs.algorithmic:RepeatCopyEnv',
|
entry_point='gym.envs.algorithmic:RepeatCopyEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=75.0,
|
reward_threshold=75.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ register(
|
|||||||
id='ReversedAddition-v0',
|
id='ReversedAddition-v0',
|
||||||
entry_point='gym.envs.algorithmic:ReversedAdditionEnv',
|
entry_point='gym.envs.algorithmic:ReversedAdditionEnv',
|
||||||
kwargs={'rows' : 2},
|
kwargs={'rows' : 2},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=25.0,
|
reward_threshold=25.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,21 +29,21 @@ register(
|
|||||||
id='ReversedAddition3-v0',
|
id='ReversedAddition3-v0',
|
||||||
entry_point='gym.envs.algorithmic:ReversedAdditionEnv',
|
entry_point='gym.envs.algorithmic:ReversedAdditionEnv',
|
||||||
kwargs={'rows' : 3},
|
kwargs={'rows' : 3},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=25.0,
|
reward_threshold=25.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='DuplicatedInput-v0',
|
id='DuplicatedInput-v0',
|
||||||
entry_point='gym.envs.algorithmic:DuplicatedInputEnv',
|
entry_point='gym.envs.algorithmic:DuplicatedInputEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=9.0,
|
reward_threshold=9.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Reverse-v0',
|
id='Reverse-v0',
|
||||||
entry_point='gym.envs.algorithmic:ReverseEnv',
|
entry_point='gym.envs.algorithmic:ReverseEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=25.0,
|
reward_threshold=25.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,41 +53,41 @@ register(
|
|||||||
register(
|
register(
|
||||||
id='CartPole-v0',
|
id='CartPole-v0',
|
||||||
entry_point='gym.envs.classic_control:CartPoleEnv',
|
entry_point='gym.envs.classic_control:CartPoleEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=195.0,
|
reward_threshold=195.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='CartPole-v1',
|
id='CartPole-v1',
|
||||||
entry_point='gym.envs.classic_control:CartPoleEnv',
|
entry_point='gym.envs.classic_control:CartPoleEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 500},
|
max_episode_steps=500,
|
||||||
reward_threshold=475.0,
|
reward_threshold=475.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='MountainCar-v0',
|
id='MountainCar-v0',
|
||||||
entry_point='gym.envs.classic_control:MountainCarEnv',
|
entry_point='gym.envs.classic_control:MountainCarEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=-110.0,
|
reward_threshold=-110.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='MountainCarContinuous-v0',
|
id='MountainCarContinuous-v0',
|
||||||
entry_point='gym.envs.classic_control:Continuous_MountainCarEnv',
|
entry_point='gym.envs.classic_control:Continuous_MountainCarEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 999},
|
max_episode_steps=999,
|
||||||
reward_threshold=90.0,
|
reward_threshold=90.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Pendulum-v0',
|
id='Pendulum-v0',
|
||||||
entry_point='gym.envs.classic_control:PendulumEnv',
|
entry_point='gym.envs.classic_control:PendulumEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Acrobot-v1',
|
id='Acrobot-v1',
|
||||||
entry_point='gym.envs.classic_control:AcrobotEnv',
|
entry_point='gym.envs.classic_control:AcrobotEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 500},
|
max_episode_steps=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Box2d
|
# Box2d
|
||||||
@@ -96,35 +96,35 @@ register(
|
|||||||
register(
|
register(
|
||||||
id='LunarLander-v2',
|
id='LunarLander-v2',
|
||||||
entry_point='gym.envs.box2d:LunarLander',
|
entry_point='gym.envs.box2d:LunarLander',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=200,
|
reward_threshold=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='LunarLanderContinuous-v2',
|
id='LunarLanderContinuous-v2',
|
||||||
entry_point='gym.envs.box2d:LunarLanderContinuous',
|
entry_point='gym.envs.box2d:LunarLanderContinuous',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=200,
|
reward_threshold=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='BipedalWalker-v2',
|
id='BipedalWalker-v2',
|
||||||
entry_point='gym.envs.box2d:BipedalWalker',
|
entry_point='gym.envs.box2d:BipedalWalker',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1600},
|
max_episode_steps=1600,
|
||||||
reward_threshold=300,
|
reward_threshold=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='BipedalWalkerHardcore-v2',
|
id='BipedalWalkerHardcore-v2',
|
||||||
entry_point='gym.envs.box2d:BipedalWalkerHardcore',
|
entry_point='gym.envs.box2d:BipedalWalkerHardcore',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 2000},
|
max_episode_steps=2000,
|
||||||
reward_threshold=300,
|
reward_threshold=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='CarRacing-v0',
|
id='CarRacing-v0',
|
||||||
entry_point='gym.envs.box2d:CarRacing',
|
entry_point='gym.envs.box2d:CarRacing',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=900,
|
reward_threshold=900,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ register(
|
|||||||
id='FrozenLake-v0',
|
id='FrozenLake-v0',
|
||||||
entry_point='gym.envs.toy_text:FrozenLakeEnv',
|
entry_point='gym.envs.toy_text:FrozenLakeEnv',
|
||||||
kwargs={'map_name' : '4x4'},
|
kwargs={'map_name' : '4x4'},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 100},
|
max_episode_steps=100,
|
||||||
reward_threshold=0.78, # optimum = .8196
|
reward_threshold=0.78, # optimum = .8196
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -148,39 +148,39 @@ register(
|
|||||||
id='FrozenLake8x8-v0',
|
id='FrozenLake8x8-v0',
|
||||||
entry_point='gym.envs.toy_text:FrozenLakeEnv',
|
entry_point='gym.envs.toy_text:FrozenLakeEnv',
|
||||||
kwargs={'map_name' : '8x8'},
|
kwargs={'map_name' : '8x8'},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
reward_threshold=0.99, # optimum = 1
|
reward_threshold=0.99, # optimum = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='NChain-v0',
|
id='NChain-v0',
|
||||||
entry_point='gym.envs.toy_text:NChainEnv',
|
entry_point='gym.envs.toy_text:NChainEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Roulette-v0',
|
id='Roulette-v0',
|
||||||
entry_point='gym.envs.toy_text:RouletteEnv',
|
entry_point='gym.envs.toy_text:RouletteEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 100},
|
max_episode_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Taxi-v2',
|
id='Taxi-v2',
|
||||||
entry_point='gym.envs.toy_text.taxi:TaxiEnv',
|
entry_point='gym.envs.toy_text.taxi:TaxiEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
|
||||||
reward_threshold=8, # optimum = 8.46
|
reward_threshold=8, # optimum = 8.46
|
||||||
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='GuessingGame-v0',
|
id='GuessingGame-v0',
|
||||||
entry_point='gym.envs.toy_text.guessing_game:GuessingGame',
|
entry_point='gym.envs.toy_text.guessing_game:GuessingGame',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HotterColder-v0',
|
id='HotterColder-v0',
|
||||||
entry_point='gym.envs.toy_text.hotter_colder:HotterColder',
|
entry_point='gym.envs.toy_text.hotter_colder:HotterColder',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
@@ -191,68 +191,68 @@ register(
|
|||||||
register(
|
register(
|
||||||
id='Reacher-v1',
|
id='Reacher-v1',
|
||||||
entry_point='gym.envs.mujoco:ReacherEnv',
|
entry_point='gym.envs.mujoco:ReacherEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 50},
|
max_episode_steps=50,
|
||||||
reward_threshold=-3.75,
|
reward_threshold=-3.75,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='InvertedPendulum-v1',
|
id='InvertedPendulum-v1',
|
||||||
entry_point='gym.envs.mujoco:InvertedPendulumEnv',
|
entry_point='gym.envs.mujoco:InvertedPendulumEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=950.0,
|
reward_threshold=950.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='InvertedDoublePendulum-v1',
|
id='InvertedDoublePendulum-v1',
|
||||||
entry_point='gym.envs.mujoco:InvertedDoublePendulumEnv',
|
entry_point='gym.envs.mujoco:InvertedDoublePendulumEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=9100.0,
|
reward_threshold=9100.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HalfCheetah-v1',
|
id='HalfCheetah-v1',
|
||||||
entry_point='gym.envs.mujoco:HalfCheetahEnv',
|
entry_point='gym.envs.mujoco:HalfCheetahEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=4800.0,
|
reward_threshold=4800.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Hopper-v1',
|
id='Hopper-v1',
|
||||||
entry_point='gym.envs.mujoco:HopperEnv',
|
entry_point='gym.envs.mujoco:HopperEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=3800.0,
|
reward_threshold=3800.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Swimmer-v1',
|
id='Swimmer-v1',
|
||||||
entry_point='gym.envs.mujoco:SwimmerEnv',
|
entry_point='gym.envs.mujoco:SwimmerEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=360.0,
|
reward_threshold=360.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Walker2d-v1',
|
id='Walker2d-v1',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
entry_point='gym.envs.mujoco:Walker2dEnv',
|
entry_point='gym.envs.mujoco:Walker2dEnv',
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Ant-v1',
|
id='Ant-v1',
|
||||||
entry_point='gym.envs.mujoco:AntEnv',
|
entry_point='gym.envs.mujoco:AntEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
reward_threshold=6000.0,
|
reward_threshold=6000.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='Humanoid-v1',
|
id='Humanoid-v1',
|
||||||
entry_point='gym.envs.mujoco:HumanoidEnv',
|
entry_point='gym.envs.mujoco:HumanoidEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HumanoidStandup-v1',
|
id='HumanoidStandup-v1',
|
||||||
entry_point='gym.envs.mujoco:HumanoidStandupEnv',
|
entry_point='gym.envs.mujoco:HumanoidStandupEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 1000},
|
max_episode_steps=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Atari
|
# Atari
|
||||||
@@ -286,7 +286,7 @@ for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', '
|
|||||||
id='{}-v0'.format(name),
|
id='{}-v0'.format(name),
|
||||||
entry_point='gym.envs.atari:AtariEnv',
|
entry_point='gym.envs.atari:AtariEnv',
|
||||||
kwargs={'game': game, 'obs_type': obs_type, 'repeat_action_probability': 0.25},
|
kwargs={'game': game, 'obs_type': obs_type, 'repeat_action_probability': 0.25},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 10000},
|
max_episode_steps=10000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -294,7 +294,7 @@ for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', '
|
|||||||
id='{}-v3'.format(name),
|
id='{}-v3'.format(name),
|
||||||
entry_point='gym.envs.atari:AtariEnv',
|
entry_point='gym.envs.atari:AtariEnv',
|
||||||
kwargs={'game': game, 'obs_type': obs_type},
|
kwargs={'game': game, 'obs_type': obs_type},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 100000},
|
max_episode_steps=100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -309,7 +309,7 @@ for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', '
|
|||||||
id='{}Deterministic-v0'.format(name),
|
id='{}Deterministic-v0'.format(name),
|
||||||
entry_point='gym.envs.atari:AtariEnv',
|
entry_point='gym.envs.atari:AtariEnv',
|
||||||
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip, 'repeat_action_probability': 0.25},
|
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip, 'repeat_action_probability': 0.25},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 100000},
|
max_episode_steps=100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -317,7 +317,7 @@ for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', '
|
|||||||
id='{}Deterministic-v3'.format(name),
|
id='{}Deterministic-v3'.format(name),
|
||||||
entry_point='gym.envs.atari:AtariEnv',
|
entry_point='gym.envs.atari:AtariEnv',
|
||||||
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip},
|
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': frameskip},
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 100000},
|
max_episode_steps=100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -325,7 +325,7 @@ for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', '
|
|||||||
id='{}NoFrameskip-v0'.format(name),
|
id='{}NoFrameskip-v0'.format(name),
|
||||||
entry_point='gym.envs.atari:AtariEnv',
|
entry_point='gym.envs.atari:AtariEnv',
|
||||||
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1, 'repeat_action_probability': 0.25}, # A frameskip of 1 means we get every frame
|
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1, 'repeat_action_probability': 0.25}, # A frameskip of 1 means we get every frame
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': frameskip * 100000},
|
max_episode_steps=frameskip * 100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -335,7 +335,7 @@ for game in ['air_raid', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', '
|
|||||||
id='{}NoFrameskip-v3'.format(name),
|
id='{}NoFrameskip-v3'.format(name),
|
||||||
entry_point='gym.envs.atari:AtariEnv',
|
entry_point='gym.envs.atari:AtariEnv',
|
||||||
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1}, # A frameskip of 1 means we get every frame
|
kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1}, # A frameskip of 1 means we get every frame
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': frameskip * 100000},
|
max_episode_steps=frameskip * 100000,
|
||||||
nondeterministic=nondeterministic,
|
nondeterministic=nondeterministic,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -430,13 +430,13 @@ register(
|
|||||||
register(
|
register(
|
||||||
id='PredictActionsCartpole-v0',
|
id='PredictActionsCartpole-v0',
|
||||||
entry_point='gym.envs.safety:PredictActionsCartpoleEnv',
|
entry_point='gym.envs.safety:PredictActionsCartpoleEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='PredictObsCartpole-v0',
|
id='PredictObsCartpole-v0',
|
||||||
entry_point='gym.envs.safety:PredictObsCartpoleEnv',
|
entry_point='gym.envs.safety:PredictObsCartpoleEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
# semi_supervised envs
|
# semi_supervised envs
|
||||||
@@ -444,30 +444,30 @@ register(
|
|||||||
register(
|
register(
|
||||||
id='SemisuperPendulumNoise-v0',
|
id='SemisuperPendulumNoise-v0',
|
||||||
entry_point='gym.envs.safety:SemisuperPendulumNoiseEnv',
|
entry_point='gym.envs.safety:SemisuperPendulumNoiseEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
# somewhat harder because of higher variance:
|
# somewhat harder because of higher variance:
|
||||||
register(
|
register(
|
||||||
id='SemisuperPendulumRandom-v0',
|
id='SemisuperPendulumRandom-v0',
|
||||||
entry_point='gym.envs.safety:SemisuperPendulumRandomEnv',
|
entry_point='gym.envs.safety:SemisuperPendulumRandomEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
# probably the hardest because you only get a constant number of rewards in total:
|
# probably the hardest because you only get a constant number of rewards in total:
|
||||||
register(
|
register(
|
||||||
id='SemisuperPendulumDecay-v0',
|
id='SemisuperPendulumDecay-v0',
|
||||||
entry_point='gym.envs.safety:SemisuperPendulumDecayEnv',
|
entry_point='gym.envs.safety:SemisuperPendulumDecayEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
# off_switch envs
|
# off_switch envs
|
||||||
register(
|
register(
|
||||||
id='OffSwitchCartpole-v0',
|
id='OffSwitchCartpole-v0',
|
||||||
entry_point='gym.envs.safety:OffSwitchCartpoleEnv',
|
entry_point='gym.envs.safety:OffSwitchCartpoleEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='OffSwitchCartpoleProb-v0',
|
id='OffSwitchCartpoleProb-v0',
|
||||||
entry_point='gym.envs.safety:OffSwitchCartpoleProbEnv',
|
entry_point='gym.envs.safety:OffSwitchCartpoleProbEnv',
|
||||||
tags={'wrapper_config.TimeLimit.max_episode_steps': 200},
|
max_episode_steps=200,
|
||||||
)
|
)
|
||||||
|
@@ -124,6 +124,29 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
|||||||
def get_action_meanings(self):
|
def get_action_meanings(self):
|
||||||
return [ACTION_MEANING[i] for i in self._action_set]
|
return [ACTION_MEANING[i] for i in self._action_set]
|
||||||
|
|
||||||
|
def get_keys_to_action(self):
|
||||||
|
KEYWORD_TO_KEY = {
|
||||||
|
'UP': ord('w'),
|
||||||
|
'DOWN': ord('s'),
|
||||||
|
'LEFT': ord('a'),
|
||||||
|
'RIGHT': ord('d'),
|
||||||
|
'FIRE': ord(' '),
|
||||||
|
}
|
||||||
|
|
||||||
|
keys_to_action = {}
|
||||||
|
|
||||||
|
for action_id, action_meaning in enumerate(self.get_action_meanings()):
|
||||||
|
keys = []
|
||||||
|
for keyword, key in KEYWORD_TO_KEY.items():
|
||||||
|
if keyword in action_meaning:
|
||||||
|
keys.append(key)
|
||||||
|
keys = tuple(sorted(keys))
|
||||||
|
|
||||||
|
assert keys not in keys_to_action
|
||||||
|
keys_to_action[keys] = action_id
|
||||||
|
|
||||||
|
return keys_to_action
|
||||||
|
|
||||||
# def save_state(self):
|
# def save_state(self):
|
||||||
# return self.ale.saveState()
|
# return self.ale.saveState()
|
||||||
|
|
||||||
|
@@ -495,4 +495,4 @@ if __name__=="__main__":
|
|||||||
if not record_video: # Faster, but you can as well call env.render() every time to play full window.
|
if not record_video: # Faster, but you can as well call env.render() every time to play full window.
|
||||||
env.render()
|
env.render()
|
||||||
if done or restart: break
|
if done or restart: break
|
||||||
env.monitor.close()
|
env.close()
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
|
import warnings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# This format is true today, but it's *not* an official spec.
|
# This format is true today, but it's *not* an official spec.
|
||||||
@@ -37,7 +36,7 @@ class EnvSpec(object):
|
|||||||
trials (int): The number of trials run in official evaluation
|
trials (int): The number of trials run in official evaluation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id, entry_point=None, trials=100, reward_threshold=None, local_only=False, kwargs=None, nondeterministic=False, tags=None, timestep_limit=None):
|
def __init__(self, id, entry_point=None, trials=100, reward_threshold=None, local_only=False, kwargs=None, nondeterministic=False, tags=None, max_episode_steps=None, max_episode_seconds=None, timestep_limit=None):
|
||||||
self.id = id
|
self.id = id
|
||||||
# Evaluation parameters
|
# Evaluation parameters
|
||||||
self.trials = trials
|
self.trials = trials
|
||||||
@@ -49,7 +48,24 @@ class EnvSpec(object):
|
|||||||
tags = {}
|
tags = {}
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
|
|
||||||
self.timestep_limit = timestep_limit
|
# BACKWARDS COMPAT 2017/1/18
|
||||||
|
if tags.get('wrapper_config.TimeLimit.max_episode_steps'):
|
||||||
|
max_episode_steps = tags.get('wrapper_config.TimeLimit.max_episode_steps')
|
||||||
|
# TODO: Add the following deprecation warning after 2017/02/18
|
||||||
|
# warnings.warn("DEPRECATION WARNING wrapper_config.TimeLimit has been deprecated. Replace any calls to `register(tags={'wrapper_config.TimeLimit.max_episode_steps': 200)}` with `register(max_episode_steps=200)`. This change was made 2017/1/31 and is included in gym version 0.8.0. If you are getting many of these warnings, you may need to update universe past version 0.21.3")
|
||||||
|
|
||||||
|
tags['wrapper_config.TimeLimit.max_episode_steps'] = max_episode_steps
|
||||||
|
######
|
||||||
|
|
||||||
|
# BACKWARDS COMPAT 2017/1/31
|
||||||
|
if timestep_limit is not None:
|
||||||
|
max_episode_steps = timestep_limit
|
||||||
|
# TODO: Add the following deprecation warning after 2017/03/01
|
||||||
|
# warnings.warn("register(timestep_limit={}) is deprecated. Use register(max_episode_steps={}) instead.".format(timestep_limit, timestep_limit))
|
||||||
|
######
|
||||||
|
|
||||||
|
self.max_episode_steps = max_episode_steps
|
||||||
|
self.max_episode_seconds = max_episode_seconds
|
||||||
|
|
||||||
# We may make some of these other parameters public if they're
|
# We may make some of these other parameters public if they're
|
||||||
# useful.
|
# useful.
|
||||||
@@ -71,6 +87,7 @@ class EnvSpec(object):
|
|||||||
|
|
||||||
# Make the enviroment aware of which spec it came from.
|
# Make the enviroment aware of which spec it came from.
|
||||||
env.spec = self
|
env.spec = self
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -78,15 +95,12 @@ class EnvSpec(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def timestep_limit(self):
|
def timestep_limit(self):
|
||||||
logger.warn("DEPRECATION WARNING: env.spec.timestep_limit has been deprecated. Replace your call to `env.spec.timestep_limit` with `env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')`. This change was made 12/28/2016 and is included in version 0.7.0")
|
return self.max_episode_steps
|
||||||
return self.tags.get('wrapper_config.TimeLimit.max_episode_steps')
|
|
||||||
|
|
||||||
@timestep_limit.setter
|
@timestep_limit.setter
|
||||||
def timestep_limit(self, timestep_limit):
|
def timestep_limit(self, value):
|
||||||
if timestep_limit is not None:
|
self.max_episode_steps = value
|
||||||
logger.warn(
|
|
||||||
"DEPRECATION WARNING: env.spec.timestep_limit has been deprecated. Replace any calls to `register(timestep_limit=200)` with `register(tags={'wrapper_config.TimeLimit.max_episode_steps': 200)}`, . This change was made 12/28/2016 and is included in gym version 0.7.0. If you are getting many of these warnings, you may need to update universe past version 0.21.1")
|
|
||||||
self.tags['wrapper_config.TimeLimit.max_episode_steps'] = timestep_limit
|
|
||||||
|
|
||||||
class EnvRegistry(object):
|
class EnvRegistry(object):
|
||||||
"""Register an env by ID. IDs remain stable over time and are
|
"""Register an env by ID. IDs remain stable over time and are
|
||||||
@@ -102,7 +116,14 @@ class EnvRegistry(object):
|
|||||||
def make(self, id):
|
def make(self, id):
|
||||||
logger.info('Making new env: %s', id)
|
logger.info('Making new env: %s', id)
|
||||||
spec = self.spec(id)
|
spec = self.spec(id)
|
||||||
return spec.make()
|
env = spec.make()
|
||||||
|
if (env.spec.timestep_limit is not None) and not spec.tags.get('vnc'):
|
||||||
|
from gym.wrappers.time_limit import TimeLimit
|
||||||
|
env = TimeLimit(env,
|
||||||
|
max_episode_steps=env.spec.max_episode_steps,
|
||||||
|
max_episode_seconds=env.spec.max_episode_seconds)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
def all(self):
|
def all(self):
|
||||||
return self.env_specs.values()
|
return self.env_specs.values()
|
||||||
|
@@ -6,7 +6,7 @@ from gym.envs.classic_control import cartpole
|
|||||||
def test_make():
|
def test_make():
|
||||||
env = envs.make('CartPole-v0')
|
env = envs.make('CartPole-v0')
|
||||||
assert env.spec.id == 'CartPole-v0'
|
assert env.spec.id == 'CartPole-v0'
|
||||||
assert isinstance(env, cartpole.CartPoleEnv)
|
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
|
||||||
|
|
||||||
def test_make_deprecated():
|
def test_make_deprecated():
|
||||||
try:
|
try:
|
||||||
|
@@ -1,9 +1,3 @@
|
|||||||
from gym.monitoring.monitor_manager import (
|
|
||||||
_open_monitors,
|
|
||||||
detect_training_manifests,
|
|
||||||
load_env_info_from_manifests,
|
|
||||||
load_results,
|
|
||||||
MonitorManager,
|
|
||||||
)
|
|
||||||
from gym.monitoring.stats_recorder import StatsRecorder
|
from gym.monitoring.stats_recorder import StatsRecorder
|
||||||
from gym.monitoring.video_recorder import VideoRecorder
|
from gym.monitoring.video_recorder import VideoRecorder
|
||||||
|
from gym.wrappers.monitoring import load_results, detect_training_manifests, load_env_info_from_manifests, _open_monitors
|
@@ -1,411 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import six
|
|
||||||
from gym import error, version
|
|
||||||
from gym.monitoring import stats_recorder, video_recorder
|
|
||||||
from gym.utils import atomic_write, closer
|
|
||||||
from gym.utils.json_utils import json_encode_np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
FILE_PREFIX = 'openaigym'
|
|
||||||
MANIFEST_PREFIX = FILE_PREFIX + '.manifest'
|
|
||||||
|
|
||||||
def detect_training_manifests(training_dir, files=None):
|
|
||||||
if files is None:
|
|
||||||
files = os.listdir(training_dir)
|
|
||||||
return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + '.')]
|
|
||||||
|
|
||||||
def detect_monitor_files(training_dir):
|
|
||||||
return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + '.')]
|
|
||||||
|
|
||||||
def clear_monitor_files(training_dir):
|
|
||||||
files = detect_monitor_files(training_dir)
|
|
||||||
if len(files) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info('Clearing %d monitor files from previous run (because force=True was provided)', len(files))
|
|
||||||
for file in files:
|
|
||||||
os.unlink(file)
|
|
||||||
|
|
||||||
def capped_cubic_video_schedule(episode_id):
|
|
||||||
if episode_id < 1000:
|
|
||||||
return int(round(episode_id ** (1. / 3))) ** 3 == episode_id
|
|
||||||
else:
|
|
||||||
return episode_id % 1000 == 0
|
|
||||||
|
|
||||||
def disable_videos(episode_id):
|
|
||||||
return False
|
|
||||||
|
|
||||||
monitor_closer = closer.Closer()
|
|
||||||
|
|
||||||
# This method gets used for a sanity check in scoreboard/api.py. It's
|
|
||||||
# not intended for use outside of the gym codebase.
|
|
||||||
def _open_monitors():
|
|
||||||
return list(monitor_closer.closeables.values())
|
|
||||||
|
|
||||||
class MonitorManager(object):
|
|
||||||
"""A configurable monitor for your training runs.
|
|
||||||
|
|
||||||
Every env has an attached monitor, which you can access as
|
|
||||||
'env.monitor'. Simple usage is just to call 'monitor.start(dir)'
|
|
||||||
to begin monitoring and 'monitor.close()' when training is
|
|
||||||
complete. This will record stats and will periodically record a video.
|
|
||||||
|
|
||||||
For finer-grained control over how often videos are collected, use the
|
|
||||||
video_callable argument, e.g.
|
|
||||||
'monitor.start(video_callable=lambda count: count % 100 == 0)'
|
|
||||||
to record every 100 episodes. ('count' is how many episodes have completed)
|
|
||||||
|
|
||||||
Depending on the environment, video can slow down execution. You
|
|
||||||
can also use 'monitor.configure(video_callable=lambda count: False)' to disable
|
|
||||||
video.
|
|
||||||
|
|
||||||
MonitorManager supports multiple threads and multiple processes writing
|
|
||||||
to the same directory of training data. The data will later be
|
|
||||||
joined by scoreboard.upload_training_data and on the server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env (gym.Env): The environment instance to monitor.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id (Optional[str]): The ID of the monitored environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env):
|
|
||||||
# Python's GC allows refcycles *or* for objects to have a
|
|
||||||
# __del__ method. So we need to maintain a weakref to env.
|
|
||||||
#
|
|
||||||
# https://docs.python.org/2/library/gc.html#gc.garbage
|
|
||||||
self._env_ref = weakref.ref(env)
|
|
||||||
self.videos = []
|
|
||||||
|
|
||||||
self.stats_recorder = None
|
|
||||||
self.video_recorder = None
|
|
||||||
self.enabled = False
|
|
||||||
self.episode_id = 0
|
|
||||||
self._monitor_id = None
|
|
||||||
self.env_semantics_autoreset = env.metadata.get('semantics.autoreset')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def env(self):
|
|
||||||
env = self._env_ref()
|
|
||||||
if env is None:
|
|
||||||
raise error.Error("env has been garbage collected. To keep using a monitor, you must keep around a reference to the env object. (HINT: try assigning the env to a variable in your code.)")
|
|
||||||
return env
|
|
||||||
|
|
||||||
def start(self, directory, video_callable=None, force=False, resume=False,
|
|
||||||
write_upon_reset=False, uid=None, mode=None):
|
|
||||||
"""Start monitoring.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
directory (str): A per-training run directory where to record stats.
|
|
||||||
video_callable (Optional[function, False]): function that takes in the index of the episode and outputs a boolean, indicating whether we should record a video on this episode. The default (for video_callable is None) is to take perfect cubes, capped at 1000. False disables video recording.
|
|
||||||
force (bool): Clear out existing training data from this directory (by deleting every file prefixed with "openaigym.").
|
|
||||||
resume (bool): Retain the training data already in this directory, which will be merged with our new data
|
|
||||||
write_upon_reset (bool): Write the manifest file on each reset. (This is currently a JSON file, so writing it is somewhat expensive.)
|
|
||||||
uid (Optional[str]): A unique id used as part of the suffix for the file. By default, uses os.getpid().
|
|
||||||
mode (['evaluation', 'training']): Whether this is an evaluation or training episode.
|
|
||||||
"""
|
|
||||||
if self.env.spec is None:
|
|
||||||
logger.warn("Trying to monitor an environment which has no 'spec' set. This usually means you did not create it via 'gym.make', and is recommended only for advanced users.")
|
|
||||||
env_id = '(unknown)'
|
|
||||||
else:
|
|
||||||
env_id = self.env.spec.id
|
|
||||||
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
logger.info('Creating monitor directory %s', directory)
|
|
||||||
if six.PY3:
|
|
||||||
os.makedirs(directory, exist_ok=True)
|
|
||||||
else:
|
|
||||||
os.makedirs(directory)
|
|
||||||
|
|
||||||
if video_callable is None:
|
|
||||||
video_callable = capped_cubic_video_schedule
|
|
||||||
elif video_callable == False:
|
|
||||||
video_callable = disable_videos
|
|
||||||
elif not callable(video_callable):
|
|
||||||
raise error.Error('You must provide a function, None, or False for video_callable, not {}: {}'.format(type(video_callable), video_callable))
|
|
||||||
self.video_callable = video_callable
|
|
||||||
|
|
||||||
# Check on whether we need to clear anything
|
|
||||||
if force:
|
|
||||||
clear_monitor_files(directory)
|
|
||||||
elif not resume:
|
|
||||||
training_manifests = detect_training_manifests(directory)
|
|
||||||
if len(training_manifests) > 0:
|
|
||||||
raise error.Error('''Trying to write to monitor directory {} with existing monitor files: {}.
|
|
||||||
|
|
||||||
You should use a unique directory for each training run, or use 'force=True' to automatically clear previous monitor files.'''.format(directory, ', '.join(training_manifests[:5])))
|
|
||||||
|
|
||||||
self._monitor_id = monitor_closer.register(self)
|
|
||||||
|
|
||||||
self.enabled = True
|
|
||||||
self.directory = os.path.abspath(directory)
|
|
||||||
# We use the 'openai-gym' prefix to determine if a file is
|
|
||||||
# ours
|
|
||||||
self.file_prefix = FILE_PREFIX
|
|
||||||
self.file_infix = '{}.{}'.format(self._monitor_id, uid if uid else os.getpid())
|
|
||||||
|
|
||||||
self.stats_recorder = stats_recorder.StatsRecorder(directory, '{}.episode_batch.{}'.format(self.file_prefix, self.file_infix), autoreset=self.env_semantics_autoreset, env_id=env_id)
|
|
||||||
|
|
||||||
if not os.path.exists(directory):
|
|
||||||
os.mkdir(directory)
|
|
||||||
self.write_upon_reset = write_upon_reset
|
|
||||||
|
|
||||||
if mode is not None:
|
|
||||||
self._set_mode(mode)
|
|
||||||
|
|
||||||
def _flush(self, force=False):
|
|
||||||
"""Flush all relevant monitor information to disk."""
|
|
||||||
if not self.write_upon_reset and not force:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.stats_recorder.flush()
|
|
||||||
|
|
||||||
# Give it a very distiguished name, since we need to pick it
|
|
||||||
# up from the filesystem later.
|
|
||||||
path = os.path.join(self.directory, '{}.manifest.{}.manifest.json'.format(self.file_prefix, self.file_infix))
|
|
||||||
logger.debug('Writing training manifest file to %s', path)
|
|
||||||
with atomic_write.atomic_write(path) as f:
|
|
||||||
# We need to write relative paths here since people may
|
|
||||||
# move the training_dir around. It would be cleaner to
|
|
||||||
# already have the basenames rather than basename'ing
|
|
||||||
# manually, but this works for now.
|
|
||||||
json.dump({
|
|
||||||
'stats': os.path.basename(self.stats_recorder.path),
|
|
||||||
'videos': [(os.path.basename(v), os.path.basename(m))
|
|
||||||
for v, m in self.videos],
|
|
||||||
'env_info': self._env_info(),
|
|
||||||
}, f, default=json_encode_np)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Flush all monitor data to disk and close any open rending windows."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
self.stats_recorder.close()
|
|
||||||
if self.video_recorder is not None:
|
|
||||||
self._close_video_recorder()
|
|
||||||
self._flush(force=True)
|
|
||||||
|
|
||||||
env = self._env_ref()
|
|
||||||
# Only take action if the env hasn't been GC'd
|
|
||||||
if env is not None:
|
|
||||||
# Note we'll close the env's rendering window even if we did
|
|
||||||
# not open it. There isn't a particular great way to know if
|
|
||||||
# we did, since some environments will have a window pop up
|
|
||||||
# during video recording.
|
|
||||||
try:
|
|
||||||
env.render(close=True)
|
|
||||||
except Exception as e:
|
|
||||||
if env.spec:
|
|
||||||
key = env.spec.id
|
|
||||||
else:
|
|
||||||
key = env
|
|
||||||
# We don't want to avoid writing the manifest simply
|
|
||||||
# because we couldn't close the renderer.
|
|
||||||
logger.error('Could not close renderer for %s: %s', key, e)
|
|
||||||
|
|
||||||
# Remove the env's pointer to this monitor
|
|
||||||
if hasattr(env, '_monitor'):
|
|
||||||
del env._monitor
|
|
||||||
|
|
||||||
# Stop tracking this for autoclose
|
|
||||||
monitor_closer.unregister(self._monitor_id)
|
|
||||||
self.enabled = False
|
|
||||||
|
|
||||||
logger.info('''Finished writing results. You can upload them to the scoreboard via gym.upload(%r)''', self.directory)
|
|
||||||
|
|
||||||
def _set_mode(self, mode):
|
|
||||||
if mode == 'evaluation':
|
|
||||||
type = 'e'
|
|
||||||
elif mode == 'training':
|
|
||||||
type = 't'
|
|
||||||
else:
|
|
||||||
raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
|
|
||||||
self.stats_recorder.type = type
|
|
||||||
|
|
||||||
def _before_step(self, action):
|
|
||||||
if not self.enabled: return
|
|
||||||
self.stats_recorder.before_step(action)
|
|
||||||
|
|
||||||
def _after_step(self, observation, reward, done, info):
|
|
||||||
if not self.enabled: return done
|
|
||||||
|
|
||||||
if done and self.env_semantics_autoreset:
|
|
||||||
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
|
|
||||||
self._reset_video_recorder()
|
|
||||||
self.episode_id += 1
|
|
||||||
self._flush()
|
|
||||||
|
|
||||||
if info.get('true_reward', None): # Semisupervised envs modify the rewards, but we want the original when scoring
|
|
||||||
reward = info['true_reward']
|
|
||||||
|
|
||||||
# Record stats
|
|
||||||
self.stats_recorder.after_step(observation, reward, done, info)
|
|
||||||
# Record video
|
|
||||||
self.video_recorder.capture_frame()
|
|
||||||
|
|
||||||
return done
|
|
||||||
|
|
||||||
def _before_reset(self):
|
|
||||||
if not self.enabled: return
|
|
||||||
self.stats_recorder.before_reset()
|
|
||||||
|
|
||||||
def _after_reset(self, observation):
|
|
||||||
if not self.enabled: return
|
|
||||||
|
|
||||||
# Reset the stat count
|
|
||||||
self.stats_recorder.after_reset(observation)
|
|
||||||
|
|
||||||
self._reset_video_recorder()
|
|
||||||
|
|
||||||
# Bump *after* all reset activity has finished
|
|
||||||
self.episode_id += 1
|
|
||||||
|
|
||||||
self._flush()
|
|
||||||
|
|
||||||
def _reset_video_recorder(self):
|
|
||||||
# Close any existing video recorder
|
|
||||||
if self.video_recorder:
|
|
||||||
self._close_video_recorder()
|
|
||||||
|
|
||||||
# Start recording the next video.
|
|
||||||
#
|
|
||||||
# TODO: calculate a more correct 'episode_id' upon merge
|
|
||||||
self.video_recorder = video_recorder.VideoRecorder(
|
|
||||||
env=self.env,
|
|
||||||
base_path=os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.episode_id)),
|
|
||||||
metadata={'episode_id': self.episode_id},
|
|
||||||
enabled=self._video_enabled(),
|
|
||||||
)
|
|
||||||
self.video_recorder.capture_frame()
|
|
||||||
|
|
||||||
def _close_video_recorder(self):
|
|
||||||
self.video_recorder.close()
|
|
||||||
if self.video_recorder.functional:
|
|
||||||
self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
|
|
||||||
|
|
||||||
def _video_enabled(self):
|
|
||||||
return self.video_callable(self.episode_id)
|
|
||||||
|
|
||||||
def _env_info(self):
|
|
||||||
env_info = {
|
|
||||||
'gym_version': version.VERSION,
|
|
||||||
}
|
|
||||||
if self.env.spec:
|
|
||||||
env_info['env_id'] = self.env.spec.id
|
|
||||||
return env_info
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
# Make sure we've closed up shop when garbage collecting
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def load_env_info_from_manifests(manifests, training_dir):
|
|
||||||
env_infos = []
|
|
||||||
for manifest in manifests:
|
|
||||||
with open(manifest) as f:
|
|
||||||
contents = json.load(f)
|
|
||||||
env_infos.append(contents['env_info'])
|
|
||||||
|
|
||||||
env_info = collapse_env_infos(env_infos, training_dir)
|
|
||||||
return env_info
|
|
||||||
|
|
||||||
def load_results(training_dir):
|
|
||||||
if not os.path.exists(training_dir):
|
|
||||||
logger.error('Training directory %s not found', training_dir)
|
|
||||||
return
|
|
||||||
|
|
||||||
manifests = detect_training_manifests(training_dir)
|
|
||||||
if not manifests:
|
|
||||||
logger.error('No manifests found in training directory %s', training_dir)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug('Uploading data from manifest %s', ', '.join(manifests))
|
|
||||||
|
|
||||||
# Load up stats + video files
|
|
||||||
stats_files = []
|
|
||||||
videos = []
|
|
||||||
env_infos = []
|
|
||||||
|
|
||||||
for manifest in manifests:
|
|
||||||
with open(manifest) as f:
|
|
||||||
contents = json.load(f)
|
|
||||||
# Make these paths absolute again
|
|
||||||
stats_files.append(os.path.join(training_dir, contents['stats']))
|
|
||||||
videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m))
|
|
||||||
for v, m in contents['videos']]
|
|
||||||
env_infos.append(contents['env_info'])
|
|
||||||
|
|
||||||
env_info = collapse_env_infos(env_infos, training_dir)
|
|
||||||
data_sources, initial_reset_timestamps, timestamps, episode_lengths, episode_rewards, episode_types, initial_reset_timestamp = merge_stats_files(stats_files)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'manifests': manifests,
|
|
||||||
'env_info': env_info,
|
|
||||||
'data_sources': data_sources,
|
|
||||||
'timestamps': timestamps,
|
|
||||||
'episode_lengths': episode_lengths,
|
|
||||||
'episode_rewards': episode_rewards,
|
|
||||||
'episode_types': episode_types,
|
|
||||||
'initial_reset_timestamps': initial_reset_timestamps,
|
|
||||||
'initial_reset_timestamp': initial_reset_timestamp,
|
|
||||||
'videos': videos,
|
|
||||||
}
|
|
||||||
|
|
||||||
def merge_stats_files(stats_files):
|
|
||||||
timestamps = []
|
|
||||||
episode_lengths = []
|
|
||||||
episode_rewards = []
|
|
||||||
episode_types = []
|
|
||||||
initial_reset_timestamps = []
|
|
||||||
data_sources = []
|
|
||||||
|
|
||||||
for i, path in enumerate(stats_files):
|
|
||||||
with open(path) as f:
|
|
||||||
content = json.load(f)
|
|
||||||
if len(content['timestamps'])==0: continue # so empty file doesn't mess up results, due to null initial_reset_timestamp
|
|
||||||
data_sources += [i] * len(content['timestamps'])
|
|
||||||
timestamps += content['timestamps']
|
|
||||||
episode_lengths += content['episode_lengths']
|
|
||||||
episode_rewards += content['episode_rewards']
|
|
||||||
# Recent addition
|
|
||||||
episode_types += content.get('episode_types', [])
|
|
||||||
# Keep track of where each episode came from.
|
|
||||||
initial_reset_timestamps.append(content['initial_reset_timestamp'])
|
|
||||||
|
|
||||||
idxs = np.argsort(timestamps)
|
|
||||||
timestamps = np.array(timestamps)[idxs].tolist()
|
|
||||||
episode_lengths = np.array(episode_lengths)[idxs].tolist()
|
|
||||||
episode_rewards = np.array(episode_rewards)[idxs].tolist()
|
|
||||||
data_sources = np.array(data_sources)[idxs].tolist()
|
|
||||||
|
|
||||||
if episode_types:
|
|
||||||
episode_types = np.array(episode_types)[idxs].tolist()
|
|
||||||
else:
|
|
||||||
episode_types = None
|
|
||||||
|
|
||||||
if len(initial_reset_timestamps) > 0:
|
|
||||||
initial_reset_timestamp = min(initial_reset_timestamps)
|
|
||||||
else:
|
|
||||||
initial_reset_timestamp = 0
|
|
||||||
|
|
||||||
return data_sources, initial_reset_timestamps, timestamps, episode_lengths, episode_rewards, episode_types, initial_reset_timestamp
|
|
||||||
|
|
||||||
# TODO training_dir isn't used except for error messages, clean up the layering
|
|
||||||
def collapse_env_infos(env_infos, training_dir):
|
|
||||||
assert len(env_infos) > 0
|
|
||||||
|
|
||||||
first = env_infos[0]
|
|
||||||
for other in env_infos[1:]:
|
|
||||||
if first != other:
|
|
||||||
raise error.Error('Found two unequal env_infos: {} and {}. This usually indicates that your training directory {} has commingled results from multiple runs.'.format(first, other, training_dir))
|
|
||||||
|
|
||||||
for key in ['env_id', 'gym_version']:
|
|
||||||
if key not in first:
|
|
||||||
raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key))
|
|
||||||
return first
|
|
@@ -103,9 +103,7 @@ logger = logging.getLogger()
|
|||||||
gym.envs.register(
|
gym.envs.register(
|
||||||
id='Autoreset-v0',
|
id='Autoreset-v0',
|
||||||
entry_point='gym.monitoring.tests.test_monitor:AutoresetEnv',
|
entry_point='gym.monitoring.tests.test_monitor:AutoresetEnv',
|
||||||
tags={
|
max_episode_steps=2,
|
||||||
'wrapper_config.TimeLimit.max_episode_steps': 2,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
def test_env_reuse():
|
def test_env_reuse():
|
||||||
with helpers.tempdir() as temp:
|
with helpers.tempdir() as temp:
|
||||||
@@ -189,9 +187,7 @@ def test_only_complete_episodes_written():
|
|||||||
register(
|
register(
|
||||||
id='test.StepsLimitCartpole-v0',
|
id='test.StepsLimitCartpole-v0',
|
||||||
entry_point='gym.envs.classic_control:CartPoleEnv',
|
entry_point='gym.envs.classic_control:CartPoleEnv',
|
||||||
tags={
|
max_episode_steps=2
|
||||||
'wrapper_config.TimeLimit.max_episode_steps': 2
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_steps_limit_restart():
|
def test_steps_limit_restart():
|
||||||
@@ -207,6 +203,6 @@ def test_steps_limit_restart():
|
|||||||
# Limit reached, now we get a done signal and the env resets itself
|
# Limit reached, now we get a done signal and the env resets itself
|
||||||
_, _, done, info = env.step(env.action_space.sample())
|
_, _, done, info = env.step(env.action_space.sample())
|
||||||
assert done == True
|
assert done == True
|
||||||
assert env._monitor.episode_id == 1
|
assert env.episode_id == 1
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -13,6 +13,7 @@ from gym.scoreboard.registration import registry, add_task, add_group, add_bench
|
|||||||
|
|
||||||
# Discover API key from the environment. (You should never have to
|
# Discover API key from the environment. (You should never have to
|
||||||
# change api_base / web_base.)
|
# change api_base / web_base.)
|
||||||
|
env_key_names = ['OPENAI_GYM_API_KEY', 'OPENAI_GYM_API_BASE', 'OPENAI_GYM_WEB_BASE']
|
||||||
api_key = os.environ.get('OPENAI_GYM_API_KEY')
|
api_key = os.environ.get('OPENAI_GYM_API_KEY')
|
||||||
api_base = os.environ.get('OPENAI_GYM_API_BASE', 'https://gym-api.openai.com')
|
api_base = os.environ.get('OPENAI_GYM_API_BASE', 'https://gym-api.openai.com')
|
||||||
web_base = os.environ.get('OPENAI_GYM_WEB_BASE', 'https://gym.openai.com')
|
web_base = os.environ.get('OPENAI_GYM_WEB_BASE', 'https://gym.openai.com')
|
||||||
|
@@ -43,7 +43,7 @@ def upload(training_dir, algorithm_id=None, writeup=None, tags=None, benchmark_i
|
|||||||
# Validate against benchmark spec
|
# Validate against benchmark spec
|
||||||
try:
|
try:
|
||||||
spec = benchmark_spec(benchmark_id)
|
spec = benchmark_spec(benchmark_id)
|
||||||
except error.UnregisteredBenchmark as e:
|
except error.UnregisteredBenchmark:
|
||||||
raise error.Error("Invalid benchmark id: {}. Are you using a benchmark registered in gym/benchmarks/__init__.py?".format(benchmark_id))
|
raise error.Error("Invalid benchmark id: {}. Are you using a benchmark registered in gym/benchmarks/__init__.py?".format(benchmark_id))
|
||||||
|
|
||||||
# TODO: verify that the number of trials matches
|
# TODO: verify that the number of trials matches
|
||||||
@@ -54,7 +54,7 @@ def upload(training_dir, algorithm_id=None, writeup=None, tags=None, benchmark_i
|
|||||||
|
|
||||||
# This could be more stringent about mixing evaluations
|
# This could be more stringent about mixing evaluations
|
||||||
if sorted(env_ids) != sorted(spec_env_ids):
|
if sorted(env_ids) != sorted(spec_env_ids):
|
||||||
logger.info("WARNING: Evaluations do not match spec for benchmark {}. In {}, we found evaluations for {}, expected {}".format(benchmark_id, training_dir, sorted(env_ids), sorted(spec_env_ids)))
|
logger.info("WARNING: Evaluations do not match spec for benchmark %s. In %s, we found evaluations for %s, expected %s", benchmark_id, training_dir, sorted(env_ids), sorted(spec_env_ids))
|
||||||
|
|
||||||
benchmark_run = resource.BenchmarkRun.create(benchmark_id=benchmark_id, algorithm_id=algorithm_id, tags=json.dumps(tags))
|
benchmark_run = resource.BenchmarkRun.create(benchmark_id=benchmark_id, algorithm_id=algorithm_id, tags=json.dumps(tags))
|
||||||
benchmark_run_id = benchmark_run.id
|
benchmark_run_id = benchmark_run.id
|
||||||
@@ -77,7 +77,7 @@ OpenAI Gym! You can find it at:
|
|||||||
return benchmark_run_id
|
return benchmark_run_id
|
||||||
else:
|
else:
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
logger.warn("Tags will NOT be uploaded for this submission.")
|
logger.warning("Tags will NOT be uploaded for this submission.")
|
||||||
# Single evalution upload
|
# Single evalution upload
|
||||||
benchmark_run_id = None
|
benchmark_run_id = None
|
||||||
evaluation = _upload(training_dir, algorithm_id, writeup, benchmark_run_id, api_key, ignore_open_monitors)
|
evaluation = _upload(training_dir, algorithm_id, writeup, benchmark_run_id, api_key, ignore_open_monitors)
|
||||||
@@ -117,7 +117,7 @@ def _upload(training_dir, algorithm_id=None, writeup=None, benchmark_run_id=None
|
|||||||
elif training_video_id is not None:
|
elif training_video_id is not None:
|
||||||
logger.info('[%s] Creating evaluation object from %s with training video', env_id, training_dir)
|
logger.info('[%s] Creating evaluation object from %s with training video', env_id, training_dir)
|
||||||
else:
|
else:
|
||||||
raise error.Error("[{}] You didn't have any recorded training data in {}. Once you've used 'env.monitor.start(training_dir)' to start recording, you need to actually run some rollouts. Please join the community chat on https://gym.openai.com if you have any issues.".format(env_id, training_dir))
|
raise error.Error("[%s] You didn't have any recorded training data in %s. Once you've used 'env.monitor.start(training_dir)' to start recording, you need to actually run some rollouts. Please join the community chat on https://gym.openai.com if you have any issues."%(env_id, training_dir))
|
||||||
|
|
||||||
evaluation = resource.Evaluation.create(
|
evaluation = resource.Evaluation.create(
|
||||||
training_episode_batch=training_episode_batch_id,
|
training_episode_batch=training_episode_batch_id,
|
||||||
@@ -140,7 +140,7 @@ def upload_training_data(training_dir, api_key=None):
|
|||||||
if not results:
|
if not results:
|
||||||
raise error.Error('''Could not find any manifest files in {}.
|
raise error.Error('''Could not find any manifest files in {}.
|
||||||
|
|
||||||
(HINT: this usually means you did not yet close() your env.monitor and have not yet exited the process. You should call 'env.monitor.start(training_dir)' at the start of training and 'env.monitor.close()' at the end, or exit the process.)'''.format(training_dir))
|
(HINT: this usually means you did not yet close() your env.monitor and have not yet exited the process. You should call 'env.monitor.start(training_dir)' at the start of training and 'env.close()' at the end, or exit the process.)'''.format(training_dir))
|
||||||
|
|
||||||
manifests = results['manifests']
|
manifests = results['manifests']
|
||||||
env_info = results['env_info']
|
env_info = results['env_info']
|
||||||
@@ -162,8 +162,8 @@ def upload_training_data(training_dir, api_key=None):
|
|||||||
training_episode_batch = None
|
training_episode_batch = None
|
||||||
|
|
||||||
if len(videos) > MAX_VIDEOS:
|
if len(videos) > MAX_VIDEOS:
|
||||||
logger.warn('[%s] You recorded videos for %s episodes, but the scoreboard only supports up to %s. We will automatically subsample for you, but you also might wish to adjust your video recording rate.', env_id, len(videos), MAX_VIDEOS)
|
logger.warning('[%s] You recorded videos for %s episodes, but the scoreboard only supports up to %s. We will automatically subsample for you, but you also might wish to adjust your video recording rate.', env_id, len(videos), MAX_VIDEOS)
|
||||||
subsample_inds = np.linspace(0, len(videos)-1, MAX_VIDEOS).astype('int')
|
subsample_inds = np.linspace(0, len(videos)-1, MAX_VIDEOS).astype('int') #pylint: disable=E1101
|
||||||
videos = [videos[i] for i in subsample_inds]
|
videos = [videos[i] for i in subsample_inds]
|
||||||
|
|
||||||
if len(videos) > 0:
|
if len(videos) > 0:
|
||||||
|
@@ -31,7 +31,7 @@ def score_from_remote(url):
|
|||||||
|
|
||||||
def score_from_local(directory):
|
def score_from_local(directory):
|
||||||
"""Calculate score from a local results directory"""
|
"""Calculate score from a local results directory"""
|
||||||
results = gym.monitoring.monitor.load_results(directory)
|
results = gym.monitoring.load_results(directory)
|
||||||
# No scores yet saved
|
# No scores yet saved
|
||||||
if results is None:
|
if results is None:
|
||||||
return None
|
return None
|
||||||
|
193
gym/utils/play.py
Normal file
193
gym/utils/play.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import gym
|
||||||
|
import pygame
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from pygame.locals import HWSURFACE, DOUBLEBUF, RESIZABLE, VIDEORESIZE
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
try:
|
||||||
|
matplotlib.use('GTK3Agg')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def display_arr(screen, arr, video_size, transpose):
|
||||||
|
arr_min, arr_max = arr.min(), arr.max()
|
||||||
|
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
|
||||||
|
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
|
||||||
|
pyg_img = pygame.transform.scale(pyg_img, video_size)
|
||||||
|
screen.blit(pyg_img, (0,0))
|
||||||
|
|
||||||
|
def play(env, transpose=True, fps=30, zoom=None, callback=None, keys_to_action=None):
|
||||||
|
"""Allows one to play the game using keyboard.
|
||||||
|
|
||||||
|
To simply play the game use:
|
||||||
|
|
||||||
|
play(gym.make("Pong-v3"))
|
||||||
|
|
||||||
|
Above code works also if env is wrapped, so it's particularly useful in
|
||||||
|
verifying that the frame-level preprocessing does not render the game
|
||||||
|
unplayable.
|
||||||
|
|
||||||
|
If you wish to plot real time statistics as you play, you can use
|
||||||
|
gym.utils.play.PlayPlot. Here's a sample code for plotting the reward
|
||||||
|
for last 5 second of gameplay.
|
||||||
|
|
||||||
|
def callback(obs_t, obs_tp1, rew, done, info):
|
||||||
|
return [rew,]
|
||||||
|
env_plotter = EnvPlotter(callback, 30 * 5, ["reward"])
|
||||||
|
|
||||||
|
env = gym.make("Pong-v3")
|
||||||
|
play(env, callback=env_plotter.callback)
|
||||||
|
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
---------
|
||||||
|
env: gym.Env
|
||||||
|
Environment to use for playing.
|
||||||
|
transpose: bool
|
||||||
|
If True the output of observation is transposed.
|
||||||
|
Defaults to true.
|
||||||
|
fps: int
|
||||||
|
Maximum number of steps of the environment to execute every second.
|
||||||
|
Defaults to 30.
|
||||||
|
zoom: float
|
||||||
|
Make screen edge this many times bigger
|
||||||
|
callback: lambda or None
|
||||||
|
Callback if a callback is provided it will be executed after
|
||||||
|
every step. It takes the following input:
|
||||||
|
obs_t: observation before performing action
|
||||||
|
obs_tp1: observation after performing action
|
||||||
|
action: action that was executed
|
||||||
|
rew: reward that was received
|
||||||
|
done: whether the environemnt is done or not
|
||||||
|
info: debug info
|
||||||
|
keys_to_action: dict: tuple(int) -> int or None
|
||||||
|
Mapping from keys pressed to action performed.
|
||||||
|
For example if pressed 'w' and space at the same time is supposed
|
||||||
|
to trigger action number 2 then key_to_action dict would look like this:
|
||||||
|
|
||||||
|
{
|
||||||
|
# ...
|
||||||
|
sorted(ord('w'), ord(' ')) -> 2
|
||||||
|
# ...
|
||||||
|
}
|
||||||
|
If None, default key_to_action mapping for that env is used, if provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
obs_s = env.observation_space
|
||||||
|
assert type(obs_s) == gym.spaces.box.Box
|
||||||
|
assert len(obs_s.shape) == 2 or (len(obs_s.shape) == 3 and obs_s.shape[2] in [1,3])
|
||||||
|
|
||||||
|
if keys_to_action is None:
|
||||||
|
if hasattr(env, 'get_keys_to_action'):
|
||||||
|
keys_to_action = env.get_keys_to_action()
|
||||||
|
elif hasattr(env.unwrapped, 'get_keys_to_action'):
|
||||||
|
keys_to_action = env.unwrapped.get_keys_to_action()
|
||||||
|
else:
|
||||||
|
assert False, env.spec.id + " does not have explicit key to action mapping, " + \
|
||||||
|
"please specify one manually"
|
||||||
|
relevant_keys = set(sum(map(list, keys_to_action.keys()),[]))
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
video_size = env.observation_space.shape[1], env.observation_space.shape[0]
|
||||||
|
else:
|
||||||
|
video_size = env.observation_space.shape[0], env.observation_space.shape[1]
|
||||||
|
|
||||||
|
if zoom is not None:
|
||||||
|
video_size = int(video_size[0] * zoom), int(video_size[1] * zoom)
|
||||||
|
|
||||||
|
pressed_keys = []
|
||||||
|
running = True
|
||||||
|
env_done = True
|
||||||
|
|
||||||
|
screen = pygame.display.set_mode(video_size)
|
||||||
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
|
||||||
|
while running:
|
||||||
|
if env_done:
|
||||||
|
env_done = False
|
||||||
|
obs = env.reset()
|
||||||
|
else:
|
||||||
|
action = keys_to_action[tuple(sorted(pressed_keys))]
|
||||||
|
prev_obs = obs
|
||||||
|
obs, rew, env_done, info = env.step(action)
|
||||||
|
if callback is not None:
|
||||||
|
callback(prev_obs, obs, action, rew, env_done, info)
|
||||||
|
if obs is not None:
|
||||||
|
if len(obs.shape) == 2:
|
||||||
|
obs = obs[:, :, None]
|
||||||
|
if obs.shape[2] == 1:
|
||||||
|
obs = obs.repeat(3, axis=2)
|
||||||
|
display_arr(screen, obs, transpose=transpose, video_size=video_size)
|
||||||
|
|
||||||
|
# process pygame events
|
||||||
|
for event in pygame.event.get():
|
||||||
|
# test events, set key states
|
||||||
|
if event.type == pygame.KEYDOWN:
|
||||||
|
if event.key in relevant_keys:
|
||||||
|
pressed_keys.append(event.key)
|
||||||
|
elif event.key == 27:
|
||||||
|
running = False
|
||||||
|
elif event.type == pygame.KEYUP:
|
||||||
|
if event.key in relevant_keys:
|
||||||
|
pressed_keys.remove(event.key)
|
||||||
|
elif event.type == pygame.QUIT:
|
||||||
|
running = False
|
||||||
|
elif event.type == VIDEORESIZE:
|
||||||
|
video_size = event.size
|
||||||
|
screen = pygame.display.set_mode(video_size)
|
||||||
|
print(video_size)
|
||||||
|
|
||||||
|
pygame.display.flip()
|
||||||
|
clock.tick(fps)
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
class PlayPlot(object):
|
||||||
|
def __init__(self, callback, horizon_timesteps, plot_names):
|
||||||
|
self.data_callback = callback
|
||||||
|
self.horizon_timesteps = horizon_timesteps
|
||||||
|
self.plot_names = plot_names
|
||||||
|
|
||||||
|
num_plots = len(self.plot_names)
|
||||||
|
self.fig, self.ax = plt.subplots(num_plots)
|
||||||
|
if num_plots == 1:
|
||||||
|
self.ax = [self.ax]
|
||||||
|
for axis, name in zip(self.ax, plot_names):
|
||||||
|
axis.set_title(name)
|
||||||
|
self.t = 0
|
||||||
|
self.cur_plot = [None for _ in range(num_plots)]
|
||||||
|
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||||
|
|
||||||
|
def callback(self, obs_t, obs_tp1, action, rew, done, info):
|
||||||
|
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
|
||||||
|
for point, data_series in zip(points, self.data):
|
||||||
|
data_series.append(point)
|
||||||
|
self.t += 1
|
||||||
|
|
||||||
|
xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t
|
||||||
|
|
||||||
|
for i, plot in enumerate(self.cur_plot):
|
||||||
|
if plot is not None:
|
||||||
|
plot.remove()
|
||||||
|
self.cur_plot[i] = self.ax[i].scatter(range(xmin, xmax), list(self.data[i]))
|
||||||
|
self.ax[i].set_xlim(xmin, xmax)
|
||||||
|
plt.pause(0.000001)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from rl_algs.common.atari_wrappers import wrap_deepmind
|
||||||
|
|
||||||
|
def callback(obs_t, obs_tp1, action, rew, done, info):
|
||||||
|
return [rew, obs_t.mean()]
|
||||||
|
env_plotter = EnvPlotter(callback, 30 * 5, ["reward", "mean intensity"])
|
||||||
|
|
||||||
|
env = gym.make("MontezumaRevengeNoFrameskip-v3")
|
||||||
|
env = wrap_deepmind(env)
|
||||||
|
|
||||||
|
play_env(env, zoom=4, callback=env_plotter.callback, fps=30)
|
||||||
|
|
@@ -1,32 +1,44 @@
|
|||||||
import gym
|
import gym
|
||||||
from gym import monitoring
|
|
||||||
from gym import Wrapper
|
from gym import Wrapper
|
||||||
from gym.wrappers.time_limit import TimeLimit
|
from gym import error, version
|
||||||
from gym import error
|
import os, json, logging, numpy as np, six
|
||||||
|
from gym.monitoring import stats_recorder, video_recorder
|
||||||
import logging
|
from gym.utils import atomic_write, closer
|
||||||
|
from gym.utils.json_utils import json_encode_np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FILE_PREFIX = 'openaigym'
|
||||||
|
MANIFEST_PREFIX = FILE_PREFIX + '.manifest'
|
||||||
|
|
||||||
class _Monitor(Wrapper):
|
class _Monitor(Wrapper):
|
||||||
def __init__(self, env, directory, video_callable=None, force=False, resume=False,
|
def __init__(self, env, directory, video_callable=None, force=False, resume=False,
|
||||||
write_upon_reset=False, uid=None, mode=None):
|
write_upon_reset=False, uid=None, mode=None):
|
||||||
super(_Monitor, self).__init__(env)
|
super(_Monitor, self).__init__(env)
|
||||||
self._monitor = monitoring.MonitorManager(env)
|
|
||||||
self._monitor.start(directory, video_callable, force, resume,
|
self.videos = []
|
||||||
|
|
||||||
|
self.stats_recorder = None
|
||||||
|
self.video_recorder = None
|
||||||
|
self.enabled = False
|
||||||
|
self.episode_id = 0
|
||||||
|
self._monitor_id = None
|
||||||
|
self.env_semantics_autoreset = env.metadata.get('semantics.autoreset')
|
||||||
|
|
||||||
|
self._start(directory, video_callable, force, resume,
|
||||||
write_upon_reset, uid, mode)
|
write_upon_reset, uid, mode)
|
||||||
|
|
||||||
def _step(self, action):
|
def _step(self, action):
|
||||||
self._monitor._before_step(action)
|
self._before_step(action)
|
||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
done = self._monitor._after_step(observation, reward, done, info)
|
done = self._after_step(observation, reward, done, info)
|
||||||
|
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
def _reset(self):
|
def _reset(self):
|
||||||
self._monitor._before_reset()
|
self._before_reset()
|
||||||
observation = self.env.reset()
|
observation = self.env.reset()
|
||||||
self._monitor._after_reset(observation)
|
self._after_reset(observation)
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
@@ -35,17 +47,347 @@ class _Monitor(Wrapper):
|
|||||||
|
|
||||||
# _monitor will not be set if super(Monitor, self).__init__ raises, this check prevents a confusing error message
|
# _monitor will not be set if super(Monitor, self).__init__ raises, this check prevents a confusing error message
|
||||||
if getattr(self, '_monitor', None):
|
if getattr(self, '_monitor', None):
|
||||||
self._monitor.close()
|
self.close()
|
||||||
|
|
||||||
def set_monitor_mode(self, mode):
|
def set_monitor_mode(self, mode):
|
||||||
logger.info("Setting the monitor mode is deprecated and will be removed soon")
|
logger.info("Setting the monitor mode is deprecated and will be removed soon")
|
||||||
self._monitor._set_mode(mode)
|
self._set_mode(mode)
|
||||||
|
|
||||||
|
|
||||||
|
def _start(self, directory, video_callable=None, force=False, resume=False,
|
||||||
|
write_upon_reset=False, uid=None, mode=None):
|
||||||
|
"""Start monitoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory (str): A per-training run directory where to record stats.
|
||||||
|
video_callable (Optional[function, False]): function that takes in the index of the episode and outputs a boolean, indicating whether we should record a video on this episode. The default (for video_callable is None) is to take perfect cubes, capped at 1000. False disables video recording.
|
||||||
|
force (bool): Clear out existing training data from this directory (by deleting every file prefixed with "openaigym.").
|
||||||
|
resume (bool): Retain the training data already in this directory, which will be merged with our new data
|
||||||
|
write_upon_reset (bool): Write the manifest file on each reset. (This is currently a JSON file, so writing it is somewhat expensive.)
|
||||||
|
uid (Optional[str]): A unique id used as part of the suffix for the file. By default, uses os.getpid().
|
||||||
|
mode (['evaluation', 'training']): Whether this is an evaluation or training episode.
|
||||||
|
"""
|
||||||
|
if self.env.spec is None:
|
||||||
|
logger.warning("Trying to monitor an environment which has no 'spec' set. This usually means you did not create it via 'gym.make', and is recommended only for advanced users.")
|
||||||
|
env_id = '(unknown)'
|
||||||
|
else:
|
||||||
|
env_id = self.env.spec.id
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
logger.info('Creating monitor directory %s', directory)
|
||||||
|
if six.PY3:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
else:
|
||||||
|
os.makedirs(directory)
|
||||||
|
|
||||||
|
if video_callable is None:
|
||||||
|
video_callable = capped_cubic_video_schedule
|
||||||
|
elif video_callable == False:
|
||||||
|
video_callable = disable_videos
|
||||||
|
elif not callable(video_callable):
|
||||||
|
raise error.Error('You must provide a function, None, or False for video_callable, not {}: {}'.format(type(video_callable), video_callable))
|
||||||
|
self.video_callable = video_callable
|
||||||
|
|
||||||
|
# Check on whether we need to clear anything
|
||||||
|
if force:
|
||||||
|
clear_monitor_files(directory)
|
||||||
|
elif not resume:
|
||||||
|
training_manifests = detect_training_manifests(directory)
|
||||||
|
if len(training_manifests) > 0:
|
||||||
|
raise error.Error('''Trying to write to monitor directory {} with existing monitor files: {}.
|
||||||
|
|
||||||
|
You should use a unique directory for each training run, or use 'force=True' to automatically clear previous monitor files.'''.format(directory, ', '.join(training_manifests[:5])))
|
||||||
|
|
||||||
|
self._monitor_id = monitor_closer.register(self)
|
||||||
|
|
||||||
|
self.enabled = True
|
||||||
|
self.directory = os.path.abspath(directory)
|
||||||
|
# We use the 'openai-gym' prefix to determine if a file is
|
||||||
|
# ours
|
||||||
|
self.file_prefix = FILE_PREFIX
|
||||||
|
self.file_infix = '{}.{}'.format(self._monitor_id, uid if uid else os.getpid())
|
||||||
|
|
||||||
|
self.stats_recorder = stats_recorder.StatsRecorder(directory, '{}.episode_batch.{}'.format(self.file_prefix, self.file_infix), autoreset=self.env_semantics_autoreset, env_id=env_id)
|
||||||
|
|
||||||
|
if not os.path.exists(directory): os.mkdir(directory)
|
||||||
|
self.write_upon_reset = write_upon_reset
|
||||||
|
|
||||||
|
if mode is not None:
|
||||||
|
self._set_mode(mode)
|
||||||
|
|
||||||
|
def _flush(self, force=False):
|
||||||
|
"""Flush all relevant monitor information to disk."""
|
||||||
|
if not self.write_upon_reset and not force:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.stats_recorder.flush()
|
||||||
|
|
||||||
|
# Give it a very distiguished name, since we need to pick it
|
||||||
|
# up from the filesystem later.
|
||||||
|
path = os.path.join(self.directory, '{}.manifest.{}.manifest.json'.format(self.file_prefix, self.file_infix))
|
||||||
|
logger.debug('Writing training manifest file to %s', path)
|
||||||
|
with atomic_write.atomic_write(path) as f:
|
||||||
|
# We need to write relative paths here since people may
|
||||||
|
# move the training_dir around. It would be cleaner to
|
||||||
|
# already have the basenames rather than basename'ing
|
||||||
|
# manually, but this works for now.
|
||||||
|
json.dump({
|
||||||
|
'stats': os.path.basename(self.stats_recorder.path),
|
||||||
|
'videos': [(os.path.basename(v), os.path.basename(m))
|
||||||
|
for v, m in self.videos],
|
||||||
|
'env_info': self._env_info(),
|
||||||
|
}, f, default=json_encode_np)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Flush all monitor data to disk and close any open rending windows."""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
self.stats_recorder.close()
|
||||||
|
if self.video_recorder is not None:
|
||||||
|
self._close_video_recorder()
|
||||||
|
self._flush(force=True)
|
||||||
|
|
||||||
|
# Stop tracking this for autoclose
|
||||||
|
monitor_closer.unregister(self._monitor_id)
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
logger.info('''Finished writing results. You can upload them to the scoreboard via gym.upload(%r)''', self.directory)
|
||||||
|
|
||||||
|
def _set_mode(self, mode):
|
||||||
|
if mode == 'evaluation':
|
||||||
|
type = 'e'
|
||||||
|
elif mode == 'training':
|
||||||
|
type = 't'
|
||||||
|
else:
|
||||||
|
raise error.Error('Invalid mode {}: must be "training" or "evaluation"', mode)
|
||||||
|
self.stats_recorder.type = type
|
||||||
|
|
||||||
|
def _before_step(self, action):
|
||||||
|
if not self.enabled: return
|
||||||
|
self.stats_recorder.before_step(action)
|
||||||
|
|
||||||
|
def _after_step(self, observation, reward, done, info):
|
||||||
|
if not self.enabled: return done
|
||||||
|
|
||||||
|
if done and self.env_semantics_autoreset:
|
||||||
|
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
|
||||||
|
self._reset_video_recorder()
|
||||||
|
self.episode_id += 1
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
if info.get('true_reward', None): # Semisupervised envs modify the rewards, but we want the original when scoring
|
||||||
|
reward = info['true_reward']
|
||||||
|
|
||||||
|
# Record stats
|
||||||
|
self.stats_recorder.after_step(observation, reward, done, info)
|
||||||
|
# Record video
|
||||||
|
self.video_recorder.capture_frame()
|
||||||
|
|
||||||
|
return done
|
||||||
|
|
||||||
|
def _before_reset(self):
|
||||||
|
if not self.enabled: return
|
||||||
|
self.stats_recorder.before_reset()
|
||||||
|
|
||||||
|
def _after_reset(self, observation):
|
||||||
|
if not self.enabled: return
|
||||||
|
|
||||||
|
# Reset the stat count
|
||||||
|
self.stats_recorder.after_reset(observation)
|
||||||
|
|
||||||
|
self._reset_video_recorder()
|
||||||
|
|
||||||
|
# Bump *after* all reset activity has finished
|
||||||
|
self.episode_id += 1
|
||||||
|
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
def _reset_video_recorder(self):
|
||||||
|
# Close any existing video recorder
|
||||||
|
if self.video_recorder:
|
||||||
|
self._close_video_recorder()
|
||||||
|
|
||||||
|
# Start recording the next video.
|
||||||
|
#
|
||||||
|
# TODO: calculate a more correct 'episode_id' upon merge
|
||||||
|
self.video_recorder = video_recorder.VideoRecorder(
|
||||||
|
env=self.env,
|
||||||
|
base_path=os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.episode_id)),
|
||||||
|
metadata={'episode_id': self.episode_id},
|
||||||
|
enabled=self._video_enabled(),
|
||||||
|
)
|
||||||
|
self.video_recorder.capture_frame()
|
||||||
|
|
||||||
|
def _close_video_recorder(self):
|
||||||
|
self.video_recorder.close()
|
||||||
|
if self.video_recorder.functional:
|
||||||
|
self.videos.append((self.video_recorder.path, self.video_recorder.metadata_path))
|
||||||
|
|
||||||
|
def _video_enabled(self):
|
||||||
|
return self.video_callable(self.episode_id)
|
||||||
|
|
||||||
|
def _env_info(self):
|
||||||
|
env_info = {
|
||||||
|
'gym_version': version.VERSION,
|
||||||
|
}
|
||||||
|
if self.env.spec:
|
||||||
|
env_info['env_id'] = self.env.spec.id
|
||||||
|
return env_info
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# Make sure we've closed up shop when garbage collecting
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def get_total_steps(self):
|
||||||
|
return self.stats_recorder.total_steps
|
||||||
|
|
||||||
|
def get_episode_rewards(self):
|
||||||
|
return self.stats_recorder.episode_rewards
|
||||||
|
|
||||||
|
def get_episode_lengths(self):
|
||||||
|
return self.stats_recorder.episode_lengths
|
||||||
|
|
||||||
|
|
||||||
def Monitor(env=None, directory=None, video_callable=None, force=False, resume=False,
|
def Monitor(env=None, directory=None, video_callable=None, force=False, resume=False,
|
||||||
write_upon_reset=False, uid=None, mode=None):
|
write_upon_reset=False, uid=None, mode=None):
|
||||||
if not isinstance(env, gym.Env):
|
if not isinstance(env, gym.Env):
|
||||||
raise error.Error("Monitor decorator syntax is deprecated as of 12/28/2016. Replace your call to `env = gym.wrappers.Monitor(directory)(env)` with `env = gym.wrappers.Monitor(env, directory)`")
|
raise error.Error("Monitor decorator syntax is deprecated as of 12/28/2016. Replace your call to `env = gym.wrappers.Monitor(directory)(env)` with `env = gym.wrappers.Monitor(env, directory)`")
|
||||||
|
return _Monitor(env, directory, video_callable, force, resume, write_upon_reset, uid, mode)
|
||||||
|
|
||||||
# TODO: add duration in seconds also
|
def detect_training_manifests(training_dir, files=None):
|
||||||
return _Monitor(TimeLimit(env, max_episode_steps=env.spec.timestep_limit), directory, video_callable, force, resume,
|
if files is None:
|
||||||
write_upon_reset, uid, mode)
|
files = os.listdir(training_dir)
|
||||||
|
return [os.path.join(training_dir, f) for f in files if f.startswith(MANIFEST_PREFIX + '.')]
|
||||||
|
|
||||||
|
def detect_monitor_files(training_dir):
|
||||||
|
return [os.path.join(training_dir, f) for f in os.listdir(training_dir) if f.startswith(FILE_PREFIX + '.')]
|
||||||
|
|
||||||
|
def clear_monitor_files(training_dir):
|
||||||
|
files = detect_monitor_files(training_dir)
|
||||||
|
if len(files) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info('Clearing %d monitor files from previous run (because force=True was provided)', len(files))
|
||||||
|
for file in files:
|
||||||
|
os.unlink(file)
|
||||||
|
|
||||||
|
def capped_cubic_video_schedule(episode_id):
|
||||||
|
if episode_id < 1000:
|
||||||
|
return int(round(episode_id ** (1. / 3))) ** 3 == episode_id
|
||||||
|
else:
|
||||||
|
return episode_id % 1000 == 0
|
||||||
|
|
||||||
|
def disable_videos(episode_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monitor_closer = closer.Closer()
|
||||||
|
|
||||||
|
# This method gets used for a sanity check in scoreboard/api.py. It's
|
||||||
|
# not intended for use outside of the gym codebase.
|
||||||
|
def _open_monitors():
|
||||||
|
return list(monitor_closer.closeables.values())
|
||||||
|
|
||||||
|
def load_env_info_from_manifests(manifests, training_dir):
|
||||||
|
env_infos = []
|
||||||
|
for manifest in manifests:
|
||||||
|
with open(manifest) as f:
|
||||||
|
contents = json.load(f)
|
||||||
|
env_infos.append(contents['env_info'])
|
||||||
|
|
||||||
|
env_info = collapse_env_infos(env_infos, training_dir)
|
||||||
|
return env_info
|
||||||
|
|
||||||
|
def load_results(training_dir):
|
||||||
|
if not os.path.exists(training_dir):
|
||||||
|
logger.error('Training directory %s not found', training_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
manifests = detect_training_manifests(training_dir)
|
||||||
|
if not manifests:
|
||||||
|
logger.error('No manifests found in training directory %s', training_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug('Uploading data from manifest %s', ', '.join(manifests))
|
||||||
|
|
||||||
|
# Load up stats + video files
|
||||||
|
stats_files = []
|
||||||
|
videos = []
|
||||||
|
env_infos = []
|
||||||
|
|
||||||
|
for manifest in manifests:
|
||||||
|
with open(manifest) as f:
|
||||||
|
contents = json.load(f)
|
||||||
|
# Make these paths absolute again
|
||||||
|
stats_files.append(os.path.join(training_dir, contents['stats']))
|
||||||
|
videos += [(os.path.join(training_dir, v), os.path.join(training_dir, m))
|
||||||
|
for v, m in contents['videos']]
|
||||||
|
env_infos.append(contents['env_info'])
|
||||||
|
|
||||||
|
env_info = collapse_env_infos(env_infos, training_dir)
|
||||||
|
data_sources, initial_reset_timestamps, timestamps, episode_lengths, episode_rewards, episode_types, initial_reset_timestamp = merge_stats_files(stats_files)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'manifests': manifests,
|
||||||
|
'env_info': env_info,
|
||||||
|
'data_sources': data_sources,
|
||||||
|
'timestamps': timestamps,
|
||||||
|
'episode_lengths': episode_lengths,
|
||||||
|
'episode_rewards': episode_rewards,
|
||||||
|
'episode_types': episode_types,
|
||||||
|
'initial_reset_timestamps': initial_reset_timestamps,
|
||||||
|
'initial_reset_timestamp': initial_reset_timestamp,
|
||||||
|
'videos': videos,
|
||||||
|
}
|
||||||
|
|
||||||
|
def merge_stats_files(stats_files):
|
||||||
|
timestamps = []
|
||||||
|
episode_lengths = []
|
||||||
|
episode_rewards = []
|
||||||
|
episode_types = []
|
||||||
|
initial_reset_timestamps = []
|
||||||
|
data_sources = []
|
||||||
|
|
||||||
|
for i, path in enumerate(stats_files):
|
||||||
|
with open(path) as f:
|
||||||
|
content = json.load(f)
|
||||||
|
if len(content['timestamps'])==0: continue # so empty file doesn't mess up results, due to null initial_reset_timestamp
|
||||||
|
data_sources += [i] * len(content['timestamps'])
|
||||||
|
timestamps += content['timestamps']
|
||||||
|
episode_lengths += content['episode_lengths']
|
||||||
|
episode_rewards += content['episode_rewards']
|
||||||
|
# Recent addition
|
||||||
|
episode_types += content.get('episode_types', [])
|
||||||
|
# Keep track of where each episode came from.
|
||||||
|
initial_reset_timestamps.append(content['initial_reset_timestamp'])
|
||||||
|
|
||||||
|
idxs = np.argsort(timestamps)
|
||||||
|
timestamps = np.array(timestamps)[idxs].tolist()
|
||||||
|
episode_lengths = np.array(episode_lengths)[idxs].tolist()
|
||||||
|
episode_rewards = np.array(episode_rewards)[idxs].tolist()
|
||||||
|
data_sources = np.array(data_sources)[idxs].tolist()
|
||||||
|
|
||||||
|
if episode_types:
|
||||||
|
episode_types = np.array(episode_types)[idxs].tolist()
|
||||||
|
else:
|
||||||
|
episode_types = None
|
||||||
|
|
||||||
|
if len(initial_reset_timestamps) > 0:
|
||||||
|
initial_reset_timestamp = min(initial_reset_timestamps)
|
||||||
|
else:
|
||||||
|
initial_reset_timestamp = 0
|
||||||
|
|
||||||
|
return data_sources, initial_reset_timestamps, timestamps, episode_lengths, episode_rewards, episode_types, initial_reset_timestamp
|
||||||
|
|
||||||
|
# TODO training_dir isn't used except for error messages, clean up the layering
|
||||||
|
def collapse_env_infos(env_infos, training_dir):
|
||||||
|
assert len(env_infos) > 0
|
||||||
|
|
||||||
|
first = env_infos[0]
|
||||||
|
for other in env_infos[1:]:
|
||||||
|
if first != other:
|
||||||
|
raise error.Error('Found two unequal env_infos: {} and {}. This usually indicates that your training directory {} has commingled results from multiple runs.'.format(first, other, training_dir))
|
||||||
|
|
||||||
|
for key in ['env_id', 'gym_version']:
|
||||||
|
if key not in first:
|
||||||
|
raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key))
|
||||||
|
return first
|
@@ -16,7 +16,6 @@ def test_skip():
|
|||||||
|
|
||||||
def test_configured():
|
def test_configured():
|
||||||
env = gym.make("FrozenLake-v0")
|
env = gym.make("FrozenLake-v0")
|
||||||
env = wrappers.TimeLimit(env)
|
|
||||||
env.configure()
|
env.configure()
|
||||||
|
|
||||||
# Make sure all layers of wrapping are configured
|
# Make sure all layers of wrapping are configured
|
||||||
|
@@ -37,12 +37,9 @@ class TimeLimit(Wrapper):
|
|||||||
self._elapsed_steps += 1
|
self._elapsed_steps += 1
|
||||||
|
|
||||||
if self._past_limit():
|
if self._past_limit():
|
||||||
# TODO(jie) we are resetting and discarding the observation here.
|
if self.metadata.get('semantics.autoreset'):
|
||||||
# This _should_ be fine since you can always call reset() again to
|
_ = self.reset() # automatically reset the env
|
||||||
# get a new, freshly initialized observation, but it would be better
|
done = True
|
||||||
# to clean this up.
|
|
||||||
_ = self.reset() # Force a reset, discard the observation
|
|
||||||
done = True # Force a done = True
|
|
||||||
|
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
@@ -46,7 +46,7 @@ def main():
|
|||||||
alldata = {}
|
alldata = {}
|
||||||
for i in xrange(2):
|
for i in xrange(2):
|
||||||
np.random.seed(i)
|
np.random.seed(i)
|
||||||
data = rollout(env, agent, env.spec.tags['wrapper_config.TimeLimit.max_episode_steps'])
|
data = rollout(env, agent, env.spec.max_episode_steps)
|
||||||
for (k, v) in data.items():
|
for (k, v) in data.items():
|
||||||
alldata["%i-%s"%(i, k)] = v
|
alldata["%i-%s"%(i, k)] = v
|
||||||
np.savez(args.outfile, **alldata)
|
np.savez(args.outfile, **alldata)
|
||||||
|
Reference in New Issue
Block a user