Files
baselines/baselines/common/tests/test_with_mpi.py
pzhokhov 8e0282ee94 ci/runtests.sh - pass all folders to pytest (#342)
* ci/runtests.sh - pass all folders to pytest

* mpi_optimizer_test precision 1e-4

* fixes to tests

* search for tests in the entire jax folder, also remove unnecessary humor
2019-05-03 15:54:25 -07:00

39 lines
997 B
Python

import os
import sys
import subprocess
import cloudpickle
import base64
import pytest
from functools import wraps
try:
from mpi4py import MPI
except ImportError:
MPI = None
def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
def outer_thunk(fn):
@wraps(fn)
def thunk(*args, **kwargs):
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
subprocess.check_call([
'mpiexec','-n', str(nproc),
sys.executable,
'-m', 'baselines.common.tests.test_with_mpi',
serialized_fn
], env=os.environ, timeout=timeout)
if skip_if_no_mpi:
return pytest.mark.skipif(MPI is None, reason="MPI not present")(thunk)
else:
return thunk
return outer_thunk
if __name__ == '__main__':
if len(sys.argv) > 1:
fn = cloudpickle.loads(base64.b64decode(sys.argv[1]))
assert callable(fn)
fn()