mpi test fixes

This commit is contained in:
Peter Zhokhov
2019-01-24 15:46:58 -08:00
parent e868bdaa1a
commit 0e0dd77f61
5 changed files with 16 additions and 14 deletions

View File

@@ -1,11 +1,16 @@
from collections import defaultdict
from mpi4py import MPI
import os, numpy as np
import platform
import shutil
import subprocess
import warnings
try:
from mpi4py import MPI
except ImportError:
MPI = None
def sync_from_root(sess, variables, comm=None):
"""
Send the root node's parameters to every worker.

View File

@@ -1,10 +1,12 @@
from baselines.common import mpi_util
from mpi4py import MPI
import subprocess
import sys
from baselines import logger
from baselines.common.tests.test_with_mpi import with_mpi
from baselines.common import mpi_util
def helper_for_mpi_weighted_mean():
@with_mpi()
def test_mpi_weighted_mean():
from mpi4py import MPI
comm = MPI.COMM_WORLD
if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
@@ -24,8 +26,3 @@ def helper_for_mpi_weighted_mean():
d2 = logger.dumpkvs(mpi_mean=True)
if comm.rank == 0:
assert d2 == correctval
def test_mpi_weighted_mean():
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])

View File

@@ -10,7 +10,7 @@ try:
except ImportError:
MPI = None
def test_with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
def outer_thunk(fn):
def thunk(*args, **kwargs):
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))

View File

@@ -46,8 +46,8 @@ class DummyVecEnv(VecEnv):
def step_wait(self):
for e in range(self.num_envs):
action = self.actions[e]
if isinstance(self.envs[e].action_space, spaces.Discrete):
action = int(action)
# if isinstance(self.envs[e].action_space, spaces.Discrete):
# action = int(action)
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
if self.buf_dones[e]:

View File

@@ -8,7 +8,7 @@ import pytest
from .dummy_vec_env import DummyVecEnv
from .shmem_vec_env import ShmemVecEnv
from .subproc_vec_env import SubprocVecEnv
from baselines.common.tests.test_with_mpi import test_with_mpi
from baselines.common.tests.test_with_mpi import with_mpi
def assert_envs_equal(env1, env2, num_steps):
@@ -103,7 +103,7 @@ class SimpleEnv(gym.Env):
@test_with_mpi()
@with_mpi()
def test_mpi_with_subprocvecenv():
shape = (2,3,4)
nenv = 1