""" Currently disabled since this was done in a very poor way Hashed str representation of objects """ import json import hashlib import os import pytest from gym import spaces, logger from tests.envs.spec_list import spec_list DATA_DIR = os.path.dirname(__file__) ROLLOUT_STEPS = 100 episodes = ROLLOUT_STEPS steps = ROLLOUT_STEPS ROLLOUT_FILE = os.path.join(DATA_DIR, "rollout.json") if not os.path.isfile(ROLLOUT_FILE): with open(ROLLOUT_FILE, "w") as outfile: json.dump({}, outfile, indent=2) def hash_object(unhashed): return hashlib.sha256( str(unhashed).encode("utf-16") ).hexdigest() # This is really bad, str could be same while values change def generate_rollout_hash(spec): spaces.seed(0) env = spec.make() env.seed(0) observation_list = [] action_list = [] reward_list = [] done_list = [] total_steps = 0 for episode in range(episodes): if total_steps >= ROLLOUT_STEPS: break observation = env.reset() for step in range(steps): action = env.action_space.sample() observation, reward, done, _ = env.step(action) action_list.append(action) observation_list.append(observation) reward_list.append(reward) done_list.append(done) total_steps += 1 if total_steps >= ROLLOUT_STEPS: break if done: break observations_hash = hash_object(observation_list) actions_hash = hash_object(action_list) rewards_hash = hash_object(reward_list) dones_hash = hash_object(done_list) env.close() return observations_hash, actions_hash, rewards_hash, dones_hash @pytest.mark.parametrize("spec", spec_list) def test_env_semantics(spec): logger.warn("Skipping this test. Existing hashes were generated in a bad way") return with open(ROLLOUT_FILE) as data_file: rollout_dict = json.load(data_file) if spec.id not in rollout_dict: if not spec.nondeterministic: logger.warn( "Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format( spec.id ) ) return logger.info("Testing rollout for {} environment...".format(spec.id)) observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec) errors = [] if rollout_dict[spec.id]["observations"] != observations_now: errors.append( "Observations not equal for {} -- expected {} but got {}".format( spec.id, rollout_dict[spec.id]["observations"], observations_now ) ) if rollout_dict[spec.id]["actions"] != actions_now: errors.append( "Actions not equal for {} -- expected {} but got {}".format( spec.id, rollout_dict[spec.id]["actions"], actions_now ) ) if rollout_dict[spec.id]["rewards"] != rewards_now: errors.append( "Rewards not equal for {} -- expected {} but got {}".format( spec.id, rollout_dict[spec.id]["rewards"], rewards_now ) ) if rollout_dict[spec.id]["dones"] != dones_now: errors.append( "Dones not equal for {} -- expected {} but got {}".format( spec.id, rollout_dict[spec.id]["dones"], dones_now ) ) if len(errors): for error in errors: logger.warn(error) raise ValueError(errors)