mpi test fixes

This commit is contained in:
Peter Zhokhov
2019-01-24 15:46:58 -08:00
parent e868bdaa1a
commit 0e0dd77f61
5 changed files with 16 additions and 14 deletions

View File

@@ -1,11 +1,16 @@
from collections import defaultdict from collections import defaultdict
from mpi4py import MPI
import os, numpy as np import os, numpy as np
import platform import platform
import shutil import shutil
import subprocess import subprocess
import warnings import warnings
try:
from mpi4py import MPI
except ImportError:
MPI = None
def sync_from_root(sess, variables, comm=None): def sync_from_root(sess, variables, comm=None):
""" """
Send the root node's parameters to every worker. Send the root node's parameters to every worker.

View File

@@ -1,10 +1,12 @@
from baselines.common import mpi_util
from mpi4py import MPI
import subprocess import subprocess
import sys import sys
from baselines import logger from baselines import logger
from baselines.common.tests.test_with_mpi import with_mpi
from baselines.common import mpi_util
def helper_for_mpi_weighted_mean(): @with_mpi()
def test_mpi_weighted_mean():
from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
if comm.rank == 0: if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)} name2valcount = {'a' : (10, 2), 'b' : (20,3)}
@@ -24,8 +26,3 @@ def helper_for_mpi_weighted_mean():
d2 = logger.dumpkvs(mpi_mean=True) d2 = logger.dumpkvs(mpi_mean=True)
if comm.rank == 0: if comm.rank == 0:
assert d2 == correctval assert d2 == correctval
def test_mpi_weighted_mean():
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])

View File

@@ -10,7 +10,7 @@ try:
except ImportError: except ImportError:
MPI = None MPI = None
def test_with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True): def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
def outer_thunk(fn): def outer_thunk(fn):
def thunk(*args, **kwargs): def thunk(*args, **kwargs):
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs))) serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))

View File

@@ -46,8 +46,8 @@ class DummyVecEnv(VecEnv):
def step_wait(self): def step_wait(self):
for e in range(self.num_envs): for e in range(self.num_envs):
action = self.actions[e] action = self.actions[e]
if isinstance(self.envs[e].action_space, spaces.Discrete): # if isinstance(self.envs[e].action_space, spaces.Discrete):
action = int(action) # action = int(action)
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action) obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
if self.buf_dones[e]: if self.buf_dones[e]:

View File

@@ -8,7 +8,7 @@ import pytest
from .dummy_vec_env import DummyVecEnv from .dummy_vec_env import DummyVecEnv
from .shmem_vec_env import ShmemVecEnv from .shmem_vec_env import ShmemVecEnv
from .subproc_vec_env import SubprocVecEnv from .subproc_vec_env import SubprocVecEnv
from baselines.common.tests.test_with_mpi import test_with_mpi from baselines.common.tests.test_with_mpi import with_mpi
def assert_envs_equal(env1, env2, num_steps): def assert_envs_equal(env1, env2, num_steps):
@@ -103,7 +103,7 @@ class SimpleEnv(gym.Env):
@test_with_mpi() @with_mpi()
def test_mpi_with_subprocvecenv(): def test_mpi_with_subprocvecenv():
shape = (2,3,4) shape = (2,3,4)
nenv = 1 nenv = 1