Compare commits

...

1 Commits

View File

@@ -12,8 +12,9 @@ def mpi_mean(x, axis=0, comm=None, keepdims=False):
localsum = np.zeros(n+1, x.dtype) localsum = np.zeros(n+1, x.dtype)
localsum[:n] = xsum.ravel() localsum[:n] = xsum.ravel()
localsum[n] = x.shape[axis] localsum[n] = x.shape[axis]
globalsum = np.zeros_like(localsum) # globalsum = np.zeros_like(localsum)
comm.Allreduce(localsum, globalsum, op=MPI.SUM) # comm.Allreduce(localsum, globalsum, op=MPI.SUM)
globalsum = comm.allreduce(localsum, op=MPI.SUM)
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
def mpi_moments(x, axis=0, comm=None, keepdims=False): def mpi_moments(x, axis=0, comm=None, keepdims=False):