Introduction to Scatter Operation in MPI

Afzal Badshah, PhD
3 min readJun 24, 2024

In MPI (Message Passing Interface) programming, the scatter operation is a collective communication pattern used to distribute data from one process to multiple processes. It takes an array or list of data on the root process and divides it into smaller chunks, then scatters these chunks to all other processes in the communicator. Each process receives one chunk of the data, allowing for parallel processing on different subsets of the data.

Code:

from mpi4py import MPI  # Import MPI module from mpi4py library
comm = MPI.COMM_WORLD  # Initialize MPI communicator
size = comm.Get_size() # Get the size of the communicator (total number of processes)
rank = comm.Get_rank() # Get the rank of the current process
print("Total number of processes:", size) # Print the total number of processes
print("Rank of current process:", rank) # Print the rank of the current process
if rank == 0: # Check if the current process is the root process (rank 0)
data = [(i+1)**2 for i in range(size)] # Generate data to be scattered by root process
print("Data generated by root process:", data) # Print the generated data by root process
else: # For non-root processes
data = None # Set data to None
data = comm.scatter(data, root=0) # Scatter data from root process to all other processes
print("Process", rank, "received data:", data) # Print the data received by the current process
assert data == (rank+1)**2 # Verify correctness of received data
print("Data verification successful for process", rank) # Print success message if data verification passes

Explanation:

from mpi4py import MPI  # Import MPI module from mpi4py library

Import the MPI module from the mpi4py library, which provides MPI functionality for Python programs.

comm = MPI.COMM_WORLD  # Initialize MPI communicator

Initialize the MPI communicator comm using MPI.COMM_WORLD, representing the group of processes involved in communication.

size = comm.Get_size()  # Get the size of the communicator (total number of processes)

Retrieve the total number of processes in the communicator and store it in the variable size.

rank = comm.Get_rank()  # Get the rank of the current process

Retrieve the rank of the current process within the communicator and store it in the variable rank.

print("Total number of processes:", size)  # Print the total number of processes
print("Rank of current process:", rank) # Print the rank of the current process

Print the total number of processes and the rank of the current process.

if rank == 0:  # Check if the current process is the root process (rank 0)
data = [(i+1)**2 for i in range(size)] # Generate data to be scattered by root process
print("Data generated by root process:", data) # Print the generated data by root process
else: # For non-root processes
data = None # Set data to None

If the current process is the root process (rank 0), generate data to be scattered. Otherwise, set data to None.

data = comm.scatter(data, root=0)  # Scatter data from root process to all other processes

Scatter the data from the root process (rank 0) to all other processes in the communicator. Each process receives one chunk of the data.

print("Process", rank, "received data:", data)  # Print the data received by the current process

Print the data received by the current process after the scatter operation.

assert data == (rank+1)**2  # Verify correctness of received data
print("Data verification successful for process", rank) # Print success message if data verification passes

Verify the correctness of the received data by checking if each element of the received data is equal to the square of the rank of the respective process. If the assertion passes, print a success message.

Visit the detailed tutorial here.

--

--

Afzal Badshah, PhD

Dr Afzal Badshah focuses on academic skills, pedagogy (teaching skills) and life skills.