Return output from render method in a right way (#1248)

* Close output StringIO after returning value

* Test render output is immutable
This commit is contained in:
Misha Behersky
2019-02-09 02:58:51 +02:00
committed by pzhokhov
parent 5f73c5dff5
commit 0659520f8c
6 changed files with 87 additions and 63 deletions

View File

@@ -33,10 +33,11 @@ leveling up many times to reach their reward threshold.
from gym import Env, logger from gym import Env, logger
from gym.spaces import Discrete, Tuple from gym.spaces import Discrete, Tuple
from gym.utils import colorize, seeding from gym.utils import colorize, seeding
import sys
from contextlib import closing
import numpy as np import numpy as np
from six import StringIO from six import StringIO
import sys
import math
class AlgorithmicEnv(Env): class AlgorithmicEnv(Env):
@@ -112,7 +113,6 @@ class AlgorithmicEnv(Env):
raise NotImplementedError raise NotImplementedError
def render(self, mode='human'): def render(self, mode='human'):
outfile = StringIO() if mode == 'ansi' else sys.stdout outfile = StringIO() if mode == 'ansi' else sys.stdout
inp = "Total length of input instance: %d, step: %d\n" % (self.input_width, self.time) inp = "Total length of input instance: %d, step: %d\n" % (self.input_width, self.time)
outfile.write(inp) outfile.write(inp)
@@ -149,7 +149,10 @@ class AlgorithmicEnv(Env):
outfile.write(" prediction: %s)\n" % pred_str) outfile.write(" prediction: %s)\n" % pred_str)
else: else:
outfile.write("\n" * 5) outfile.write("\n" * 5)
return outfile
if mode != 'human':
with closing(outfile):
return outfile.getvalue()
@property @property
def input_width(self): def input_width(self):
@@ -234,6 +237,7 @@ class AlgorithmicEnv(Env):
def _move(self, movement): def _move(self, movement):
raise NotImplemented raise NotImplemented
class TapeAlgorithmicEnv(AlgorithmicEnv): class TapeAlgorithmicEnv(AlgorithmicEnv):
"""An algorithmic env with a 1-d input tape.""" """An algorithmic env with a 1-d input tape."""
MOVEMENTS = ['left', 'right'] MOVEMENTS = ['left', 'right']
@@ -269,6 +273,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
x_str += "\n" x_str += "\n"
return x_str return x_str
class GridAlgorithmicEnv(AlgorithmicEnv): class GridAlgorithmicEnv(AlgorithmicEnv):
"""An algorithmic env with a 2-d input grid.""" """An algorithmic env with a 2-d input grid."""
MOVEMENTS = ['left', 'right', 'up', 'down'] MOVEMENTS = ['left', 'right', 'up', 'down']

View File

@@ -48,3 +48,17 @@ def test_random_rollout():
(ob, _reward, done, _info) = env.step(a) (ob, _reward, done, _info) = env.step(a)
if done: break if done: break
env.close() env.close()
def test_env_render_result_is_immutable():
environs = [
envs.make('Taxi-v2'),
envs.make('FrozenLake-v0'),
envs.make('Reverse-v0'),
]
for env in environs:
env.reset()
output = env.render(mode='ansi')
assert isinstance(output, str)
env.close()

View File

@@ -21,75 +21,75 @@ steps = ROLLOUT_STEPS
ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout.json') ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout.json')
if not os.path.isfile(ROLLOUT_FILE): if not os.path.isfile(ROLLOUT_FILE):
with open(ROLLOUT_FILE, "w") as outfile: with open(ROLLOUT_FILE, "w") as outfile:
json.dump({}, outfile, indent=2) json.dump({}, outfile, indent=2)
def hash_object(unhashed): def hash_object(unhashed):
return hashlib.sha256(str(unhashed).encode('utf-16')).hexdigest() # This is really bad, str could be same while values change return hashlib.sha256(str(unhashed).encode('utf-16')).hexdigest() # This is really bad, str could be same while values change
def generate_rollout_hash(spec): def generate_rollout_hash(spec):
spaces.seed(0) spaces.seed(0)
env = spec.make() env = spec.make()
env.seed(0) env.seed(0)
observation_list = [] observation_list = []
action_list = [] action_list = []
reward_list = [] reward_list = []
done_list = [] done_list = []
total_steps = 0 total_steps = 0
for episode in range(episodes): for episode in range(episodes):
if total_steps >= ROLLOUT_STEPS: break if total_steps >= ROLLOUT_STEPS: break
observation = env.reset() observation = env.reset()
for step in range(steps): for step in range(steps):
action = env.action_space.sample() action = env.action_space.sample()
observation, reward, done, _ = env.step(action) observation, reward, done, _ = env.step(action)
action_list.append(action) action_list.append(action)
observation_list.append(observation) observation_list.append(observation)
reward_list.append(reward) reward_list.append(reward)
done_list.append(done) done_list.append(done)
total_steps += 1 total_steps += 1
if total_steps >= ROLLOUT_STEPS: break if total_steps >= ROLLOUT_STEPS: break
if done: break if done: break
observations_hash = hash_object(observation_list) observations_hash = hash_object(observation_list)
actions_hash = hash_object(action_list) actions_hash = hash_object(action_list)
rewards_hash = hash_object(reward_list) rewards_hash = hash_object(reward_list)
dones_hash = hash_object(done_list) dones_hash = hash_object(done_list)
env.close() env.close()
return observations_hash, actions_hash, rewards_hash, dones_hash return observations_hash, actions_hash, rewards_hash, dones_hash
@pytest.mark.parametrize("spec", spec_list) @pytest.mark.parametrize("spec", spec_list)
def test_env_semantics(spec): def test_env_semantics(spec):
logger.warn("Skipping this test. Existing hashes were generated in a bad way") logger.warn("Skipping this test. Existing hashes were generated in a bad way")
return return
with open(ROLLOUT_FILE) as data_file: with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file) rollout_dict = json.load(data_file)
if spec.id not in rollout_dict: if spec.id not in rollout_dict:
if not spec.nondeterministic: if not spec.nondeterministic:
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id)) logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
return return
logger.info("Testing rollout for {} environment...".format(spec.id)) logger.info("Testing rollout for {} environment...".format(spec.id))
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec) observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
errors = [] errors = []
if rollout_dict[spec.id]['observations'] != observations_now: if rollout_dict[spec.id]['observations'] != observations_now:
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now)) errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
if rollout_dict[spec.id]['actions'] != actions_now: if rollout_dict[spec.id]['actions'] != actions_now:
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now)) errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
if rollout_dict[spec.id]['rewards'] != rewards_now: if rollout_dict[spec.id]['rewards'] != rewards_now:
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now)) errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
if rollout_dict[spec.id]['dones'] != dones_now: if rollout_dict[spec.id]['dones'] != dones_now:
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now)) errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
if len(errors): if len(errors):
for error in errors: for error in errors:
logger.warn(error) logger.warn(error)
raise ValueError(errors) raise ValueError(errors)

View File

@@ -31,7 +31,7 @@ class DiscreteEnv(Env):
def __init__(self, nS, nA, P, isd): def __init__(self, nS, nA, P, isd):
self.P = P self.P = P
self.isd = isd self.isd = isd
self.lastaction=None # for rendering self.lastaction = None # for rendering
self.nS = nS self.nS = nS
self.nA = nA self.nA = nA
@@ -47,7 +47,7 @@ class DiscreteEnv(Env):
def reset(self): def reset(self):
self.s = categorical_sample(self.isd, self.np_random) self.s = categorical_sample(self.isd, self.np_random)
self.lastaction=None self.lastaction = None
return self.s return self.s
def step(self, a): def step(self, a):
@@ -55,5 +55,5 @@ class DiscreteEnv(Env):
i = categorical_sample([t[0] for t in transitions], self.np_random) i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d= transitions[i] p, s, r, d= transitions[i]
self.s = s self.s = s
self.lastaction=a self.lastaction = a
return (s, r, d, {"prob" : p}) return (s, r, d, {"prob" : p})

View File

@@ -1,5 +1,7 @@
import numpy as np
import sys import sys
from contextlib import closing
import numpy as np
from six import StringIO, b from six import StringIO, b
from gym import utils from gym import utils
@@ -129,4 +131,5 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
outfile.write("\n".join(''.join(line) for line in desc)+"\n") outfile.write("\n".join(''.join(line) for line in desc)+"\n")
if mode != 'human': if mode != 'human':
return outfile with closing(outfile):
return outfile.getvalue()

View File

@@ -1,4 +1,5 @@
import sys import sys
from contextlib import closing
from six import StringIO from six import StringIO
from gym import utils from gym import utils
from gym.envs.toy_text import discrete from gym.envs.toy_text import discrete
@@ -149,4 +150,5 @@ class TaxiEnv(discrete.DiscreteEnv):
# No need to return anything for human # No need to return anything for human
if mode != 'human': if mode != 'human':
return outfile with closing(outfile):
return outfile.getvalue()