Fix Python3 compat of import dependencies

This commit is contained in:
Jonas Schneider
2016-04-27 18:03:29 -07:00
parent 7d2630b82e
commit 5065950a09
11 changed files with 48 additions and 41 deletions

View File

@@ -3,7 +3,7 @@ from gym.spaces import Discrete, Tuple
from gym.utils import colorize
import numpy as np
import random
import StringIO
from six import StringIO
import sys
import math
@@ -91,7 +91,7 @@ class AlgorithmicEnv(Env):
# Nothing interesting to close
return
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
outfile = StringIO() if mode == 'ansi' else sys.stdout
inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time)
outfile.write(inp)
x, y, action = self.x, self.y, self.last_action

View File

@@ -6,8 +6,8 @@ from gym import utils
try:
import atari_py
except ImportError:
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'")
except ImportError as e:
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'".format(e))
import logging
logger = logging.getLogger(__name__)
@@ -53,7 +53,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
reward = 0.0
action = self._action_set[a]
num_steps = np.random.randint(2, 5)
for _ in xrange(num_steps):
for _ in range(num_steps):
reward += self.ale.act(action)
ob = self._get_obs()
@@ -80,7 +80,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
self.ale.reset_game()
return self._get_obs()
def _render(self, mode='human', close=False):
def _render(self, mode='human', close=False):
if close:
if self.viewer is not None:
self.viewer.close()
@@ -93,7 +93,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
if self.viewer is None:
self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(img)
def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self._action_set]
@@ -118,4 +118,4 @@ ACTION_MEANING = {
15 : "UPLEFTFIRE",
16 : "DOWNRIGHTFIRE",
17 : "DOWNLEFTFIRE",
}
}

View File

@@ -10,6 +10,7 @@ import gym
from gym import spaces
import StringIO
import sys
import six
# The coordinate representation of Pachi (and pachi_py) is defined on a board
@@ -191,7 +192,7 @@ class GoEnv(gym.Env):
self.state = self.state.act(action)
except pachi_py.IllegalMove:
if self.illegal_move_mode == 'raise':
raise
six.reraise(*sys.exc_info())
elif self.illegal_move_mode == 'lose':
# Automatic loss on illegal move
self.done = True

View File

@@ -196,7 +196,7 @@ class FilledPolygon(Geom):
def make_circle(radius=10, res=30, filled=True):
points = []
for i in xrange(res):
for i in range(res):
ang = 2*math.pi*i / res
points.append((math.cos(ang)*radius, math.sin(ang)*radius))
if filled:

View File

@@ -64,7 +64,7 @@ class EnvSpec(object):
# This likely indicates unsupported kwargs
six.reraise(type, """Could not 'make' {} ({}): {}.
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e.message, self._kwargs), traceback)
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e, self._kwargs), traceback)
# Make the enviroment aware of which spec it came from.
env.spec = self

View File

@@ -27,7 +27,7 @@ def test_random_rollout():
for env in [envs.make('CartPole-v0'), envs.make('FrozenLake-v0')]:
agent = lambda ob: env.action_space.sample()
ob = env.reset()
for _ in xrange(10):
for _ in range(10):
assert env.observation_space.contains(ob)
a = agent(ob)
assert env.action_space.contains(a)

View File

@@ -1,5 +1,6 @@
import numpy as np
import StringIO, sys
import sys
from six import StringIO
from gym import utils
from gym.envs.toy_text import discrete
@@ -67,10 +68,10 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
nA = 4
nS = nrow * ncol
isd = (desc == 'S').ravel().astype('float64')
isd = np.array(desc == 'S').astype('float64')
isd /= isd.sum()
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)}
P = {s : {a : [] for a in range(nA)} for s in range(nS)}
def to_s(row, col):
return row*ncol + col
@@ -85,24 +86,24 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
row = max(row-1,0)
return (row, col)
for row in xrange(nrow):
for col in xrange(ncol):
for row in range(nrow):
for col in range(ncol):
s = to_s(row, col)
for a in xrange(4):
for a in range(4):
li = P[s][a]
if is_slippery:
for b in [(a-1)%4, a, (a+1)%4]:
newrow, newcol = inc(row, col, b)
newstate = to_s(newrow, newcol)
letter = desc[newrow, newcol]
done = letter in 'GH'
done = str(letter) in 'GH'
rew = float(letter == 'G')
li.append((1.0/3.0, newstate, rew, done))
else:
newrow, newcol = inc(row, col, a)
newstate = to_s(newrow, newcol)
letter = desc[newrow, newcol]
done = letter in 'GH'
done = str(letter) in 'GH'
rew = float(letter == 'G')
li.append((1.0/3.0, newstate, rew, done))
@@ -112,7 +113,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
if close:
return
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
outfile = StringIO() if mode == 'ansi' else sys.stdout
row, col = self.s // self.ncol, self.s % self.ncol
desc = self.desc.tolist()

View File

@@ -1,5 +1,6 @@
import numpy as np
import StringIO, sys
import sys
from six import StringIO
from gym import spaces, utils
from gym.envs.toy_text import discrete
@@ -42,12 +43,12 @@ class TaxiEnv(discrete.DiscreteEnv):
maxC = nC-1
isd = np.zeros(nS)
nA = 6
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)}
for row in xrange(5):
for col in xrange(5):
for passidx in xrange(5):
for destidx in xrange(4):
for a in xrange(nA):
P = {s : {a : [] for a in range(nA)} for s in range(nS)}
for row in range(5):
for col in range(5):
for passidx in range(5):
for destidx in range(4):
for a in range(nA):
state = self.encode(row, col, passidx, destidx)
# defaults
newrow, newcol, newpassidx = row, col, passidx
@@ -111,7 +112,7 @@ class TaxiEnv(discrete.DiscreteEnv):
if close:
return
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout
outfile = StringIO() if mode == 'ansi' else sys.stdout
out = self.desc.copy().tolist()
taxirow, taxicol, passidx, destidx = self.decode(self.s)

View File

@@ -6,7 +6,8 @@ import tempfile
import os.path
import distutils.spawn
import numpy as np
import StringIO
from six import StringIO
import six.moves.urllib as urlparse
from gym import error
@@ -179,7 +180,7 @@ class TextEncoder(object):
string = None
if isinstance(frame, str):
string = frame
elif isinstance(frame, StringIO.StringIO):
elif isinstance(frame, StringIO):
string = frame.getvalue()
else:
raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame))

View File

@@ -1,6 +1,7 @@
import json
import platform
import urlparse
import six.moves.urllib as urlparse
from six import iteritems
from gym import error, version
import gym.scoreboard.client
@@ -20,7 +21,7 @@ def _build_api_url(url, query):
def _strip_nulls(params):
if isinstance(params, dict):
stripped = {}
for key, value in params.iteritems():
for key, value in iteritems(params):
value = _strip_nulls(value)
if value is not None:
stripped[key] = value

View File

@@ -1,7 +1,9 @@
import json
import urllib
import warnings
import sys
from six import string_types
from six import iteritems
import six.moves.urllib as urllib
import gym
from gym import error
@@ -18,7 +20,7 @@ def convert_to_gym_object(resp, api_key):
elif isinstance(resp, dict) and not isinstance(resp, GymObject):
resp = resp.copy()
klass_name = resp.get('object')
if isinstance(klass_name, basestring):
if isinstance(klass_name, string_types):
klass = types.get(klass_name, GymObject)
else:
klass = GymObject
@@ -142,7 +144,7 @@ class GymObject(dict):
self._transient_values = self._transient_values - set(values)
for k, v in values.iteritems():
for k, v in iteritems(values):
super(GymObject, self).__setitem__(
k, convert_to_gym_object(v, api_key))
@@ -164,10 +166,10 @@ class GymObject(dict):
def __repr__(self):
ident_parts = [type(self).__name__]
if isinstance(self.get('object'), basestring):
if isinstance(self.get('object'), string_types):
ident_parts.append(self.get('object'))
if isinstance(self.get('id'), basestring):
if isinstance(self.get('id'), string_types):
ident_parts.append('id=%s' % (self.get('id'),))
unicode_repr = '<%s at %s> JSON: %s' % (
@@ -228,7 +230,7 @@ class APIResource(GymObject):
raise NotImplementedError(
'APIResource is an abstract class. You should perform '
'actions on its subclasses (e.g. Charge, Customer)')
return str(urllib.quote_plus(cls.__name__.lower()))
return str(urllib.parse.quote_plus(cls.__name__.lower()))
@classmethod
def class_path(cls):
@@ -243,7 +245,7 @@ class APIResource(GymObject):
'has invalid ID: %r' % (type(self).__name__, id), 'id')
id = util.utf8(id)
base = self.class_path()
extn = urllib.quote_plus(id)
extn = urllib.parse.quote_plus(id)
return "%s/%s" % (base, extn)
class ListObject(GymObject):
@@ -280,7 +282,7 @@ class ListObject(GymObject):
def retrieve(self, id, **params):
base = self.get('url')
id = util.utf8(id)
extn = urllib.quote_plus(id)
extn = urllib.parse.quote_plus(id)
url = "%s/%s" % (base, extn)
return self.request('get', url, params)