mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
Return output from render method in a right way (#1248)
* Close output StringIO after returning value * Test render output is immutable
This commit is contained in:
@@ -33,10 +33,11 @@ leveling up many times to reach their reward threshold.
|
|||||||
from gym import Env, logger
|
from gym import Env, logger
|
||||||
from gym.spaces import Discrete, Tuple
|
from gym.spaces import Discrete, Tuple
|
||||||
from gym.utils import colorize, seeding
|
from gym.utils import colorize, seeding
|
||||||
|
import sys
|
||||||
|
from contextlib import closing
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six import StringIO
|
from six import StringIO
|
||||||
import sys
|
|
||||||
import math
|
|
||||||
|
|
||||||
class AlgorithmicEnv(Env):
|
class AlgorithmicEnv(Env):
|
||||||
|
|
||||||
@@ -112,7 +113,6 @@ class AlgorithmicEnv(Env):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
|
|
||||||
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||||
inp = "Total length of input instance: %d, step: %d\n" % (self.input_width, self.time)
|
inp = "Total length of input instance: %d, step: %d\n" % (self.input_width, self.time)
|
||||||
outfile.write(inp)
|
outfile.write(inp)
|
||||||
@@ -149,7 +149,10 @@ class AlgorithmicEnv(Env):
|
|||||||
outfile.write(" prediction: %s)\n" % pred_str)
|
outfile.write(" prediction: %s)\n" % pred_str)
|
||||||
else:
|
else:
|
||||||
outfile.write("\n" * 5)
|
outfile.write("\n" * 5)
|
||||||
return outfile
|
|
||||||
|
if mode != 'human':
|
||||||
|
with closing(outfile):
|
||||||
|
return outfile.getvalue()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_width(self):
|
def input_width(self):
|
||||||
@@ -234,6 +237,7 @@ class AlgorithmicEnv(Env):
|
|||||||
def _move(self, movement):
|
def _move(self, movement):
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
|
|
||||||
class TapeAlgorithmicEnv(AlgorithmicEnv):
|
class TapeAlgorithmicEnv(AlgorithmicEnv):
|
||||||
"""An algorithmic env with a 1-d input tape."""
|
"""An algorithmic env with a 1-d input tape."""
|
||||||
MOVEMENTS = ['left', 'right']
|
MOVEMENTS = ['left', 'right']
|
||||||
@@ -269,6 +273,7 @@ class TapeAlgorithmicEnv(AlgorithmicEnv):
|
|||||||
x_str += "\n"
|
x_str += "\n"
|
||||||
return x_str
|
return x_str
|
||||||
|
|
||||||
|
|
||||||
class GridAlgorithmicEnv(AlgorithmicEnv):
|
class GridAlgorithmicEnv(AlgorithmicEnv):
|
||||||
"""An algorithmic env with a 2-d input grid."""
|
"""An algorithmic env with a 2-d input grid."""
|
||||||
MOVEMENTS = ['left', 'right', 'up', 'down']
|
MOVEMENTS = ['left', 'right', 'up', 'down']
|
||||||
|
@@ -48,3 +48,17 @@ def test_random_rollout():
|
|||||||
(ob, _reward, done, _info) = env.step(a)
|
(ob, _reward, done, _info) = env.step(a)
|
||||||
if done: break
|
if done: break
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_render_result_is_immutable():
|
||||||
|
environs = [
|
||||||
|
envs.make('Taxi-v2'),
|
||||||
|
envs.make('FrozenLake-v0'),
|
||||||
|
envs.make('Reverse-v0'),
|
||||||
|
]
|
||||||
|
|
||||||
|
for env in environs:
|
||||||
|
env.reset()
|
||||||
|
output = env.render(mode='ansi')
|
||||||
|
assert isinstance(output, str)
|
||||||
|
env.close()
|
||||||
|
@@ -21,75 +21,75 @@ steps = ROLLOUT_STEPS
|
|||||||
ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout.json')
|
ROLLOUT_FILE = os.path.join(DATA_DIR, 'rollout.json')
|
||||||
|
|
||||||
if not os.path.isfile(ROLLOUT_FILE):
|
if not os.path.isfile(ROLLOUT_FILE):
|
||||||
with open(ROLLOUT_FILE, "w") as outfile:
|
with open(ROLLOUT_FILE, "w") as outfile:
|
||||||
json.dump({}, outfile, indent=2)
|
json.dump({}, outfile, indent=2)
|
||||||
|
|
||||||
def hash_object(unhashed):
|
def hash_object(unhashed):
|
||||||
return hashlib.sha256(str(unhashed).encode('utf-16')).hexdigest() # This is really bad, str could be same while values change
|
return hashlib.sha256(str(unhashed).encode('utf-16')).hexdigest() # This is really bad, str could be same while values change
|
||||||
|
|
||||||
def generate_rollout_hash(spec):
|
def generate_rollout_hash(spec):
|
||||||
spaces.seed(0)
|
spaces.seed(0)
|
||||||
env = spec.make()
|
env = spec.make()
|
||||||
env.seed(0)
|
env.seed(0)
|
||||||
|
|
||||||
observation_list = []
|
observation_list = []
|
||||||
action_list = []
|
action_list = []
|
||||||
reward_list = []
|
reward_list = []
|
||||||
done_list = []
|
done_list = []
|
||||||
|
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
for episode in range(episodes):
|
for episode in range(episodes):
|
||||||
if total_steps >= ROLLOUT_STEPS: break
|
if total_steps >= ROLLOUT_STEPS: break
|
||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
|
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
observation, reward, done, _ = env.step(action)
|
observation, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
action_list.append(action)
|
action_list.append(action)
|
||||||
observation_list.append(observation)
|
observation_list.append(observation)
|
||||||
reward_list.append(reward)
|
reward_list.append(reward)
|
||||||
done_list.append(done)
|
done_list.append(done)
|
||||||
|
|
||||||
total_steps += 1
|
total_steps += 1
|
||||||
if total_steps >= ROLLOUT_STEPS: break
|
if total_steps >= ROLLOUT_STEPS: break
|
||||||
|
|
||||||
if done: break
|
if done: break
|
||||||
|
|
||||||
observations_hash = hash_object(observation_list)
|
observations_hash = hash_object(observation_list)
|
||||||
actions_hash = hash_object(action_list)
|
actions_hash = hash_object(action_list)
|
||||||
rewards_hash = hash_object(reward_list)
|
rewards_hash = hash_object(reward_list)
|
||||||
dones_hash = hash_object(done_list)
|
dones_hash = hash_object(done_list)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
return observations_hash, actions_hash, rewards_hash, dones_hash
|
return observations_hash, actions_hash, rewards_hash, dones_hash
|
||||||
|
|
||||||
@pytest.mark.parametrize("spec", spec_list)
|
@pytest.mark.parametrize("spec", spec_list)
|
||||||
def test_env_semantics(spec):
|
def test_env_semantics(spec):
|
||||||
logger.warn("Skipping this test. Existing hashes were generated in a bad way")
|
logger.warn("Skipping this test. Existing hashes were generated in a bad way")
|
||||||
return
|
return
|
||||||
with open(ROLLOUT_FILE) as data_file:
|
with open(ROLLOUT_FILE) as data_file:
|
||||||
rollout_dict = json.load(data_file)
|
rollout_dict = json.load(data_file)
|
||||||
|
|
||||||
if spec.id not in rollout_dict:
|
if spec.id not in rollout_dict:
|
||||||
if not spec.nondeterministic:
|
if not spec.nondeterministic:
|
||||||
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
|
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Testing rollout for {} environment...".format(spec.id))
|
logger.info("Testing rollout for {} environment...".format(spec.id))
|
||||||
|
|
||||||
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
|
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
if rollout_dict[spec.id]['observations'] != observations_now:
|
if rollout_dict[spec.id]['observations'] != observations_now:
|
||||||
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
|
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
|
||||||
if rollout_dict[spec.id]['actions'] != actions_now:
|
if rollout_dict[spec.id]['actions'] != actions_now:
|
||||||
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
|
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
|
||||||
if rollout_dict[spec.id]['rewards'] != rewards_now:
|
if rollout_dict[spec.id]['rewards'] != rewards_now:
|
||||||
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
|
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
|
||||||
if rollout_dict[spec.id]['dones'] != dones_now:
|
if rollout_dict[spec.id]['dones'] != dones_now:
|
||||||
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
|
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
|
||||||
if len(errors):
|
if len(errors):
|
||||||
for error in errors:
|
for error in errors:
|
||||||
logger.warn(error)
|
logger.warn(error)
|
||||||
raise ValueError(errors)
|
raise ValueError(errors)
|
||||||
|
@@ -31,7 +31,7 @@ class DiscreteEnv(Env):
|
|||||||
def __init__(self, nS, nA, P, isd):
|
def __init__(self, nS, nA, P, isd):
|
||||||
self.P = P
|
self.P = P
|
||||||
self.isd = isd
|
self.isd = isd
|
||||||
self.lastaction=None # for rendering
|
self.lastaction = None # for rendering
|
||||||
self.nS = nS
|
self.nS = nS
|
||||||
self.nA = nA
|
self.nA = nA
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class DiscreteEnv(Env):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.s = categorical_sample(self.isd, self.np_random)
|
self.s = categorical_sample(self.isd, self.np_random)
|
||||||
self.lastaction=None
|
self.lastaction = None
|
||||||
return self.s
|
return self.s
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
@@ -55,5 +55,5 @@ class DiscreteEnv(Env):
|
|||||||
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
i = categorical_sample([t[0] for t in transitions], self.np_random)
|
||||||
p, s, r, d= transitions[i]
|
p, s, r, d= transitions[i]
|
||||||
self.s = s
|
self.s = s
|
||||||
self.lastaction=a
|
self.lastaction = a
|
||||||
return (s, r, d, {"prob" : p})
|
return (s, r, d, {"prob" : p})
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import numpy as np
|
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from six import StringIO, b
|
from six import StringIO, b
|
||||||
|
|
||||||
from gym import utils
|
from gym import utils
|
||||||
@@ -129,4 +131,5 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
|||||||
outfile.write("\n".join(''.join(line) for line in desc)+"\n")
|
outfile.write("\n".join(''.join(line) for line in desc)+"\n")
|
||||||
|
|
||||||
if mode != 'human':
|
if mode != 'human':
|
||||||
return outfile
|
with closing(outfile):
|
||||||
|
return outfile.getvalue()
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from contextlib import closing
|
||||||
from six import StringIO
|
from six import StringIO
|
||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.toy_text import discrete
|
from gym.envs.toy_text import discrete
|
||||||
@@ -149,4 +150,5 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
|
|
||||||
# No need to return anything for human
|
# No need to return anything for human
|
||||||
if mode != 'human':
|
if mode != 'human':
|
||||||
return outfile
|
with closing(outfile):
|
||||||
|
return outfile.getvalue()
|
||||||
|
Reference in New Issue
Block a user