2019-01-15 09:59:27 -08:00
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import subprocess
|
|
|
|
import cloudpickle
|
|
|
|
import base64
|
|
|
|
import pytest
|
2019-04-22 14:41:46 -07:00
|
|
|
from functools import wraps
|
2019-01-15 09:59:27 -08:00
|
|
|
|
2019-01-24 14:35:41 -08:00
|
|
|
try:
|
|
|
|
from mpi4py import MPI
|
|
|
|
except ImportError:
|
|
|
|
MPI = None
|
2019-01-15 09:59:27 -08:00
|
|
|
|
2019-01-24 15:46:58 -08:00
|
|
|
def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
|
2019-01-15 09:59:27 -08:00
|
|
|
def outer_thunk(fn):
|
2019-04-22 14:41:46 -07:00
|
|
|
@wraps(fn)
|
2019-01-15 09:59:27 -08:00
|
|
|
def thunk(*args, **kwargs):
|
|
|
|
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
|
|
|
|
subprocess.check_call([
|
|
|
|
'mpiexec','-n', str(nproc),
|
|
|
|
sys.executable,
|
|
|
|
'-m', 'baselines.common.tests.test_with_mpi',
|
|
|
|
serialized_fn
|
|
|
|
], env=os.environ, timeout=timeout)
|
|
|
|
|
|
|
|
if skip_if_no_mpi:
|
|
|
|
return pytest.mark.skipif(MPI is None, reason="MPI not present")(thunk)
|
|
|
|
else:
|
|
|
|
return thunk
|
|
|
|
|
|
|
|
return outer_thunk
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if len(sys.argv) > 1:
|
|
|
|
fn = cloudpickle.loads(base64.b64decode(sys.argv[1]))
|
|
|
|
assert callable(fn)
|
|
|
|
fn()
|