mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-02 10:35:52 +00:00
Fix Python3 compat of import dependencies
This commit is contained in:
@@ -3,7 +3,7 @@ from gym.spaces import Discrete, Tuple
|
||||
from gym.utils import colorize
|
||||
import numpy as np
|
||||
import random
|
||||
import StringIO
|
||||
from six import StringIO
|
||||
import sys
|
||||
import math
|
||||
|
||||
@@ -91,7 +91,7 @@ class AlgorithmicEnv(Env):
|
||||
# Nothing interesting to close
|
||||
return
|
||||
|
||||
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
|
||||
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||
inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time)
|
||||
outfile.write(inp)
|
||||
x, y, action = self.x, self.y, self.last_action
|
||||
|
@@ -6,8 +6,8 @@ from gym import utils
|
||||
|
||||
try:
|
||||
import atari_py
|
||||
except ImportError:
|
||||
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'")
|
||||
except ImportError as e:
|
||||
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'".format(e))
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -53,7 +53,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
||||
reward = 0.0
|
||||
action = self._action_set[a]
|
||||
num_steps = np.random.randint(2, 5)
|
||||
for _ in xrange(num_steps):
|
||||
for _ in range(num_steps):
|
||||
reward += self.ale.act(action)
|
||||
ob = self._get_obs()
|
||||
|
||||
@@ -80,7 +80,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
||||
self.ale.reset_game()
|
||||
return self._get_obs()
|
||||
|
||||
def _render(self, mode='human', close=False):
|
||||
def _render(self, mode='human', close=False):
|
||||
if close:
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
@@ -93,7 +93,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
|
||||
if self.viewer is None:
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
self.viewer.imshow(img)
|
||||
|
||||
|
||||
def get_action_meanings(self):
|
||||
return [ACTION_MEANING[i] for i in self._action_set]
|
||||
|
||||
@@ -118,4 +118,4 @@ ACTION_MEANING = {
|
||||
15 : "UPLEFTFIRE",
|
||||
16 : "DOWNRIGHTFIRE",
|
||||
17 : "DOWNLEFTFIRE",
|
||||
}
|
||||
}
|
||||
|
@@ -10,6 +10,7 @@ import gym
|
||||
from gym import spaces
|
||||
import StringIO
|
||||
import sys
|
||||
import six
|
||||
|
||||
|
||||
# The coordinate representation of Pachi (and pachi_py) is defined on a board
|
||||
@@ -191,7 +192,7 @@ class GoEnv(gym.Env):
|
||||
self.state = self.state.act(action)
|
||||
except pachi_py.IllegalMove:
|
||||
if self.illegal_move_mode == 'raise':
|
||||
raise
|
||||
six.reraise(*sys.exc_info())
|
||||
elif self.illegal_move_mode == 'lose':
|
||||
# Automatic loss on illegal move
|
||||
self.done = True
|
||||
|
@@ -196,7 +196,7 @@ class FilledPolygon(Geom):
|
||||
|
||||
def make_circle(radius=10, res=30, filled=True):
|
||||
points = []
|
||||
for i in xrange(res):
|
||||
for i in range(res):
|
||||
ang = 2*math.pi*i / res
|
||||
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
|
||||
if filled:
|
||||
|
@@ -64,7 +64,7 @@ class EnvSpec(object):
|
||||
# This likely indicates unsupported kwargs
|
||||
six.reraise(type, """Could not 'make' {} ({}): {}.
|
||||
|
||||
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e.message, self._kwargs), traceback)
|
||||
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e, self._kwargs), traceback)
|
||||
|
||||
# Make the enviroment aware of which spec it came from.
|
||||
env.spec = self
|
||||
|
@@ -27,7 +27,7 @@ def test_random_rollout():
|
||||
for env in [envs.make('CartPole-v0'), envs.make('FrozenLake-v0')]:
|
||||
agent = lambda ob: env.action_space.sample()
|
||||
ob = env.reset()
|
||||
for _ in xrange(10):
|
||||
for _ in range(10):
|
||||
assert env.observation_space.contains(ob)
|
||||
a = agent(ob)
|
||||
assert env.action_space.contains(a)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import StringIO, sys
|
||||
import sys
|
||||
from six import StringIO
|
||||
|
||||
from gym import utils
|
||||
from gym.envs.toy_text import discrete
|
||||
@@ -67,10 +68,10 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
||||
nA = 4
|
||||
nS = nrow * ncol
|
||||
|
||||
isd = (desc == 'S').ravel().astype('float64')
|
||||
isd = np.array(desc == 'S').astype('float64')
|
||||
isd /= isd.sum()
|
||||
|
||||
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)}
|
||||
P = {s : {a : [] for a in range(nA)} for s in range(nS)}
|
||||
|
||||
def to_s(row, col):
|
||||
return row*ncol + col
|
||||
@@ -85,24 +86,24 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
||||
row = max(row-1,0)
|
||||
return (row, col)
|
||||
|
||||
for row in xrange(nrow):
|
||||
for col in xrange(ncol):
|
||||
for row in range(nrow):
|
||||
for col in range(ncol):
|
||||
s = to_s(row, col)
|
||||
for a in xrange(4):
|
||||
for a in range(4):
|
||||
li = P[s][a]
|
||||
if is_slippery:
|
||||
for b in [(a-1)%4, a, (a+1)%4]:
|
||||
newrow, newcol = inc(row, col, b)
|
||||
newstate = to_s(newrow, newcol)
|
||||
letter = desc[newrow, newcol]
|
||||
done = letter in 'GH'
|
||||
done = str(letter) in 'GH'
|
||||
rew = float(letter == 'G')
|
||||
li.append((1.0/3.0, newstate, rew, done))
|
||||
else:
|
||||
newrow, newcol = inc(row, col, a)
|
||||
newstate = to_s(newrow, newcol)
|
||||
letter = desc[newrow, newcol]
|
||||
done = letter in 'GH'
|
||||
done = str(letter) in 'GH'
|
||||
rew = float(letter == 'G')
|
||||
li.append((1.0/3.0, newstate, rew, done))
|
||||
|
||||
@@ -112,7 +113,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
|
||||
if close:
|
||||
return
|
||||
|
||||
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
|
||||
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||
|
||||
row, col = self.s // self.ncol, self.s % self.ncol
|
||||
desc = self.desc.tolist()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import StringIO, sys
|
||||
import sys
|
||||
from six import StringIO
|
||||
|
||||
from gym import spaces, utils
|
||||
from gym.envs.toy_text import discrete
|
||||
@@ -42,12 +43,12 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
maxC = nC-1
|
||||
isd = np.zeros(nS)
|
||||
nA = 6
|
||||
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)}
|
||||
for row in xrange(5):
|
||||
for col in xrange(5):
|
||||
for passidx in xrange(5):
|
||||
for destidx in xrange(4):
|
||||
for a in xrange(nA):
|
||||
P = {s : {a : [] for a in range(nA)} for s in range(nS)}
|
||||
for row in range(5):
|
||||
for col in range(5):
|
||||
for passidx in range(5):
|
||||
for destidx in range(4):
|
||||
for a in range(nA):
|
||||
state = self.encode(row, col, passidx, destidx)
|
||||
# defaults
|
||||
newrow, newcol, newpassidx = row, col, passidx
|
||||
@@ -111,7 +112,7 @@ class TaxiEnv(discrete.DiscreteEnv):
|
||||
if close:
|
||||
return
|
||||
|
||||
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
|
||||
outfile = StringIO() if mode == 'ansi' else sys.stdout
|
||||
|
||||
out = self.desc.copy().tolist()
|
||||
taxirow, taxicol, passidx, destidx = self.decode(self.s)
|
||||
|
@@ -6,7 +6,8 @@ import tempfile
|
||||
import os.path
|
||||
import distutils.spawn
|
||||
import numpy as np
|
||||
import StringIO
|
||||
from six import StringIO
|
||||
import six.moves.urllib as urlparse
|
||||
|
||||
from gym import error
|
||||
|
||||
@@ -179,7 +180,7 @@ class TextEncoder(object):
|
||||
string = None
|
||||
if isinstance(frame, str):
|
||||
string = frame
|
||||
elif isinstance(frame, StringIO.StringIO):
|
||||
elif isinstance(frame, StringIO):
|
||||
string = frame.getvalue()
|
||||
else:
|
||||
raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame))
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import platform
|
||||
import urlparse
|
||||
import six.moves.urllib as urlparse
|
||||
from six import iteritems
|
||||
|
||||
from gym import error, version
|
||||
import gym.scoreboard.client
|
||||
@@ -20,7 +21,7 @@ def _build_api_url(url, query):
|
||||
def _strip_nulls(params):
|
||||
if isinstance(params, dict):
|
||||
stripped = {}
|
||||
for key, value in params.iteritems():
|
||||
for key, value in iteritems(params):
|
||||
value = _strip_nulls(value)
|
||||
if value is not None:
|
||||
stripped[key] = value
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import urllib
|
||||
import warnings
|
||||
import sys
|
||||
from six import string_types
|
||||
from six import iteritems
|
||||
import six.moves.urllib as urllib
|
||||
|
||||
import gym
|
||||
from gym import error
|
||||
@@ -18,7 +20,7 @@ def convert_to_gym_object(resp, api_key):
|
||||
elif isinstance(resp, dict) and not isinstance(resp, GymObject):
|
||||
resp = resp.copy()
|
||||
klass_name = resp.get('object')
|
||||
if isinstance(klass_name, basestring):
|
||||
if isinstance(klass_name, string_types):
|
||||
klass = types.get(klass_name, GymObject)
|
||||
else:
|
||||
klass = GymObject
|
||||
@@ -142,7 +144,7 @@ class GymObject(dict):
|
||||
|
||||
self._transient_values = self._transient_values - set(values)
|
||||
|
||||
for k, v in values.iteritems():
|
||||
for k, v in iteritems(values):
|
||||
super(GymObject, self).__setitem__(
|
||||
k, convert_to_gym_object(v, api_key))
|
||||
|
||||
@@ -164,10 +166,10 @@ class GymObject(dict):
|
||||
def __repr__(self):
|
||||
ident_parts = [type(self).__name__]
|
||||
|
||||
if isinstance(self.get('object'), basestring):
|
||||
if isinstance(self.get('object'), string_types):
|
||||
ident_parts.append(self.get('object'))
|
||||
|
||||
if isinstance(self.get('id'), basestring):
|
||||
if isinstance(self.get('id'), string_types):
|
||||
ident_parts.append('id=%s' % (self.get('id'),))
|
||||
|
||||
unicode_repr = '<%s at %s> JSON: %s' % (
|
||||
@@ -228,7 +230,7 @@ class APIResource(GymObject):
|
||||
raise NotImplementedError(
|
||||
'APIResource is an abstract class. You should perform '
|
||||
'actions on its subclasses (e.g. Charge, Customer)')
|
||||
return str(urllib.quote_plus(cls.__name__.lower()))
|
||||
return str(urllib.parse.quote_plus(cls.__name__.lower()))
|
||||
|
||||
@classmethod
|
||||
def class_path(cls):
|
||||
@@ -243,7 +245,7 @@ class APIResource(GymObject):
|
||||
'has invalid ID: %r' % (type(self).__name__, id), 'id')
|
||||
id = util.utf8(id)
|
||||
base = self.class_path()
|
||||
extn = urllib.quote_plus(id)
|
||||
extn = urllib.parse.quote_plus(id)
|
||||
return "%s/%s" % (base, extn)
|
||||
|
||||
class ListObject(GymObject):
|
||||
@@ -280,7 +282,7 @@ class ListObject(GymObject):
|
||||
def retrieve(self, id, **params):
|
||||
base = self.get('url')
|
||||
id = util.utf8(id)
|
||||
extn = urllib.quote_plus(id)
|
||||
extn = urllib.parse.quote_plus(id)
|
||||
url = "%s/%s" % (base, extn)
|
||||
|
||||
return self.request('get', url, params)
|
||||
|
Reference in New Issue
Block a user