Merge branch 'master' of github.com:openai/games into peterz_track_baselines_branch
This commit is contained in:
@@ -213,8 +213,11 @@ class LazyFrames(object):
|
|||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return self._force()[i]
|
return self._force()[i]
|
||||||
|
|
||||||
def make_atari(env_id):
|
def make_atari(env_id, timelimit=True):
|
||||||
|
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
|
if not timelimit:
|
||||||
|
env = env.env
|
||||||
assert 'NoFrameskip' in env.spec.id
|
assert 'NoFrameskip' in env.spec.id
|
||||||
env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
|
@@ -28,6 +28,3 @@ class VecFrameStack(VecEnvWrapper):
|
|||||||
self.stackedobs[...] = 0
|
self.stackedobs[...] = 0
|
||||||
self.stackedobs[..., -obs.shape[-1]:] = obs
|
self.stackedobs[..., -obs.shape[-1]:] = obs
|
||||||
return self.stackedobs
|
return self.stackedobs
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.venv.close()
|
|
||||||
|
@@ -106,7 +106,8 @@ class CSVOutputFormat(KVWriter):
|
|||||||
|
|
||||||
def writekvs(self, kvs):
|
def writekvs(self, kvs):
|
||||||
# Add our current row to the history
|
# Add our current row to the history
|
||||||
extra_keys = kvs.keys() - self.keys
|
extra_keys = list(kvs.keys() - self.keys)
|
||||||
|
extra_keys.sort()
|
||||||
if extra_keys:
|
if extra_keys:
|
||||||
self.keys.extend(extra_keys)
|
self.keys.extend(extra_keys)
|
||||||
self.file.seek(0)
|
self.file.seek(0)
|
||||||
|
16
setup.py
16
setup.py
@@ -48,9 +48,13 @@ setup(name='baselines',
|
|||||||
|
|
||||||
|
|
||||||
# ensure there is some tensorflow build with version above 1.4
|
# ensure there is some tensorflow build with version above 1.4
|
||||||
try:
|
import pkg_resources
|
||||||
from distutils.version import StrictVersion
|
tf_pkg = None
|
||||||
import tensorflow
|
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
|
||||||
assert StrictVersion(re.sub(r'-rc\d+$', '', tensorflow.__version__)) >= StrictVersion('1.4.0')
|
try:
|
||||||
except ImportError:
|
tf_pkg = pkg_resources.get_distribution(tf_pkg_name)
|
||||||
assert False, "TensorFlow needed, of version above 1.4"
|
except pkg_resources.DistributionNotFound:
|
||||||
|
pass
|
||||||
|
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
||||||
|
from distutils.version import StrictVersion
|
||||||
|
assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||||
|
Reference in New Issue
Block a user