Files
baselines/baselines/common/mpi_util.py
Peter Zhokhov 217b111c88 merged refactor
2018-08-10 14:14:46 -07:00

102 lines
3.0 KiB
Python

from collections import defaultdict
from mpi4py import MPI
import os, numpy as np
import platform
import shutil
import subprocess
def sync_from_root(sess, variables, comm=None):
"""
Send the root node's parameters to every worker.
Arguments:
sess: the TensorFlow session.
variables: all parameter variables including optimizer's
"""
if comm is None: comm = MPI.COMM_WORLD
rank = comm.Get_rank()
for var in variables:
if rank == 0:
comm.Bcast(sess.run(var))
else:
import tensorflow as tf
returned_var = np.empty(var.shape, dtype='float32')
comm.Bcast(returned_var)
sess.run(tf.assign(var, returned_var))
def gpu_count():
"""
Count the GPUs on this machine.
"""
if shutil.which('nvidia-smi') is None:
return 0
output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
return max(0, len(output.split(b'\n')) - 2)
def setup_mpi_gpus():
"""
Set CUDA_VISIBLE_DEVICES using MPI.
"""
num_gpus = gpu_count()
if num_gpus == 0:
return
local_rank, _ = get_local_rank_size(MPI.COMM_WORLD)
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank % num_gpus)
def get_local_rank_size(comm):
"""
Returns the rank of each process on its machine
The processes on a given machine will be assigned ranks
0, 1, 2, ..., N-1,
where N is the number of processes on this machine.
Useful if you want to assign one gpu per machine
"""
this_node = platform.node()
ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
node2rankssofar = defaultdict(int)
local_rank = None
for (rank, node) in ranks_nodes:
if rank == comm.Get_rank():
local_rank = node2rankssofar[node]
node2rankssofar[node] += 1
assert local_rank is not None
return local_rank, node2rankssofar[this_node]
def share_file(comm, path):
"""
Copies the file from rank 0 to all other ranks
Puts it in the same place on all machines
"""
localrank, _ = get_local_rank_size(comm)
if comm.Get_rank() == 0:
with open(path, 'rb') as fh:
data = fh.read()
comm.bcast(data)
else:
data = comm.bcast(None)
if localrank == 0:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as fh:
fh.write(data)
comm.Barrier()
def dict_gather(comm, d, op='mean', assert_all_have_data=True):
if comm is None: return d
alldicts = comm.allgather(d)
size = comm.size
k2li = defaultdict(list)
for d in alldicts:
for (k,v) in d.items():
k2li[k].append(v)
result = {}
for (k,li) in k2li.items():
if assert_all_have_data:
assert len(li)==size, "only %i out of %i MPI workers have sent '%s'" % (len(li), size, k)
if op=='mean':
result[k] = np.mean(li, axis=0)
elif op=='sum':
result[k] = np.sum(li, axis=0)
else:
assert 0, op
return result