Files
Gymnasium/tests/vector/test_step_compatibility_vector.py
John Balis 3a8daafce1 Removing return_info argument to env.reset() and deprecated env.seed() function (reset now always returns info) (#2962)
* removed return_info, made info dict mandatory in reset

* tenatively removed deprecated seed api for environments

* added more info type checks to wrapper tests

* formatting/style compliance

* addressed some comments

* polish to address review

* fixed tests after merge, and added a test of the return_info deprecation assertion if found in reset signature

* some organization of env_checker tests, reverted a probably merge error

* added deprecation check for seed function in env

* updated docstring

* removed debug prints, tweaked test_check_seed_deprecation

* changed return_info deprecation check from assertion to warning

* fixes to vector envs, now  should be correctly structured

* added some explanation and typehints for mockup depcreated return info reset function

* re-removed seed function from vector envs

* added explanation to _reset_return_info_type and changed the return statement
2022-08-23 11:09:54 -04:00

89 lines
2.3 KiB
Python

import numpy as np
import pytest
import gym
from gym.spaces import Discrete
from gym.vector import AsyncVectorEnv, SyncVectorEnv
class OldStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def reset(self):
return 0, {}
def step(self, action):
obs = self.observation_space.sample()
rew = 0
done = False
info = {}
return obs, rew, done, info
class NewStepEnv(gym.Env):
def __init__(self):
self.action_space = Discrete(2)
self.observation_space = Discrete(2)
def reset(self):
return 0, {}
def step(self, action):
obs = self.observation_space.sample()
rew = 0
terminated = False
truncated = False
info = {}
return obs, rew, terminated, truncated, info
@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv])
def test_vector_step_compatibility_new_env(VecEnv):
envs = [
OldStepEnv(),
NewStepEnv(),
]
vec_env = VecEnv([lambda: env for env in envs])
vec_env.reset()
step_returns = vec_env.step([0, 0])
assert len(step_returns) == 4
_, _, dones, _ = step_returns
assert dones.dtype == np.bool_
vec_env.close()
vec_env = VecEnv([lambda: env for env in envs], new_step_api=True)
vec_env.reset()
step_returns = vec_env.step([0, 0])
assert len(step_returns) == 5
_, _, terminateds, truncateds, _ = step_returns
assert terminateds.dtype == np.bool_
assert truncateds.dtype == np.bool_
vec_env.close()
@pytest.mark.parametrize("async_bool", [True, False])
def test_vector_step_compatibility_existing(async_bool):
env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool)
env.reset()
step_returns = env.step(env.action_space.sample())
assert len(step_returns) == 4
_, _, dones, _ = step_returns
assert dones.dtype == np.bool_
env.close()
env = gym.vector.make(
"CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True
)
env.reset()
step_returns = env.step(env.action_space.sample())
assert len(step_returns) == 5
_, _, terminateds, truncateds, _ = step_returns
assert terminateds.dtype == np.bool_
assert truncateds.dtype == np.bool_
env.close()