minor fixes from internal (#665)
* sync internal changes. Make ddpg work with vecenvs * B -> nenvs for consistency with other algos, small cleanups * eval_done[d]==True -> eval_done[d] * flake8 and numpy.random.random_integers deprecation warning * Merge branch 'master' of github.com:openai/games into peterz_track_baselines_branch
This commit is contained in:
@@ -213,8 +213,11 @@ class LazyFrames(object):
|
||||
def __getitem__(self, i):
|
||||
return self._force()[i]
|
||||
|
||||
def make_atari(env_id):
|
||||
def make_atari(env_id, timelimit=True):
|
||||
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
|
||||
env = gym.make(env_id)
|
||||
if not timelimit:
|
||||
env = env.env
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
|
@@ -28,6 +28,3 @@ class VecFrameStack(VecEnvWrapper):
|
||||
self.stackedobs[...] = 0
|
||||
self.stackedobs[..., -obs.shape[-1]:] = obs
|
||||
return self.stackedobs
|
||||
|
||||
def close(self):
|
||||
self.venv.close()
|
||||
|
@@ -106,7 +106,8 @@ class CSVOutputFormat(KVWriter):
|
||||
|
||||
def writekvs(self, kvs):
|
||||
# Add our current row to the history
|
||||
extra_keys = kvs.keys() - self.keys
|
||||
extra_keys = list(kvs.keys() - self.keys)
|
||||
extra_keys.sort()
|
||||
if extra_keys:
|
||||
self.keys.extend(extra_keys)
|
||||
self.file.seek(0)
|
||||
|
16
setup.py
16
setup.py
@@ -48,9 +48,13 @@ setup(name='baselines',
|
||||
|
||||
|
||||
# ensure there is some tensorflow build with version above 1.4
|
||||
try:
|
||||
from distutils.version import StrictVersion
|
||||
import tensorflow
|
||||
assert StrictVersion(re.sub(r'-rc\d+$', '', tensorflow.__version__)) >= StrictVersion('1.4.0')
|
||||
except ImportError:
|
||||
assert False, "TensorFlow needed, of version above 1.4"
|
||||
import pkg_resources
|
||||
tf_pkg = None
|
||||
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
|
||||
try:
|
||||
tf_pkg = pkg_resources.get_distribution(tf_pkg_name)
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
||||
from distutils.version import StrictVersion
|
||||
assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||
|
Reference in New Issue
Block a user