mpi test fixes
This commit is contained in:
@@ -1,11 +1,16 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from mpi4py import MPI
|
|
||||||
import os, numpy as np
|
import os, numpy as np
|
||||||
import platform
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
except ImportError:
|
||||||
|
MPI = None
|
||||||
|
|
||||||
|
|
||||||
def sync_from_root(sess, variables, comm=None):
|
def sync_from_root(sess, variables, comm=None):
|
||||||
"""
|
"""
|
||||||
Send the root node's parameters to every worker.
|
Send the root node's parameters to every worker.
|
||||||
|
@@ -1,10 +1,12 @@
|
|||||||
from baselines.common import mpi_util
|
|
||||||
from mpi4py import MPI
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
|
from baselines.common.tests.test_with_mpi import with_mpi
|
||||||
|
from baselines.common import mpi_util
|
||||||
|
|
||||||
def helper_for_mpi_weighted_mean():
|
@with_mpi()
|
||||||
|
def test_mpi_weighted_mean():
|
||||||
|
from mpi4py import MPI
|
||||||
comm = MPI.COMM_WORLD
|
comm = MPI.COMM_WORLD
|
||||||
if comm.rank == 0:
|
if comm.rank == 0:
|
||||||
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
||||||
@@ -24,8 +26,3 @@ def helper_for_mpi_weighted_mean():
|
|||||||
d2 = logger.dumpkvs(mpi_mean=True)
|
d2 = logger.dumpkvs(mpi_mean=True)
|
||||||
if comm.rank == 0:
|
if comm.rank == 0:
|
||||||
assert d2 == correctval
|
assert d2 == correctval
|
||||||
|
|
||||||
|
|
||||||
def test_mpi_weighted_mean():
|
|
||||||
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
|
|
||||||
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])
|
|
||||||
|
@@ -10,7 +10,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
MPI = None
|
MPI = None
|
||||||
|
|
||||||
def test_with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
|
def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
|
||||||
def outer_thunk(fn):
|
def outer_thunk(fn):
|
||||||
def thunk(*args, **kwargs):
|
def thunk(*args, **kwargs):
|
||||||
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
|
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
|
||||||
|
@@ -46,8 +46,8 @@ class DummyVecEnv(VecEnv):
|
|||||||
def step_wait(self):
|
def step_wait(self):
|
||||||
for e in range(self.num_envs):
|
for e in range(self.num_envs):
|
||||||
action = self.actions[e]
|
action = self.actions[e]
|
||||||
if isinstance(self.envs[e].action_space, spaces.Discrete):
|
# if isinstance(self.envs[e].action_space, spaces.Discrete):
|
||||||
action = int(action)
|
# action = int(action)
|
||||||
|
|
||||||
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
|
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
|
||||||
if self.buf_dones[e]:
|
if self.buf_dones[e]:
|
||||||
|
@@ -8,7 +8,7 @@ import pytest
|
|||||||
from .dummy_vec_env import DummyVecEnv
|
from .dummy_vec_env import DummyVecEnv
|
||||||
from .shmem_vec_env import ShmemVecEnv
|
from .shmem_vec_env import ShmemVecEnv
|
||||||
from .subproc_vec_env import SubprocVecEnv
|
from .subproc_vec_env import SubprocVecEnv
|
||||||
from baselines.common.tests.test_with_mpi import test_with_mpi
|
from baselines.common.tests.test_with_mpi import with_mpi
|
||||||
|
|
||||||
|
|
||||||
def assert_envs_equal(env1, env2, num_steps):
|
def assert_envs_equal(env1, env2, num_steps):
|
||||||
@@ -103,7 +103,7 @@ class SimpleEnv(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
@test_with_mpi()
|
@with_mpi()
|
||||||
def test_mpi_with_subprocvecenv():
|
def test_mpi_with_subprocvecenv():
|
||||||
shape = (2,3,4)
|
shape = (2,3,4)
|
||||||
nenv = 1
|
nenv = 1
|
||||||
|
Reference in New Issue
Block a user