use allreduce instead of Allreduce (send pickled data instead of floats) - probably affects performance somewhat, but avoid element number mismatch. Fixes 998 (#1000)
This commit is contained in:
@@ -12,8 +12,9 @@ def mpi_mean(x, axis=0, comm=None, keepdims=False):
|
|||||||
localsum = np.zeros(n+1, x.dtype)
|
localsum = np.zeros(n+1, x.dtype)
|
||||||
localsum[:n] = xsum.ravel()
|
localsum[:n] = xsum.ravel()
|
||||||
localsum[n] = x.shape[axis]
|
localsum[n] = x.shape[axis]
|
||||||
globalsum = np.zeros_like(localsum)
|
# globalsum = np.zeros_like(localsum)
|
||||||
comm.Allreduce(localsum, globalsum, op=MPI.SUM)
|
# comm.Allreduce(localsum, globalsum, op=MPI.SUM)
|
||||||
|
globalsum = comm.allreduce(localsum, op=MPI.SUM)
|
||||||
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
|
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
|
||||||
|
|
||||||
def mpi_moments(x, axis=0, comm=None, keepdims=False):
|
def mpi_moments(x, axis=0, comm=None, keepdims=False):
|
||||||
|
Reference in New Issue
Block a user