Fix Python3 compat of import dependencies

This commit is contained in:
Jonas Schneider
2016-04-27 18:03:29 -07:00
parent 7d2630b82e
commit 5065950a09
11 changed files with 48 additions and 41 deletions

View File

@@ -3,7 +3,7 @@ from gym.spaces import Discrete, Tuple
from gym.utils import colorize from gym.utils import colorize
import numpy as np import numpy as np
import random import random
import StringIO from six import StringIO
import sys import sys
import math import math
@@ -91,7 +91,7 @@ class AlgorithmicEnv(Env):
# Nothing interesting to close # Nothing interesting to close
return return
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout outfile = StringIO() if mode == 'ansi' else sys.stdout
inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time) inp = "Total length of input instance: %d, step: %d\n" % (self.total_len, self.time)
outfile.write(inp) outfile.write(inp)
x, y, action = self.x, self.y, self.last_action x, y, action = self.x, self.y, self.last_action

View File

@@ -6,8 +6,8 @@ from gym import utils
try: try:
import atari_py import atari_py
except ImportError: except ImportError as e:
raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'") raise error.DependencyNotInstalled("{}. (HINT: you can install Atari dependencies with 'pip install gym[atari].)'".format(e))
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -53,7 +53,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
reward = 0.0 reward = 0.0
action = self._action_set[a] action = self._action_set[a]
num_steps = np.random.randint(2, 5) num_steps = np.random.randint(2, 5)
for _ in xrange(num_steps): for _ in range(num_steps):
reward += self.ale.act(action) reward += self.ale.act(action)
ob = self._get_obs() ob = self._get_obs()
@@ -80,7 +80,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
self.ale.reset_game() self.ale.reset_game()
return self._get_obs() return self._get_obs()
def _render(self, mode='human', close=False): def _render(self, mode='human', close=False):
if close: if close:
if self.viewer is not None: if self.viewer is not None:
self.viewer.close() self.viewer.close()
@@ -93,7 +93,7 @@ class AtariEnv(gym.Env, utils.EzPickle):
if self.viewer is None: if self.viewer is None:
self.viewer = rendering.SimpleImageViewer() self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(img) self.viewer.imshow(img)
def get_action_meanings(self): def get_action_meanings(self):
return [ACTION_MEANING[i] for i in self._action_set] return [ACTION_MEANING[i] for i in self._action_set]
@@ -118,4 +118,4 @@ ACTION_MEANING = {
15 : "UPLEFTFIRE", 15 : "UPLEFTFIRE",
16 : "DOWNRIGHTFIRE", 16 : "DOWNRIGHTFIRE",
17 : "DOWNLEFTFIRE", 17 : "DOWNLEFTFIRE",
} }

View File

@@ -10,6 +10,7 @@ import gym
from gym import spaces from gym import spaces
import StringIO import StringIO
import sys import sys
import six
# The coordinate representation of Pachi (and pachi_py) is defined on a board # The coordinate representation of Pachi (and pachi_py) is defined on a board
@@ -191,7 +192,7 @@ class GoEnv(gym.Env):
self.state = self.state.act(action) self.state = self.state.act(action)
except pachi_py.IllegalMove: except pachi_py.IllegalMove:
if self.illegal_move_mode == 'raise': if self.illegal_move_mode == 'raise':
raise six.reraise(*sys.exc_info())
elif self.illegal_move_mode == 'lose': elif self.illegal_move_mode == 'lose':
# Automatic loss on illegal move # Automatic loss on illegal move
self.done = True self.done = True

View File

@@ -196,7 +196,7 @@ class FilledPolygon(Geom):
def make_circle(radius=10, res=30, filled=True): def make_circle(radius=10, res=30, filled=True):
points = [] points = []
for i in xrange(res): for i in range(res):
ang = 2*math.pi*i / res ang = 2*math.pi*i / res
points.append((math.cos(ang)*radius, math.sin(ang)*radius)) points.append((math.cos(ang)*radius, math.sin(ang)*radius))
if filled: if filled:

View File

@@ -64,7 +64,7 @@ class EnvSpec(object):
# This likely indicates unsupported kwargs # This likely indicates unsupported kwargs
six.reraise(type, """Could not 'make' {} ({}): {}. six.reraise(type, """Could not 'make' {} ({}): {}.
(For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e.message, self._kwargs), traceback) (For reference, the environment was instantiated with kwargs: {}).""".format(self.id, cls, e, self._kwargs), traceback)
# Make the enviroment aware of which spec it came from. # Make the enviroment aware of which spec it came from.
env.spec = self env.spec = self

View File

@@ -27,7 +27,7 @@ def test_random_rollout():
for env in [envs.make('CartPole-v0'), envs.make('FrozenLake-v0')]: for env in [envs.make('CartPole-v0'), envs.make('FrozenLake-v0')]:
agent = lambda ob: env.action_space.sample() agent = lambda ob: env.action_space.sample()
ob = env.reset() ob = env.reset()
for _ in xrange(10): for _ in range(10):
assert env.observation_space.contains(ob) assert env.observation_space.contains(ob)
a = agent(ob) a = agent(ob)
assert env.action_space.contains(a) assert env.action_space.contains(a)

View File

@@ -1,5 +1,6 @@
import numpy as np import numpy as np
import StringIO, sys import sys
from six import StringIO
from gym import utils from gym import utils
from gym.envs.toy_text import discrete from gym.envs.toy_text import discrete
@@ -67,10 +68,10 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
nA = 4 nA = 4
nS = nrow * ncol nS = nrow * ncol
isd = (desc == 'S').ravel().astype('float64') isd = np.array(desc == 'S').astype('float64')
isd /= isd.sum() isd /= isd.sum()
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)} P = {s : {a : [] for a in range(nA)} for s in range(nS)}
def to_s(row, col): def to_s(row, col):
return row*ncol + col return row*ncol + col
@@ -85,24 +86,24 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
row = max(row-1,0) row = max(row-1,0)
return (row, col) return (row, col)
for row in xrange(nrow): for row in range(nrow):
for col in xrange(ncol): for col in range(ncol):
s = to_s(row, col) s = to_s(row, col)
for a in xrange(4): for a in range(4):
li = P[s][a] li = P[s][a]
if is_slippery: if is_slippery:
for b in [(a-1)%4, a, (a+1)%4]: for b in [(a-1)%4, a, (a+1)%4]:
newrow, newcol = inc(row, col, b) newrow, newcol = inc(row, col, b)
newstate = to_s(newrow, newcol) newstate = to_s(newrow, newcol)
letter = desc[newrow, newcol] letter = desc[newrow, newcol]
done = letter in 'GH' done = str(letter) in 'GH'
rew = float(letter == 'G') rew = float(letter == 'G')
li.append((1.0/3.0, newstate, rew, done)) li.append((1.0/3.0, newstate, rew, done))
else: else:
newrow, newcol = inc(row, col, a) newrow, newcol = inc(row, col, a)
newstate = to_s(newrow, newcol) newstate = to_s(newrow, newcol)
letter = desc[newrow, newcol] letter = desc[newrow, newcol]
done = letter in 'GH' done = str(letter) in 'GH'
rew = float(letter == 'G') rew = float(letter == 'G')
li.append((1.0/3.0, newstate, rew, done)) li.append((1.0/3.0, newstate, rew, done))
@@ -112,7 +113,7 @@ class FrozenLakeEnv(discrete.DiscreteEnv):
if close: if close:
return return
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout outfile = StringIO() if mode == 'ansi' else sys.stdout
row, col = self.s // self.ncol, self.s % self.ncol row, col = self.s // self.ncol, self.s % self.ncol
desc = self.desc.tolist() desc = self.desc.tolist()

View File

@@ -1,5 +1,6 @@
import numpy as np import numpy as np
import StringIO, sys import sys
from six import StringIO
from gym import spaces, utils from gym import spaces, utils
from gym.envs.toy_text import discrete from gym.envs.toy_text import discrete
@@ -42,12 +43,12 @@ class TaxiEnv(discrete.DiscreteEnv):
maxC = nC-1 maxC = nC-1
isd = np.zeros(nS) isd = np.zeros(nS)
nA = 6 nA = 6
P = {s : {a : [] for a in xrange(nA)} for s in xrange(nS)} P = {s : {a : [] for a in range(nA)} for s in range(nS)}
for row in xrange(5): for row in range(5):
for col in xrange(5): for col in range(5):
for passidx in xrange(5): for passidx in range(5):
for destidx in xrange(4): for destidx in range(4):
for a in xrange(nA): for a in range(nA):
state = self.encode(row, col, passidx, destidx) state = self.encode(row, col, passidx, destidx)
# defaults # defaults
newrow, newcol, newpassidx = row, col, passidx newrow, newcol, newpassidx = row, col, passidx
@@ -111,7 +112,7 @@ class TaxiEnv(discrete.DiscreteEnv):
if close: if close:
return return
outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout outfile = StringIO() if mode == 'ansi' else sys.stdout
out = self.desc.copy().tolist() out = self.desc.copy().tolist()
taxirow, taxicol, passidx, destidx = self.decode(self.s) taxirow, taxicol, passidx, destidx = self.decode(self.s)

View File

@@ -6,7 +6,8 @@ import tempfile
import os.path import os.path
import distutils.spawn import distutils.spawn
import numpy as np import numpy as np
import StringIO from six import StringIO
import six.moves.urllib as urlparse
from gym import error from gym import error
@@ -179,7 +180,7 @@ class TextEncoder(object):
string = None string = None
if isinstance(frame, str): if isinstance(frame, str):
string = frame string = frame
elif isinstance(frame, StringIO.StringIO): elif isinstance(frame, StringIO):
string = frame.getvalue() string = frame.getvalue()
else: else:
raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame)) raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame))

View File

@@ -1,6 +1,7 @@
import json import json
import platform import platform
import urlparse import six.moves.urllib as urlparse
from six import iteritems
from gym import error, version from gym import error, version
import gym.scoreboard.client import gym.scoreboard.client
@@ -20,7 +21,7 @@ def _build_api_url(url, query):
def _strip_nulls(params): def _strip_nulls(params):
if isinstance(params, dict): if isinstance(params, dict):
stripped = {} stripped = {}
for key, value in params.iteritems(): for key, value in iteritems(params):
value = _strip_nulls(value) value = _strip_nulls(value)
if value is not None: if value is not None:
stripped[key] = value stripped[key] = value

View File

@@ -1,7 +1,9 @@
import json import json
import urllib
import warnings import warnings
import sys import sys
from six import string_types
from six import iteritems
import six.moves.urllib as urllib
import gym import gym
from gym import error from gym import error
@@ -18,7 +20,7 @@ def convert_to_gym_object(resp, api_key):
elif isinstance(resp, dict) and not isinstance(resp, GymObject): elif isinstance(resp, dict) and not isinstance(resp, GymObject):
resp = resp.copy() resp = resp.copy()
klass_name = resp.get('object') klass_name = resp.get('object')
if isinstance(klass_name, basestring): if isinstance(klass_name, string_types):
klass = types.get(klass_name, GymObject) klass = types.get(klass_name, GymObject)
else: else:
klass = GymObject klass = GymObject
@@ -142,7 +144,7 @@ class GymObject(dict):
self._transient_values = self._transient_values - set(values) self._transient_values = self._transient_values - set(values)
for k, v in values.iteritems(): for k, v in iteritems(values):
super(GymObject, self).__setitem__( super(GymObject, self).__setitem__(
k, convert_to_gym_object(v, api_key)) k, convert_to_gym_object(v, api_key))
@@ -164,10 +166,10 @@ class GymObject(dict):
def __repr__(self): def __repr__(self):
ident_parts = [type(self).__name__] ident_parts = [type(self).__name__]
if isinstance(self.get('object'), basestring): if isinstance(self.get('object'), string_types):
ident_parts.append(self.get('object')) ident_parts.append(self.get('object'))
if isinstance(self.get('id'), basestring): if isinstance(self.get('id'), string_types):
ident_parts.append('id=%s' % (self.get('id'),)) ident_parts.append('id=%s' % (self.get('id'),))
unicode_repr = '<%s at %s> JSON: %s' % ( unicode_repr = '<%s at %s> JSON: %s' % (
@@ -228,7 +230,7 @@ class APIResource(GymObject):
raise NotImplementedError( raise NotImplementedError(
'APIResource is an abstract class. You should perform ' 'APIResource is an abstract class. You should perform '
'actions on its subclasses (e.g. Charge, Customer)') 'actions on its subclasses (e.g. Charge, Customer)')
return str(urllib.quote_plus(cls.__name__.lower())) return str(urllib.parse.quote_plus(cls.__name__.lower()))
@classmethod @classmethod
def class_path(cls): def class_path(cls):
@@ -243,7 +245,7 @@ class APIResource(GymObject):
'has invalid ID: %r' % (type(self).__name__, id), 'id') 'has invalid ID: %r' % (type(self).__name__, id), 'id')
id = util.utf8(id) id = util.utf8(id)
base = self.class_path() base = self.class_path()
extn = urllib.quote_plus(id) extn = urllib.parse.quote_plus(id)
return "%s/%s" % (base, extn) return "%s/%s" % (base, extn)
class ListObject(GymObject): class ListObject(GymObject):
@@ -280,7 +282,7 @@ class ListObject(GymObject):
def retrieve(self, id, **params): def retrieve(self, id, **params):
base = self.get('url') base = self.get('url')
id = util.utf8(id) id = util.utf8(id)
extn = urllib.quote_plus(id) extn = urllib.parse.quote_plus(id)
url = "%s/%s" % (base, extn) url = "%s/%s" % (base, extn)
return self.request('get', url, params) return self.request('get', url, params)