Source code for taurex.mpi

"""
Module for wrapping MPI functions. 
Most functions will do nothing if mpi4py is not present.

"""
from functools import lru_cache
from functools import wraps
import numpy as np


[docs]def convert_op(operation): from mpi4py import MPI if operation.lower() == 'sum': return MPI.SUM else: raise NotImplementedError
[docs]@lru_cache(maxsize=10) def shared_comm(): """ Returns the process id within a node. Used for shared memory """ from mpi4py import MPI return MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
[docs]@lru_cache(maxsize=10) def nprocs(): """Gets number of processes or returns 1 if mpi is not installed Returns ------- int: Rank of process or 1 if MPI is not installed """ try: from mpi4py import MPI except ImportError: return 1 comm = MPI.COMM_WORLD return comm.Get_size()
[docs]def allgather(value): try: from mpi4py import MPI except ImportError: return [value] comm = MPI.COMM_WORLD data = value data = comm.allgather(data) return data
[docs]def allreduce(value, op): try: from mpi4py import MPI except ImportError: return value comm = MPI.COMM_WORLD data = value data = comm.allreduce(value, op=convert_op(op)) return data
[docs]def broadcast(array, rank=0): import numpy as np try: from mpi4py import MPI except ImportError: return array comm = MPI.COMM_WORLD if isinstance(array, np.ndarray): data = None if get_rank() == rank: data = np.copy(array) else: data = np.zeros_like(array) comm.Bcast(data, root=rank) else: data = comm.bcast(array, root=rank) return data
[docs]@lru_cache(maxsize=10) def get_rank(comm=None): """Gets rank or returns 0 if mpi is not installed Parameters ---------- comm: int, optional MPI communicator, default is MPI_COMM_WORLD Returns ------- int: Rank of process in communitor or 0 if MPI is not installed """ try: from mpi4py import MPI except ImportError: return 0 comm = comm or MPI.COMM_WORLD rank = comm.Get_rank() return rank
[docs]def barrier(comm=None): """ Waits for all processes to finish. Does nothing if mpi4py not present Parameters ---------- comm: int, optional MPI communicator, default is MPI_COMM_WORLD """ try: from mpi4py import MPI except ImportError: return comm = comm or MPI.COMM_WORLD comm.Barrier()
[docs]def only_master_rank(f): """ A decorator to ensure only the master MPI rank can run it """ @wraps(f) def wrapper(*args, **kwargs): if get_rank() == 0: return f(*args, **kwargs) return wrapper
[docs]@lru_cache(maxsize=10) def shared_rank(): return shared_comm().Get_rank()
[docs]def allocate_as_shared(arr, logger=None, force_shared=False): """ Converts a numpy array into an MPI shared memory. This allow for things like opacities to be loaded only once per node when using MPI. Only activates if mpi4py installed and when enabled via the ``mpi_use_shared`` input:: [Global] mpi_use_shared = True or ``force_shared=True`` otherwise does nothing and returns the same array back Parameters ---------- arr: numpy array Array to convert logger: :class:`~taurex.log.logger.Logger` Logger object to print outputs force_shared: bool Force conversion to shared memory Returns ------- array: If enabled and MPI present, shared memory version of array otherwise the original array """ try: from mpi4py import MPI except ImportError: return arr from taurex.cache import GlobalCache if GlobalCache()['mpi_use_shared'] or force_shared: if logger is not None: logger.info('Moving to shared memory') comm = shared_comm() nbytes = arr.size*arr.itemsize window = MPI.Win.Allocate_shared(nbytes, arr.itemsize, comm=comm) buf, itemsize = window.Shared_query(0) if itemsize != arr.itemsize: raise Exception(f'Shared memory size {itemsize} != array itemsize {arr.itemsize}') shared_array = np.ndarray(buffer=buf, dtype=arr.dtype, shape=arr.shape) if shared_rank() == 0: shared_array[...] = arr[...] comm.Barrier() return shared_array else: return arr