mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-02 10:35:52 +00:00
Fix Python3 compat of import dependencies
This commit is contained in:
@@ -3,7 +3,7 @@ from gym.spaces import Discrete, Tuple
|
|||||||
from gym.utils import colorize
|
from gym.utils import colorize
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import StringIO
|
from six import StringIO
|
||||||
import sys
|
import sys
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ class AlgorithmicEnv(Env):
|
|||||||
# Nothing interesting to close
|
# Nothing interesting to close
|
||||||
return
|
return
|
||||||
|
|
||||||
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
|
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||||
inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time)
|
inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time)
|
||||||
outfile.write(inp)
|
outfile.write(inp)
|
||||||
x, y, action = self.x, self.y, self.last_action
|
x, y, action = self.x, self.y, self.last_action
|
||||||
|
@@ -6,8 +6,8 @@ from gym import utils
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import atari_py
|
import atari_py
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'")
|
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'".format(e))
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -53,7 +53,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
|||||||
reward = 0.0
|
reward = 0.0
|
||||||
action = self._action_set[a]
|
action = self._action_set[a]
|
||||||
num_steps = np.random.randint(2, 5)
|
num_steps = np.random.randint(2, 5)
|
||||||
for _ in xrange(num_steps):
|
for _ in range(num_steps):
|
||||||
reward += self.ale.act(action)
|
reward += self.ale.act(action)
|
||||||
ob = self._get_obs()
|
ob = self._get_obs()
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
|||||||
self.ale.reset_game()
|
self.ale.reset_game()
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def _render(self, mode='human', close=False):
|
def _render(self, mode='human', close=False):
|
||||||
if close:
|
if close:
|
||||||
if self.viewer is not None:
|
if self.viewer is not None:
|
||||||
self.viewer.close()
|
self.viewer.close()
|
||||||
@@ -93,7 +93,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
|||||||
if self.viewer is None:
|
if self.viewer is None:
|
||||||
self.viewer = rendering.SimpleImageViewer()
|
self.viewer = rendering.SimpleImageViewer()
|
||||||
self.viewer.imshow(img)
|
self.viewer.imshow(img)
|
||||||
|
|
||||||
def get_action_meanings(self):
|
def get_action_meanings(self):
|
||||||
return [ACTION_MEANING[i] for i in self._action_set]
|
return [ACTION_MEANING[i] for i in self._action_set]
|
||||||
|
|
||||||
@@ -118,4 +118,4 @@ ACTION_MEANING = {
|
|||||||
15 : "UPLEFTFIRE",
|
15 : "UPLEFTFIRE",
|
||||||
16 : "DOWNRIGHTFIRE",
|
16 : "DOWNRIGHTFIRE",
|
||||||
17 : "DOWNLEFTFIRE",
|
17 : "DOWNLEFTFIRE",
|
||||||
}
|
}
|
||||||
|
@@ -10,6 +10,7 @@ import gym
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
import StringIO
|
import StringIO
|
||||||
import sys
|
import sys
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
# The coordinate representation of Pachi (and pachi_py) is defined on a board
|
# The coordinate representation of Pachi (and pachi_py) is defined on a board
|
||||||
@@ -191,7 +192,7 @@ class GoEnv(gym.Env):
|
|||||||
self.state = self.state.act(action)
|
self.state = self.state.act(action)
|
||||||
except pachi_py.IllegalMove:
|
except pachi_py.IllegalMove:
|
||||||
if self.illegal_move_mode == 'raise':
|
if self.illegal_move_mode == 'raise':
|
||||||
raise
|
six.reraise(*sys.exc_info())
|
||||||
elif self.illegal_move_mode == 'lose':
|
elif self.illegal_move_mode == 'lose':
|
||||||
# Automatic loss on illegal move
|
# Automatic loss on illegal move
|
||||||
self.done = True
|
self.done = True
|
||||||
|
@@ -196,7 +196,7 @@ class FilledPolygon(Geom):
|
|||||||
|
|
||||||
def make_circle(radius=10, res=30, filled=True):
|
def make_circle(radius=10, res=30, filled=True):
|
||||||
points = []
|
points = []
|
||||||
for i in xrange(res):
|
for i in range(res):
|
||||||
ang = 2*math.pi*i / res
|
ang = 2*math.pi*i / res
|
||||||
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
|
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
|
||||||
if filled:
|
if filled:
|
||||||
|
@@ -64,7 +64,7 @@ class EnvSpec(object):
|
|||||||
# This likely indicates unsupported kwargs
|
# This likely indicates unsupported kwargs
|
||||||
six.reraise(type, """Could not 'make' {} ({}): {}.
|
six.reraise(type, """Could not 'make' {} ({}): {}.
|
||||||
|
|
||||||
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e.message, self._kwargs), traceback)
|
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e, self._kwargs), traceback)
|
||||||
|
|
||||||
# Make the enviroment aware of which spec it came from.
|
# Make the enviroment aware of which spec it came from.
|
||||||
env.spec = self
|
env.spec = self
|
||||||
|
@@ -27,7 +27,7 @@ def test_random_rollout():
|
|||||||
for env in [envs.make('CartPole-v0'), envs.make('FrozenLake-v0')]:
|
for env in [envs.make('CartPole-v0'), envs.make('FrozenLake-v0')]:
|
||||||
agent = lambda ob: env.action_space.sample()
|
agent = lambda ob: env.action_space.sample()
|
||||||
ob = env.reset()
|
ob = env.reset()
|
||||||
for _ in xrange(10):
|
for _ in range(10):
|
||||||
assert env.observation_space.contains(ob)
|
assert env.observation_space.contains(ob)
|
||||||
a = agent(ob)
|
a = agent(ob)
|
||||||
assert env.action_space.contains(a)
|
assert env.action_space.contains(a)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import StringIO, sys
|
import sys
|
||||||
|
from six import StringIO
|
||||||
|
|
||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.toy_text import discrete
|
from gym.envs.toy_text import discrete
|
||||||
@@ -67,10 +68,10 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
|||||||
nA = 4
|
nA = 4
|
||||||
nS = nrow * ncol
|
nS = nrow * ncol
|
||||||
|
|
||||||
isd = (desc == 'S').ravel().astype('float64')
|
isd = np.array(desc == 'S').astype('float64')
|
||||||
isd /= isd.sum()
|
isd /= isd.sum()
|
||||||
|
|
||||||
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)}
|
P = {s : {a : [] for a in range(nA)} for s in range(nS)}
|
||||||
|
|
||||||
def to_s(row, col):
|
def to_s(row, col):
|
||||||
return row*ncol + col
|
return row*ncol + col
|
||||||
@@ -85,24 +86,24 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
|||||||
row = max(row-1,0)
|
row = max(row-1,0)
|
||||||
return (row, col)
|
return (row, col)
|
||||||
|
|
||||||
for row in xrange(nrow):
|
for row in range(nrow):
|
||||||
for col in xrange(ncol):
|
for col in range(ncol):
|
||||||
s = to_s(row, col)
|
s = to_s(row, col)
|
||||||
for a in xrange(4):
|
for a in range(4):
|
||||||
li = P[s][a]
|
li = P[s][a]
|
||||||
if is_slippery:
|
if is_slippery:
|
||||||
for b in [(a-1)%4, a, (a+1)%4]:
|
for b in [(a-1)%4, a, (a+1)%4]:
|
||||||
newrow, newcol = inc(row, col, b)
|
newrow, newcol = inc(row, col, b)
|
||||||
newstate = to_s(newrow, newcol)
|
newstate = to_s(newrow, newcol)
|
||||||
letter = desc[newrow, newcol]
|
letter = desc[newrow, newcol]
|
||||||
done = letter in 'GH'
|
done = str(letter) in 'GH'
|
||||||
rew = float(letter == 'G')
|
rew = float(letter == 'G')
|
||||||
li.append((1.0/3.0, newstate, rew, done))
|
li.append((1.0/3.0, newstate, rew, done))
|
||||||
else:
|
else:
|
||||||
newrow, newcol = inc(row, col, a)
|
newrow, newcol = inc(row, col, a)
|
||||||
newstate = to_s(newrow, newcol)
|
newstate = to_s(newrow, newcol)
|
||||||
letter = desc[newrow, newcol]
|
letter = desc[newrow, newcol]
|
||||||
done = letter in 'GH'
|
done = str(letter) in 'GH'
|
||||||
rew = float(letter == 'G')
|
rew = float(letter == 'G')
|
||||||
li.append((1.0/3.0, newstate, rew, done))
|
li.append((1.0/3.0, newstate, rew, done))
|
||||||
|
|
||||||
@@ -112,7 +113,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
|||||||
if close:
|
if close:
|
||||||
return
|
return
|
||||||
|
|
||||||
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
|
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||||
|
|
||||||
row, col = self.s // self.ncol, self.s % self.ncol
|
row, col = self.s // self.ncol, self.s % self.ncol
|
||||||
desc = self.desc.tolist()
|
desc = self.desc.tolist()
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import StringIO, sys
|
import sys
|
||||||
|
from six import StringIO
|
||||||
|
|
||||||
from gym import spaces, utils
|
from gym import spaces, utils
|
||||||
from gym.envs.toy_text import discrete
|
from gym.envs.toy_text import discrete
|
||||||
@@ -42,12 +43,12 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
maxC = nC-1
|
maxC = nC-1
|
||||||
isd = np.zeros(nS)
|
isd = np.zeros(nS)
|
||||||
nA = 6
|
nA = 6
|
||||||
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)}
|
P = {s : {a : [] for a in range(nA)} for s in range(nS)}
|
||||||
for row in xrange(5):
|
for row in range(5):
|
||||||
for col in xrange(5):
|
for col in range(5):
|
||||||
for passidx in xrange(5):
|
for passidx in range(5):
|
||||||
for destidx in xrange(4):
|
for destidx in range(4):
|
||||||
for a in xrange(nA):
|
for a in range(nA):
|
||||||
state = self.encode(row, col, passidx, destidx)
|
state = self.encode(row, col, passidx, destidx)
|
||||||
# defaults
|
# defaults
|
||||||
newrow, newcol, newpassidx = row, col, passidx
|
newrow, newcol, newpassidx = row, col, passidx
|
||||||
@@ -111,7 +112,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
|||||||
if close:
|
if close:
|
||||||
return
|
return
|
||||||
|
|
||||||
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
|
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||||
|
|
||||||
out = self.desc.copy().tolist()
|
out = self.desc.copy().tolist()
|
||||||
taxirow, taxicol, passidx, destidx = self.decode(self.s)
|
taxirow, taxicol, passidx, destidx = self.decode(self.s)
|
||||||
|
@@ -6,7 +6,8 @@ import tempfile
|
|||||||
import os.path
|
import os.path
|
||||||
import distutils.spawn
|
import distutils.spawn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import StringIO
|
from six import StringIO
|
||||||
|
import six.moves.urllib as urlparse
|
||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
|
|
||||||
@@ -179,7 +180,7 @@ class TextEncoder(object):
|
|||||||
string = None
|
string = None
|
||||||
if isinstance(frame, str):
|
if isinstance(frame, str):
|
||||||
string = frame
|
string = frame
|
||||||
elif isinstance(frame, StringIO.StringIO):
|
elif isinstance(frame, StringIO):
|
||||||
string = frame.getvalue()
|
string = frame.getvalue()
|
||||||
else:
|
else:
|
||||||
raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame))
|
raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame))
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import platform
|
import platform
|
||||||
import urlparse
|
import six.moves.urllib as urlparse
|
||||||
|
from six import iteritems
|
||||||
|
|
||||||
from gym import error, version
|
from gym import error, version
|
||||||
import gym.scoreboard.client
|
import gym.scoreboard.client
|
||||||
@@ -20,7 +21,7 @@ def _build_api_url(url, query):
|
|||||||
def _strip_nulls(params):
|
def _strip_nulls(params):
|
||||||
if isinstance(params, dict):
|
if isinstance(params, dict):
|
||||||
stripped = {}
|
stripped = {}
|
||||||
for key, value in params.iteritems():
|
for key, value in iteritems(params):
|
||||||
value = _strip_nulls(value)
|
value = _strip_nulls(value)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
stripped[key] = value
|
stripped[key] = value
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import urllib
|
|
||||||
import warnings
|
import warnings
|
||||||
import sys
|
import sys
|
||||||
|
from six import string_types
|
||||||
|
from six import iteritems
|
||||||
|
import six.moves.urllib as urllib
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import error
|
from gym import error
|
||||||
@@ -18,7 +20,7 @@ def convert_to_gym_object(resp, api_key):
|
|||||||
elif isinstance(resp, dict) and not isinstance(resp, GymObject):
|
elif isinstance(resp, dict) and not isinstance(resp, GymObject):
|
||||||
resp = resp.copy()
|
resp = resp.copy()
|
||||||
klass_name = resp.get('object')
|
klass_name = resp.get('object')
|
||||||
if isinstance(klass_name, basestring):
|
if isinstance(klass_name, string_types):
|
||||||
klass = types.get(klass_name, GymObject)
|
klass = types.get(klass_name, GymObject)
|
||||||
else:
|
else:
|
||||||
klass = GymObject
|
klass = GymObject
|
||||||
@@ -142,7 +144,7 @@ class GymObject(dict):
|
|||||||
|
|
||||||
self._transient_values = self._transient_values - set(values)
|
self._transient_values = self._transient_values - set(values)
|
||||||
|
|
||||||
for k, v in values.iteritems():
|
for k, v in iteritems(values):
|
||||||
super(GymObject, self).__setitem__(
|
super(GymObject, self).__setitem__(
|
||||||
k, convert_to_gym_object(v, api_key))
|
k, convert_to_gym_object(v, api_key))
|
||||||
|
|
||||||
@@ -164,10 +166,10 @@ class GymObject(dict):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
ident_parts = [type(self).__name__]
|
ident_parts = [type(self).__name__]
|
||||||
|
|
||||||
if isinstance(self.get('object'), basestring):
|
if isinstance(self.get('object'), string_types):
|
||||||
ident_parts.append(self.get('object'))
|
ident_parts.append(self.get('object'))
|
||||||
|
|
||||||
if isinstance(self.get('id'), basestring):
|
if isinstance(self.get('id'), string_types):
|
||||||
ident_parts.append('id=%s' % (self.get('id'),))
|
ident_parts.append('id=%s' % (self.get('id'),))
|
||||||
|
|
||||||
unicode_repr = '<%s at %s> JSON: %s' % (
|
unicode_repr = '<%s at %s> JSON: %s' % (
|
||||||
@@ -228,7 +230,7 @@ class APIResource(GymObject):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'APIResource is an abstract class. You should perform '
|
'APIResource is an abstract class. You should perform '
|
||||||
'actions on its subclasses (e.g. Charge, Customer)')
|
'actions on its subclasses (e.g. Charge, Customer)')
|
||||||
return str(urllib.quote_plus(cls.__name__.lower()))
|
return str(urllib.parse.quote_plus(cls.__name__.lower()))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_path(cls):
|
def class_path(cls):
|
||||||
@@ -243,7 +245,7 @@ class APIResource(GymObject):
|
|||||||
'has invalid ID: %r' % (type(self).__name__, id), 'id')
|
'has invalid ID: %r' % (type(self).__name__, id), 'id')
|
||||||
id = util.utf8(id)
|
id = util.utf8(id)
|
||||||
base = self.class_path()
|
base = self.class_path()
|
||||||
extn = urllib.quote_plus(id)
|
extn = urllib.parse.quote_plus(id)
|
||||||
return "%s/%s" % (base, extn)
|
return "%s/%s" % (base, extn)
|
||||||
|
|
||||||
class ListObject(GymObject):
|
class ListObject(GymObject):
|
||||||
@@ -280,7 +282,7 @@ class ListObject(GymObject):
|
|||||||
def retrieve(self, id, **params):
|
def retrieve(self, id, **params):
|
||||||
base = self.get('url')
|
base = self.get('url')
|
||||||
id = util.utf8(id)
|
id = util.utf8(id)
|
||||||
extn = urllib.quote_plus(id)
|
extn = urllib.parse.quote_plus(id)
|
||||||
url = "%s/%s" % (base, extn)
|
url = "%s/%s" % (base, extn)
|
||||||
|
|
||||||
return self.request('get', url, params)
|
return self.request('get', url, params)
|
||||||
|
Reference in New Issue
Block a user