Compare commits

...

21 Commits

Author SHA1 Message Date
Peter Zhokhov
58801032fc install mpi4py in mpi dockerfile 2018-10-31 11:34:10 -07:00
Peter Zhokhov
b4a149a75f fix .travis.yml 2018-10-31 11:32:03 -07:00
Peter Zhokhov
c248bf9a46 CI dockerfiles with and without mpi 2018-10-31 11:27:45 -07:00
Peter Zhokhov
d1f7d12743 mpiless ddpg 2018-10-31 09:48:41 -07:00
Peter Zhokhov
f0d49fb67d add assertion to test in mpi_adam; fix trpo_mpi failure without MPI on cartpole 2018-10-30 14:45:20 -07:00
Peter Zhokhov
ef2e7246c9 autopep8 2018-10-30 14:11:38 -07:00
Peter Zhokhov
3e3e2b7998 MpiAdam becomes regular Adam if Mpi not present 2018-10-30 14:04:30 -07:00
Peter Zhokhov
d00f3bce34 syntax and flake8 2018-10-30 09:47:39 -07:00
Peter Zhokhov
72aa2f1251 more MPI removal 2018-10-29 15:43:56 -07:00
Peter Zhokhov
ea7a52b652 further removing MPI references where unnecessary 2018-10-29 15:38:16 -07:00
Peter Zhokhov
064c45fa76 Merge branch 'master' of github.com:openai/baselines into peterz_mpiless 2018-10-29 15:31:37 -07:00
Peter Zhokhov
6f148fdb0d squash-merged latest master 2018-10-29 15:28:59 -07:00
Peter Zhokhov
93c7cc202c Merge branch 'master' of github.com:openai/baselines 2018-10-29 15:25:38 -07:00
Peter Zhokhov
de36116e3b update tensorflow version check regex to parse version like 1.2.3rc4 (previously only 1.2.3-rc4) 2018-10-29 15:25:31 -07:00
Mathieu Poliquin
e2b41828af Set 'cnn' as default network for retro (#683) 2018-10-29 13:30:41 -07:00
pzhokhov
8e56ddeac2 Multidiscrete action space compatibility for policy gradient-based methods (#677)
* multidiscrete space compatibility

* flake8 and syntax
2018-10-24 11:01:59 -07:00
Juliano Laganá
c3bd8cea66 Adds description of param_noise parameter in deepq.learn method (#675) 2018-10-24 10:00:31 -07:00
AurelianTactics
84ea7aa1fd DDPG has unused 'seed' argument (#676)
DeepQ, PPO2, ACER, trpo_mpi, A2C, and ACKTR have the code for:

```
from baselines.common import set_global_seeds
...
def learn(...):
...
   set_global_seeds(seed)
```

DDPG has the argument 'seed=None' but doesn't have the two lines of code needed to set the global seeds.
2018-10-24 09:59:46 -07:00
Peter Zhokhov
88300ed54c fix raise NotImplemented() complaints of latest flake8 2018-10-24 09:57:57 -07:00
pzhokhov
583ba082a2 Update cmd_util.py 2018-10-23 11:22:27 -07:00
Peter Zhokhov
d96e20ff27 make baselines run without mpi wip 2018-10-19 17:00:41 -07:00
21 changed files with 229 additions and 82 deletions

View File

@@ -5,9 +5,13 @@ python:
services:
- docker
env:
- DOCKER_SUFFIX=py36-nompi
- DOCKER_SUFFIX=py36-mpi
install:
- pip install flake8
- docker build . -t baselines-test
- docker build -f test.dockerfile.${DOCKER_SUFFIX} -t baselines-test .
script:
- flake8 . --show-source --statistics

View File

@@ -1,25 +0,0 @@
FROM ubuntu:16.04
RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
ENV CODE_DIR /root/code
ENV VENV /root/venv
RUN \
pip install virtualenv && \
virtualenv $VENV --python=python3 && \
. $VENV/bin/activate && \
pip install --upgrade pip
ENV PATH=$VENV/bin:$PATH
COPY . $CODE_DIR/baselines
WORKDIR $CODE_DIR/baselines
# Clean up pycache and pyc files
RUN rm -rf __pycache__ && \
find . -name "*.pyc" -delete && \
pip install tensorflow && \
pip install -e .[test]
CMD /bin/bash

View File

@@ -21,16 +21,16 @@ class Model(object):
self.sess = sess = get_session()
nbatch = nenvs * nsteps
A = tf.placeholder(ac_space.dtype, [nbatch,] + list(ac_space.shape))
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
self.model = step_model = policy(nenvs, 1, sess=sess)
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
A = train_model.pdtype.sample_placeholder([None])
ADV = tf.placeholder(tf.float32, [nbatch])
R = tf.placeholder(tf.float32, [nbatch])
PG_LR = tf.placeholder(tf.float32, [])
VF_LR = tf.placeholder(tf.float32, [])
with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE):
self.model = step_model = policy(nenvs, 1, sess=sess)
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
neglogpac = train_model.pd.neglogp(A)
self.logits = train_model.pi

View File

@@ -43,7 +43,7 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
return DummyVecEnv([make_thunk(start_index)])
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None):
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
if env_type == 'atari':
env = make_atari(env_id)

View File

@@ -39,7 +39,7 @@ class PdType(object):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def pdfromlatent(self, latent_vector):
def pdfromlatent(self, latent_vector, init_scale, init_bias):
raise NotImplementedError
def param_shape(self):
raise NotImplementedError
@@ -80,6 +80,11 @@ class MultiCategoricalPdType(PdType):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.ncats, flat)
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):

View File

@@ -1,5 +1,6 @@
import numpy as np
import tensorflow as tf
from gym.spaces import Discrete, Box
from gym.spaces import Discrete, Box, MultiDiscrete
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
'''
@@ -20,10 +21,14 @@ def observation_placeholder(ob_space, batch_size=None, name='Ob'):
tensorflow placeholder tensor
'''
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box), \
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
'Can only deal with Discrete and Box observation spaces for now'
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name)
dtype = ob_space.dtype
if dtype == np.int8:
dtype = np.uint8
return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
def observation_input(ob_space, batch_size=None, name='Ob'):
@@ -48,9 +53,12 @@ def encode_observation(ob_space, placeholder):
'''
if isinstance(ob_space, Discrete):
return tf.to_float(tf.one_hot(placeholder, ob_space.n))
elif isinstance(ob_space, Box):
return tf.to_float(placeholder)
elif isinstance(ob_space, MultiDiscrete):
placeholder = tf.cast(placeholder, tf.int32)
one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
return tf.concat(one_hots, axis=-1)
else:
raise NotImplementedError

View File

@@ -1,7 +1,11 @@
from mpi4py import MPI
import baselines.common.tf_util as U
import tensorflow as tf
import numpy as np
try:
from mpi4py import MPI
except ImportError:
MPI = None
class MpiAdam(object):
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
@@ -16,16 +20,19 @@ class MpiAdam(object):
self.t = 0
self.setfromflat = U.SetFromFlat(var_list)
self.getflat = U.GetFlat(var_list)
self.comm = MPI.COMM_WORLD if comm is None else comm
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
def update(self, localg, stepsize):
if self.t % 100 == 0:
self.check_synced()
localg = localg.astype('float32')
if self.comm is not None:
globalg = np.zeros_like(localg)
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
if self.scale_grad_by_procs:
globalg /= self.comm.Get_size()
else:
globalg = np.copy(localg)
self.t += 1
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
@@ -35,11 +42,15 @@ class MpiAdam(object):
self.setfromflat(self.getflat() + step)
def sync(self):
if self.comm is None:
return
theta = self.getflat()
self.comm.Bcast(theta, root=0)
self.setfromflat(theta)
def check_synced(self):
if self.comm is None:
return
if self.comm.Get_rank() == 0: # this is root
theta = self.getflat()
self.comm.Bcast(theta, root=0)
@@ -63,17 +74,30 @@ def test_MpiAdam():
do_update = U.function([], loss, updates=[update_op])
tf.get_default_session().run(tf.global_variables_initializer())
losslist_ref = []
for i in range(10):
print(i,do_update())
l = do_update()
print(i, l)
losslist_ref.append(l)
tf.set_random_seed(0)
tf.get_default_session().run(tf.global_variables_initializer())
var_list = [a,b]
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
adam = MpiAdam(var_list)
losslist_test = []
for i in range(10):
l,g = lossandgrad()
adam.update(g, stepsize)
print(i,l)
losslist_test.append(l)
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
if __name__ == '__main__':
test_MpiAdam()

View File

@@ -1,4 +1,8 @@
try:
from mpi4py import MPI
except ImportError:
MPI = None
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
class RunningMeanStd(object):
@@ -39,6 +43,7 @@ class RunningMeanStd(object):
n = int(np.prod(self.shape))
totalvec = np.zeros(n*2+1, 'float64')
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
if MPI is not None:
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])

View File

@@ -1,7 +1,7 @@
import numpy as np
from abc import abstractmethod
from gym import Env
from gym.spaces import Discrete, Box
from gym.spaces import MultiDiscrete, Discrete, Box
class IdentityEnv(Env):
@@ -53,6 +53,19 @@ class DiscreteIdentityEnv(IdentityEnv):
def _get_reward(self, actions):
return 1 if self.state == actions else 0
class MultiDiscreteIdentityEnv(IdentityEnv):
def __init__(
self,
dims,
episode_len=None,
):
self.action_space = MultiDiscrete(dims)
super().__init__(episode_len=episode_len)
def _get_reward(self, actions):
return 1 if all(self.state == actions) else 0
class BoxIdentityEnv(IdentityEnv):
def __init__(

View File

@@ -1,5 +1,5 @@
import pytest
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv
from baselines.run import get_learn_function
from baselines.common.tests.util import simple_test
@@ -21,6 +21,7 @@ learn_kwargs = {
algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi']
algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
@pytest.mark.slow
@@ -38,6 +39,21 @@ def test_discrete_identity(alg):
env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
simple_test(env_fn, learn_fn, 0.9)
@pytest.mark.slow
@pytest.mark.parametrize("alg", algos_multidisc)
def test_multidiscrete_identity(alg):
'''
Test if the algorithm (with an mlp policy)
can learn an identity transformation (i.e. return observation as an action)
'''
kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs)
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100)
simple_test(env_fn, learn_fn, 0.9)
@pytest.mark.slow
@pytest.mark.parametrize("alg", algos_cont)
def test_continuous_identity(alg):
@@ -55,5 +71,5 @@ def test_continuous_identity(alg):
simple_test(env_fn, learn_fn, -0.1)
if __name__ == '__main__':
test_continuous_identity('ddpg')
test_multidiscrete_identity('acktr')

View File

@@ -20,8 +20,11 @@ class DummyVecEnv(VecEnv):
env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
if isinstance(obs_space, spaces.MultiDiscrete):
obs_space.shape = obs_space.shape[0]
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)

View File

@@ -7,13 +7,16 @@ from baselines.ddpg.ddpg_learner import DDPG
from baselines.ddpg.models import Actor, Critic
from baselines.ddpg.memory import Memory
from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from baselines.common import set_global_seeds
import baselines.common.tf_util as U
from baselines import logger
import numpy as np
from mpi4py import MPI
try:
from mpi4py import MPI
except ImportError:
MPI = None
def learn(network, env,
seed=None,
@@ -41,6 +44,7 @@ def learn(network, env,
param_noise_adaption_interval=50,
**network_kwargs):
set_global_seeds(seed)
if total_timesteps is not None:
assert nb_epochs is None
@@ -48,7 +52,11 @@ def learn(network, env,
else:
nb_epochs = 500
if MPI is not None:
rank = MPI.COMM_WORLD.Get_rank()
else:
rank = 0
nb_actions = env.action_space.shape[-1]
assert (np.abs(env.action_space.low) == env.action_space.high).all() # we assume symmetric actions.
@@ -199,7 +207,11 @@ def learn(network, env,
eval_episode_rewards_history.append(eval_episode_reward[d])
eval_episode_reward[d] = 0.0
if MPI is not None:
mpi_size = MPI.COMM_WORLD.Get_size()
else:
mpi_size = 1
# Log stats.
# XXX shouldn't call np.mean on variable length lists
duration = time.time() - start_time
@@ -233,7 +245,10 @@ def learn(network, env,
else:
raise ValueError('expected scalar, got %s'%x)
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
combined_stats_sums = np.array([ np.array(x).flatten()[0] for x in combined_stats.values()])
if MPI is not None:
combined_stats_sums = MPI.COMM_WORLD.allreduce(combined_stats_sums)
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
# Total statistics.

View File

@@ -9,7 +9,10 @@ from baselines import logger
from baselines.common.mpi_adam import MpiAdam
import baselines.common.tf_util as U
from baselines.common.mpi_running_mean_std import RunningMeanStd
try:
from mpi4py import MPI
except ImportError:
MPI = None
def normalize(x, stats):
if stats is None:
@@ -358,6 +361,11 @@ class DDPG(object):
return stats
def adapt_param_noise(self):
try:
from mpi4py import MPI
except ImportError:
MPI = None
if self.param_noise is None:
return 0.
@@ -371,7 +379,16 @@ class DDPG(object):
self.param_noise_stddev: self.param_noise.current_stddev,
})
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else:
mean_distance = distance
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else:
mean_distance = distance
self.param_noise.adapt(mean_distance)
return mean_distance

View File

@@ -169,6 +169,8 @@ def learn(env,
to 1.0. If set to None equals to total_timesteps.
prioritized_replay_eps: float
epsilon to add to the TD errors when updating priorities.
param_noise: bool
whether or not to use parameter space noise (https://arxiv.org/abs/1706.01905)
callback: (locals, globals) -> None
function called at every steps with state of the algorithm.
If callback returns true training stops.

View File

@@ -18,11 +18,11 @@ class TfInput(object):
"""Return the tf variable(s) representing the possibly postprocessed value
of placeholder(s).
"""
raise NotImplemented()
raise NotImplementedError
def make_feed_dict(data):
"""Given data input it to the placeholder(s)."""
raise NotImplemented()
raise NotImplementedError
class PlaceholderTfInput(TfInput):

View File

@@ -10,11 +10,15 @@ from baselines.common import explained_variance, set_global_seeds
from baselines.common.policies import build_policy
from baselines.common.runners import AbstractEnvRunner
from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
try:
from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
from mpi4py import MPI
from baselines.common.tf_util import initialize
from baselines.common.mpi_util import sync_from_root
except ImportError:
MPI = None
from baselines.common.tf_util import initialize
class Model(object):
"""
@@ -93,7 +97,10 @@ class Model(object):
# 1. Get the model parameters
params = tf.trainable_variables('ppo2_model')
# 2. Build our trainer
if MPI is not None:
trainer = MpiAdamOptimizer(MPI.COMM_WORLD, learning_rate=LR, epsilon=1e-5)
else:
trainer = tf.train.AdamOptimizer(learning_rate=LR, epsilon=1e-5)
# 3. Calculate the gradients
grads_and_var = trainer.compute_gradients(loss, params)
grads, var = zip(*grads_and_var)
@@ -136,9 +143,11 @@ class Model(object):
self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
if MPI.COMM_WORLD.Get_rank() == 0:
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
initialize()
global_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="")
if MPI is not None:
sync_from_root(sess, global_variables) #pylint: disable=E1101
class Runner(AbstractEnvRunner):
@@ -392,9 +401,9 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
logger.logkv('time_elapsed', tnow - tfirststart)
for (lossval, lossname) in zip(lossvals, model.loss_names):
logger.logkv(lossname, lossval)
if MPI.COMM_WORLD.Get_rank() == 0:
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
logger.dumpkvs()
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and MPI.COMM_WORLD.Get_rank() == 0:
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
checkdir = osp.join(logger.get_dir(), 'checkpoints')
os.makedirs(checkdir, exist_ok=True)
savepath = osp.join(checkdir, '%.5i'%update)

View File

@@ -131,7 +131,7 @@ def get_env_type(env_id):
def get_default_network(env_type):
if env_type == 'atari':
if env_type in {'atari', 'retro'}:
return 'cnn'
else:
return 'mlp'

View File

@@ -4,7 +4,6 @@ import baselines.common.tf_util as U
import tensorflow as tf, numpy as np
import time
from baselines.common import colorize
from mpi4py import MPI
from collections import deque
from baselines.common import set_global_seeds
from baselines.common.mpi_adam import MpiAdam
@@ -13,6 +12,11 @@ from baselines.common.input import observation_placeholder
from baselines.common.policies import build_policy
from contextlib import contextmanager
try:
from mpi4py import MPI
except ImportError:
MPI = None
def traj_segment_generator(pi, env, horizon, stochastic):
# Initialize state variables
t = 0
@@ -146,9 +150,12 @@ def learn(*,
'''
if MPI is not None:
nworkers = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
else:
nworkers = 1
rank = 0
cpus_per_worker = 1
U.get_session(config=tf.ConfigProto(
@@ -237,9 +244,13 @@ def learn(*,
def allmean(x):
assert isinstance(x, np.ndarray)
if MPI is not None:
out = np.empty_like(x)
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
out /= nworkers
else:
out = np.copy(x)
return out
U.initialize()
@@ -247,7 +258,9 @@ def learn(*,
pi.load(load_path)
th_init = get_flat()
if MPI is not None:
MPI.COMM_WORLD.Bcast(th_init, root=0)
set_from_flat(th_init)
vfadam.sync()
print("Init param sum", th_init.sum(), flush=True)
@@ -353,7 +366,11 @@ def learn(*,
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
if MPI is not None:
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
else:
listoflrpairs = [lrlocal]
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)

View File

@@ -15,6 +15,9 @@ extras = {
],
'bullet': [
'pybullet',
],
'mpi': [
'mpi4py'
]
}
@@ -34,7 +37,6 @@ setup(name='baselines',
'joblib',
'dill',
'progressbar2',
'mpi4py',
'cloudpickle',
'click',
'opencv-python'
@@ -57,4 +59,4 @@ for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
pass
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
from distutils.version import StrictVersion
assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')

16
test.dockerfile.py36-mpi Normal file
View File

@@ -0,0 +1,16 @@
FROM python:3.6
RUN apt-get -y update && apt-get -y install ffmpeg libopenmpi-dev
ENV CODE_DIR /root/code
COPY . $CODE_DIR/baselines
WORKDIR $CODE_DIR/baselines
# Clean up pycache and pyc files
RUN rm -rf __pycache__ && \
find . -name "*.pyc" -delete && \
pip install tensorflow && \
pip install -e .[test,mpi]
CMD /bin/bash

View File

@@ -0,0 +1,16 @@
FROM python:3.6
RUN apt-get -y update && apt-get -y install ffmpeg
ENV CODE_DIR /root/code
COPY . $CODE_DIR/baselines
WORKDIR $CODE_DIR/baselines
# Clean up pycache and pyc files
RUN rm -rf __pycache__ && \
find . -name "*.pyc" -delete && \
pip install tensorflow && \
pip install -e .[test]
CMD /bin/bash