Return output from render method in a right way (#1248)

* Close output StringIO after returning value

* Test render output is immutable
This commit is contained in:
Misha Behersky
2019-02-09 02:58:51 +02:00
committed by pzhokhov
parent 5f73c5dff5
commit 0659520f8c
6 changed files with 87 additions and 63 deletions

View File

@@ -33,10 +33,11 @@ leveling up many times to reach their reward threshold.
from gym import Env, logger
from gym.spaces import Discrete, Tuple
from gym.utils import colorize, seeding
import sys
from contextlib import closing
import numpy as np
from six import StringIO
import sys
import math
class AlgorithmicEnv(Env):
@@ -112,7 +113,6 @@ class AlgorithmicEnv(Env):
raise NotImplementedError
def render(self, mode='human'):
outfile = StringIO() if mode == 'ansi' else sys.stdout
inp = "Total length of input instance: %d, step: %d\n" % (self.input_width, self.time)
outfile.write(inp)
@@ -149,7 +149,10 @@ class AlgorithmicEnv(Env):
outfile.write(" prediction: %s)\n" % pred_str)
else:
outfile.write("\n" * 5)
return outfile
if mode != 'human':
with closing(outfile):
return outfile.getvalue()
@property
def input_width(self):
@@ -234,6 +237,7 @@ class AlgorithmicEnv(Env):
def _move(self, movement):
raise NotImplemented
class TapeAlgorithmicEnv(AlgorithmicEnv):
"""An algorithmic env with a 1-d input tape."""
MOVEMENTS = ['left', 'right']
@@ -269,6 +273,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
x_str += "\n"
return x_str
class GridAlgorithmicEnv(AlgorithmicEnv):
"""An algorithmic env with a 2-d input grid."""
MOVEMENTS = ['left', 'right', 'up', 'down']

View File

@@ -48,3 +48,17 @@ def test_random_rollout():
(ob, _reward, done, _info) = env.step(a)
if done: break
env.close()
def test_env_render_result_is_immutable():
environs = [
envs.make('Taxi-v2'),
envs.make('FrozenLake-v0'),
envs.make('Reverse-v0'),
]
for env in environs:
env.reset()
output = env.render(mode='ansi')
assert isinstance(output, str)
env.close()

View File

@@ -21,75 +21,75 @@ steps = ROLLOUT_STEPS
ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout.json')
if not os.path.isfile(ROLLOUT_FILE):
with open(ROLLOUT_FILE, "w") as outfile:
json.dump({}, outfile, indent=2)
with open(ROLLOUT_FILE, "w") as outfile:
json.dump({}, outfile, indent=2)
def hash_object(unhashed):
return hashlib.sha256(str(unhashed).encode('utf-16')).hexdigest() # This is really bad, str could be same while values change
return hashlib.sha256(str(unhashed).encode('utf-16')).hexdigest() # This is really bad, str could be same while values change
def generate_rollout_hash(spec):
spaces.seed(0)
env = spec.make()
env.seed(0)
spaces.seed(0)
env = spec.make()
env.seed(0)
observation_list = []
action_list = []
reward_list = []
done_list = []
observation_list = []
action_list = []
reward_list = []
done_list = []
total_steps = 0
for episode in range(episodes):
if total_steps >= ROLLOUT_STEPS: break
observation = env.reset()
total_steps = 0
for episode in range(episodes):
if total_steps >= ROLLOUT_STEPS: break
observation = env.reset()
for step in range(steps):
action = env.action_space.sample()
observation, reward, done, _ = env.step(action)
for step in range(steps):
action = env.action_space.sample()
observation, reward, done, _ = env.step(action)
action_list.append(action)
observation_list.append(observation)
reward_list.append(reward)
done_list.append(done)
action_list.append(action)
observation_list.append(observation)
reward_list.append(reward)
done_list.append(done)
total_steps += 1
if total_steps >= ROLLOUT_STEPS: break
total_steps += 1
if total_steps >= ROLLOUT_STEPS: break
if done: break
if done: break
observations_hash = hash_object(observation_list)
actions_hash = hash_object(action_list)
rewards_hash = hash_object(reward_list)
dones_hash = hash_object(done_list)
observations_hash = hash_object(observation_list)
actions_hash = hash_object(action_list)
rewards_hash = hash_object(reward_list)
dones_hash = hash_object(done_list)
env.close()
return observations_hash, actions_hash, rewards_hash, dones_hash
env.close()
return observations_hash, actions_hash, rewards_hash, dones_hash
@pytest.mark.parametrize("spec", spec_list)
def test_env_semantics(spec):
logger.warn("Skipping this test. Existing hashes were generated in a bad way")
return
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
logger.warn("Skipping this test. Existing hashes were generated in a bad way")
return
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
return
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
return
logger.info("Testing rollout for {} environment...".format(spec.id))
logger.info("Testing rollout for {} environment...".format(spec.id))
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
errors = []
if rollout_dict[spec.id]['observations'] != observations_now:
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
if rollout_dict[spec.id]['actions'] != actions_now:
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
if rollout_dict[spec.id]['rewards'] != rewards_now:
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
if rollout_dict[spec.id]['dones'] != dones_now:
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
if len(errors):
for error in errors:
logger.warn(error)
raise ValueError(errors)
errors = []
if rollout_dict[spec.id]['observations'] != observations_now:
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
if rollout_dict[spec.id]['actions'] != actions_now:
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
if rollout_dict[spec.id]['rewards'] != rewards_now:
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
if rollout_dict[spec.id]['dones'] != dones_now:
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
if len(errors):
for error in errors:
logger.warn(error)
raise ValueError(errors)

View File

@@ -31,7 +31,7 @@ class DiscreteEnv(Env):
def __init__(self, nS, nA, P, isd):
self.P = P
self.isd = isd
self.lastaction=None # for rendering
self.lastaction = None # for rendering
self.nS = nS
self.nA = nA
@@ -47,7 +47,7 @@ class DiscreteEnv(Env):
def reset(self):
self.s = categorical_sample(self.isd, self.np_random)
self.lastaction=None
self.lastaction = None
return self.s
def step(self, a):
@@ -55,5 +55,5 @@ class DiscreteEnv(Env):
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, d= transitions[i]
self.s = s
self.lastaction=a
self.lastaction = a
return (s, r, d, {"prob" : p})

View File

@@ -1,5 +1,7 @@
import numpy as np
import sys
from contextlib import closing
import numpy as np
from six import StringIO, b
from gym import utils
@@ -129,4 +131,5 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
outfile.write("\n".join(''.join(line) for line in desc)+"\n")
if mode != 'human':
return outfile
with closing(outfile):
return outfile.getvalue()

View File

@@ -1,4 +1,5 @@
import sys
from contextlib import closing
from six import StringIO
from gym import utils
from gym.envs.toy_text import discrete
@@ -149,4 +150,5 @@ class TaxiEnv(discrete.DiscreteEnv):
# No need to return anything for human
if mode != 'human':
return outfile
with closing(outfile):
return outfile.getvalue()