mpi test fixes
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
from collections import defaultdict
|
||||
from mpi4py import MPI
|
||||
import os, numpy as np
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
|
||||
def sync_from_root(sess, variables, comm=None):
|
||||
"""
|
||||
Send the root node's parameters to every worker.
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from baselines.common import mpi_util
|
||||
from mpi4py import MPI
|
||||
import subprocess
|
||||
import sys
|
||||
from baselines import logger
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
from baselines.common import mpi_util
|
||||
|
||||
def helper_for_mpi_weighted_mean():
|
||||
@with_mpi()
|
||||
def test_mpi_weighted_mean():
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
if comm.rank == 0:
|
||||
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
||||
@@ -24,8 +26,3 @@ def helper_for_mpi_weighted_mean():
|
||||
d2 = logger.dumpkvs(mpi_mean=True)
|
||||
if comm.rank == 0:
|
||||
assert d2 == correctval
|
||||
|
||||
|
||||
def test_mpi_weighted_mean():
|
||||
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
|
||||
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])
|
||||
|
@@ -10,7 +10,7 @@ try:
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
def test_with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
|
||||
def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
|
||||
def outer_thunk(fn):
|
||||
def thunk(*args, **kwargs):
|
||||
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
|
||||
|
@@ -46,8 +46,8 @@ class DummyVecEnv(VecEnv):
|
||||
def step_wait(self):
|
||||
for e in range(self.num_envs):
|
||||
action = self.actions[e]
|
||||
if isinstance(self.envs[e].action_space, spaces.Discrete):
|
||||
action = int(action)
|
||||
# if isinstance(self.envs[e].action_space, spaces.Discrete):
|
||||
# action = int(action)
|
||||
|
||||
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
|
||||
if self.buf_dones[e]:
|
||||
|
@@ -8,7 +8,7 @@ import pytest
|
||||
from .dummy_vec_env import DummyVecEnv
|
||||
from .shmem_vec_env import ShmemVecEnv
|
||||
from .subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.tests.test_with_mpi import test_with_mpi
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
|
||||
|
||||
def assert_envs_equal(env1, env2, num_steps):
|
||||
@@ -103,7 +103,7 @@ class SimpleEnv(gym.Env):
|
||||
|
||||
|
||||
|
||||
@test_with_mpi()
|
||||
@with_mpi()
|
||||
def test_mpi_with_subprocvecenv():
|
||||
shape = (2,3,4)
|
||||
nenv = 1
|
||||
|
Reference in New Issue
Block a user