From 0e423a01080358376d5b862bcc68d97baced228d Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Fri, 6 Sep 2019 14:36:35 -0700 Subject: [PATCH] use allreduce instead of Allreduce (send pickled data instead of floats) - probably affects performance somewhat, but avoid element number mismatch. Fixes 998 --- baselines/common/mpi_moments.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/baselines/common/mpi_moments.py b/baselines/common/mpi_moments.py index 7a97a43..5510de0 100644 --- a/baselines/common/mpi_moments.py +++ b/baselines/common/mpi_moments.py @@ -12,8 +12,9 @@ def mpi_mean(x, axis=0, comm=None, keepdims=False): localsum = np.zeros(n+1, x.dtype) localsum[:n] = xsum.ravel() localsum[n] = x.shape[axis] - globalsum = np.zeros_like(localsum) - comm.Allreduce(localsum, globalsum, op=MPI.SUM) + # globalsum = np.zeros_like(localsum) + # comm.Allreduce(localsum, globalsum, op=MPI.SUM) + globalsum = comm.allreduce(localsum, op=MPI.SUM) return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] def mpi_moments(x, axis=0, comm=None, keepdims=False):