2016-04-27 08:00:58 -07:00
|
|
|
# -*- coding: utf-8 -*-
|
2019-01-29 13:37:43 -08:00
|
|
|
import gym
|
2016-04-27 08:00:58 -07:00
|
|
|
from gym import error, envs
|
|
|
|
from gym.envs import registration
|
|
|
|
from gym.envs.classic_control import cartpole
|
|
|
|
|
2019-01-29 13:37:43 -08:00
|
|
|
class ArgumentEnv(gym.Env):
|
|
|
|
def __init__(self, arg1, arg2, arg3):
|
|
|
|
self.arg1 = arg1
|
|
|
|
self.arg2 = arg2
|
|
|
|
self.arg3 = arg3
|
|
|
|
|
|
|
|
gym.register(
|
|
|
|
id='test.ArgumentEnv-v0',
|
|
|
|
entry_point='gym.envs.tests.test_registration:ArgumentEnv',
|
|
|
|
kwargs={
|
|
|
|
'arg1': 'arg1',
|
|
|
|
'arg2': 'arg2',
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def test_make():
|
|
|
|
env = envs.make('CartPole-v0')
|
|
|
|
assert env.spec.id == 'CartPole-v0'
|
2017-02-01 13:10:59 -08:00
|
|
|
assert isinstance(env.unwrapped, cartpole.CartPoleEnv)
|
2016-04-27 08:00:58 -07:00
|
|
|
|
2019-01-29 13:37:43 -08:00
|
|
|
def test_make_with_kwargs():
|
|
|
|
env = envs.make('test.ArgumentEnv-v0', arg2='override_arg2', arg3='override_arg3')
|
|
|
|
assert env.spec.id == 'test.ArgumentEnv-v0'
|
|
|
|
assert isinstance(env.unwrapped, ArgumentEnv)
|
|
|
|
assert env.arg1 == 'arg1'
|
|
|
|
assert env.arg2 == 'override_arg2'
|
|
|
|
assert env.arg3 == 'override_arg3'
|
|
|
|
|
2016-05-06 22:26:40 -07:00
|
|
|
def test_make_deprecated():
|
|
|
|
try:
|
|
|
|
envs.make('Humanoid-v0')
|
|
|
|
except error.Error:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def test_spec():
|
|
|
|
spec = envs.spec('CartPole-v0')
|
|
|
|
assert spec.id == 'CartPole-v0'
|
|
|
|
|
2020-04-24 23:49:41 +02:00
|
|
|
def test_spec_with_kwargs():
|
|
|
|
map_name_value = '8x8'
|
|
|
|
env = gym.make('FrozenLake-v0', map_name=map_name_value)
|
|
|
|
assert env.spec._kwargs['map_name'] == map_name_value
|
|
|
|
|
2016-04-27 08:00:58 -07:00
|
|
|
def test_missing_lookup():
|
|
|
|
registry = registration.EnvRegistry()
|
|
|
|
registry.register(id='Test-v0', entry_point=None)
|
|
|
|
registry.register(id='Test-v15', entry_point=None)
|
|
|
|
registry.register(id='Test-v9', entry_point=None)
|
|
|
|
registry.register(id='Other-v100', entry_point=None)
|
|
|
|
try:
|
2016-05-18 16:24:54 -07:00
|
|
|
registry.spec('Test-v1') # must match an env name but not the version above
|
|
|
|
except error.DeprecatedEnv:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
try:
|
|
|
|
registry.spec('Unknown-v1')
|
2016-04-27 08:00:58 -07:00
|
|
|
except error.UnregisteredEnv:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
def test_malformed_lookup():
|
|
|
|
registry = registration.EnvRegistry()
|
|
|
|
try:
|
|
|
|
registry.spec(u'“Breakout-v0”')
|
|
|
|
except error.Error as e:
|
2016-04-28 09:38:57 -07:00
|
|
|
assert 'malformed environment ID' in '{}'.format(e), 'Unexpected message: {}'.format(e)
|
2016-04-27 08:00:58 -07:00
|
|
|
else:
|
|
|
|
assert False
|