Files
Gymnasium/misc/write_rollout_data.py
Tom Brown d337f4e571 TimeLimit refactor with Monitor Simplification (#482)
* fix double reset, as suggested by @jietang

* better floors and ceilings

* add convenience methods to monitor

* add wrappers to gym namespace

* allow playing Atari games, with potentially more coming in the future

* simplify example in docs

* Move play out of the Env

* fix tests

* no more deprecation warnings

* remove env.monitor

* monitor simplification

* monitor simplifications

* monitor related fixes

* a few changes suggested by linter

* timestep_limit fixes

* keep track of gym env variables for future compatibility

* timestep_limit => max_episode_timesteps

* don't apply TimeLimit wrapper in make for VNC envs

* Respect old timestep_limit argument

* Pass max_episode_seconds through registration

* Don't include deprecation warnings yet
2017-02-01 13:10:59 -08:00

56 lines
1.6 KiB
Python

"""
This script does a few rollouts with an environment and writes the data to an npz file
Its purpose is to help with verifying that you haven't functionally changed an environment.
(If you have, you should bump the version number.)
"""
import argparse, numpy as np, collections, sys
from os import path
class RandomAgent(object):
def __init__(self, ac_space):
self.ac_space = ac_space
def act(self, _):
return self.ac_space.sample()
def rollout(env, agent, max_episode_steps):
"""
Simulate the env and agent for max_episode_steps
"""
ob = env.reset()
data = collections.defaultdict(list)
for _ in xrange(max_episode_steps):
data["observation"].append(ob)
action = agent.act(ob)
data["action"].append(action)
ob,rew,done,_ = env.step(action)
data["reward"].append(rew)
if done:
break
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument("envid")
parser.add_argument("outfile")
parser.add_argument("--gymdir")
args = parser.parse_args()
if args.gymdir:
sys.path.insert(0, args.gymdir)
import gym
from gym import utils
print utils.colorize("gym directory: %s"%path.dirname(gym.__file__), "yellow")
env = gym.make(args.envid)
agent = RandomAgent(env.action_space)
alldata = {}
for i in xrange(2):
np.random.seed(i)
data = rollout(env, agent, env.spec.max_episode_steps)
for (k, v) in data.items():
alldata["%i-%s"%(i, k)] = v
np.savez(args.outfile, **alldata)
if __name__ == "__main__":
main()