Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
0e423a0108 |
@@ -12,8 +12,9 @@ def mpi_mean(x, axis=0, comm=None, keepdims=False):
|
|||||||
localsum = np.zeros(n+1, x.dtype)
|
localsum = np.zeros(n+1, x.dtype)
|
||||||
localsum[:n] = xsum.ravel()
|
localsum[:n] = xsum.ravel()
|
||||||
localsum[n] = x.shape[axis]
|
localsum[n] = x.shape[axis]
|
||||||
globalsum = np.zeros_like(localsum)
|
# globalsum = np.zeros_like(localsum)
|
||||||
comm.Allreduce(localsum, globalsum, op=MPI.SUM)
|
# comm.Allreduce(localsum, globalsum, op=MPI.SUM)
|
||||||
|
globalsum = comm.allreduce(localsum, op=MPI.SUM)
|
||||||
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
|
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
|
||||||
|
|
||||||
def mpi_moments(x, axis=0, comm=None, keepdims=False):
|
def mpi_moments(x, axis=0, comm=None, keepdims=False):
|
||||||
|
Reference in New Issue
Block a user