Distributed Training with DistributedDataParallel

deep learning
pytorch
distributed systems
Distributed training explained using torch distributed.
Published

April 14, 2025

# What is DistributedDataParallel (DDP)?

DistributedDataParallel is a way to parallelize training across multiple GPUs or nodes. It is an extension of DataParallel that provides more flexibility and scalability. DataParallel (DP) is an older approach to data parallelism. DP is trivially simple (with just one extra line of code) but it is less performant. DDP improves upon the architecture in a few ways:

DataParallel DistributedDataParallel
Simpler to use More involved changes to use
More overhead; model is replicated and destroyed at each forward pass Model is replicated only once at the start
Only supports single-node parallelism Supports single-node and multi-node parallelism
Slower; uses multithreading on a single process and runs into Global Interpreter Lock (GIL) contention Faster (no GIL contention) because it uses multiprocessing

# Multi-GPU Training with DDP

DDP uses multiprocessing to copy the model to each GPU (rank). This allows the model (and code) to only be copied to each process once at the start of the script. Multiprocessing pickles Python objects to serialize across processes. This means all objects must be pickleable.

All objects sent to each process must be pickleable

What can be pickled and unpickled?

Some objects that can’t be pickled:

  • Generators
  • Database connections
  • Sockets
  • File descriptors
  • Lambdas

The basic outline of DDP training is:

  1. Setup the communications by setting the host and port
  2. Spawn a training process for each GPU with torch.multiprocessing.spawn
  3. Initialize the process group using init_process_group:
    • GPU - "nccl"
    • CPU - "gloo"
  4. Wrap the model with DistributedDataParallel
  5. Create a DistributedSampler and DataLoader for the dataset
  6. Train the model and update sampler with the epoch
  7. Destroy the process group using destroy_process_group
Code
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group

def main(
    rank: int, # rank is the GPU number
    world_size: int, # world_size is the number of processes, typically set to the number of GPUs
    train_path: str,
    random_state: int,
    lr: float,
    epochs: int,
    num_workers: int,
    batch_size: int,
) -> None:
    # ... Other setup

    # initialize the process group for distributed training
    init_process_group(
        backend="nccl" if torch.cuda.is_available() else "gloo",  # CPU only works on gloo backend
        rank=rank,
        world_size=world_size,
    )

    # we need to divide the workers and batch across the different processes used in distributed training
    num_workers_per_proc = num_workers // world_size # avoids CPU contentionn
    batch_size_per_proc = batch_size // world_size   # avoids OOM

    # DistributedSampler ensures that training data is chunked across GPUs without overlapping samples
    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(
        val_dataset,
        shuffle=False,  # don't shuffle the validation dataset
        drop_last=True, # DistributedSampler will append additional samples to fill an incomplete batch.  We don't want that for the validation dataset.
    )

    train_dataloader = DataLoader(
        train_dataset,
        shuffle=False,  # don't shuffle if using DistributedSampler as that's done within the sampler
        sampler=train_sampler,
        num_workers=num_workers_per_proc,
        batch_size=batch_size_per_proc,
        pin_memory=True,
        collate_fn=collate_rowgroups,
    )
    val_dataloader = DataLoader(
        val_dataset,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers_per_proc,
        batch_size=batch_size_per_proc,
        pin_memory=True,
        collate_fn=collate_rowgroups,
    )

    # set up the NN model as normal and then wrap with DDP
    model.to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    for epch in range(0, epochs):
        # need to call `set_epoch()` at the beginning of each epoch before creating the
        # `DataLoader` iterator to make shuffling work properly across multiple epochs
        train_sampler.set_epoch(epch)

        # ... training loop

    # ... diagnostics

    # cleanly shutdown distributed processes
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    import os
    import torch

    os.environ["MASTER_ADDR"] = "localhost" # for single node on local compute
    os.environ["MASTER_PORT"] = "12345" # any free port

    world_size = torch.cuda.device_count()  # number of GPUs

    # ... CLI args parsing

    # spawn multiple processes equal to world_size first argument passed in will be the rank
    torch.multiprocessing.spawn(
        main,
        args=(
            world_size,
            args.train_path,
            args.random_state,
            args.lr,
            args.epochs,
            args.num_workers,
            args.batch_size,
        ),
        nprocs=world_size,  # this is used to set the `rank` parameter.  It is passed as the first argument
    )

1. Communication: host and port

Setting up distributed training on a single (local) node is as simple as setting the host and port as below. To setup multi-node see torchrun.

Code
import os
os.environ["MASTER_ADDR"] = "localhost" # for single node on local compute
os.environ["MASTER_PORT"] = "12345" # any free port

2. Spawn a process on each rank (GPU)

Since torch.multiprocessing follows the same API as multiprocessing. To spawn a new process we pass the function to run, the arguments as a tuple, and specify the number of processes (usually the number of GPUs).

Code
import torch

# spawn multiple processes equal to world_size first argument passed in will be the rank
torch.multiprocessing.spawn(
    main,
    args=(
        world_size,
        args.train_path,
        args.random_state,
        args.lr,
        args.epochs,
        args.num_workers,
        args.batch_size,
    ),
    nprocs=world_size,  # this is used to set the `rank` parameter.  It is passed as the first argument
)

3. Constructing the process group

  • First, before initializing the group process, call set_device, which sets the default GPU for each process. This is important to prevent hangs or excessive memory utilization on GPU:0
  • The process group can be initialized by TCP (default) or from a shared file-system. Read more on process group initialization.
  • init_process_group initializes the distributed process group.
  • Read more about choosing a DDP backend.
Code
from torch.distributed import init_process_group

# initialize the process group for distributed training
init_process_group(
    backend="nccl" if torch.cuda.is_available() else "gloo",  # CPU only works on gloo backend
    rank=rank,
    world_size=world_size,
)

5. Constructing the DDP model

  • device_ids - 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None. (From the DDP docs)
model.to(rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

4. Distributing the data with DistributedSampler

Dividing the workload

  • DistributedSampler chunks the input data across all distributed processes, without overlap. If we have 4 GPUs then each process will only load 1/4 of the training dataset.
  • The batch_size needs to be divided among the processes (GPUs). Each process will receive an input batch of batch_size_per_proc; the effective batch size is batch_size_per_proc * world_size, if the batch_size is 64 and world_size is 4 GPUs, then the effective batch size is still 64 in total.
  • The num_workers also needs to be divided among the processes (GPUs). Each proces will receive num_workers_per_proc.
batch_size and OOM

If the batch_size isn’t divided among the processes then then each process gets a full batch and the effective batch size is now xworld_size larger and we are likely to run out of CPU or GPU memory if not careful.

Code
# we need to divide the workers and batch across the different processes used in distributed training
num_workers_per_proc = num_workers // world_size # avoids CPU contention
batch_size_per_proc = batch_size // world_size   # avoids OOM

Setting Up the DistributedSampler

  • shuffle - by default, the DistributedSampler will shuffle the dataset. We don’t want to shuffle the validation dataset.
  • drop_last - by default, the DistributedSampler will append additional samples to fill an incomplete batch (e.g. there’s 100 training samples with batch_size=64 there would be one batch of 36 samples). We don’t want to repeat samples for the validation dataset as that would change the metrics.
  • pin_memory - For large datasets that are loaded into the CPU in the Dataset, pinning the memory can speed up the host to device transfer (see this NVIDIA blog for more details).
Code
# DistributedSampler ensures that training data is chunked across GPUs without overlapping samples
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(
    val_dataset,
    shuffle=False,  # don't shuffle the validation dataset
    drop_last=True, # DistributedSampler will append additional samples to fill an incomplete batch.  We don't want that for the validation dataset.
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=False,  # don't shuffle if using DistributedSampler as that's done within the sampler
    sampler=train_sampler,
    num_workers=num_workers_per_proc,
    batch_size=batch_size_per_proc,
    pin_memory=True,
    collate_fn=collate_rowgroups,
)
val_dataloader = DataLoader(
    val_dataset,
    shuffle=False,
    sampler=val_sampler,
    num_workers=num_workers_per_proc,
    batch_size=batch_size_per_proc,
    pin_memory=True,
    collate_fn=collate_rowgroups,
)

7. Shuffling across epochs

  • Calling the set_epoch() method on the DistributedSampler at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.
Code
# ... Neural Network setup

for epch in range(0, epochs):
    # need to call `set_epoch()` at the beginning of each epoch before creating the
    # `DataLoader` iterator to make shuffling work properly across multiple epochs
    train_sampler.set_epoch(epch)

    # ... training loop

6. Running the distributed training job

  • rank is auto-allocated by DDP when calling torch.multiprocessing.spawn.
  • world_size is the number of processes across the training job. For GPU training, this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.
  • Both rank and world_size are new parameters to main(). Because of how spawning processes works, rank needs to be the first parameter to the calling function, main(rank, ...).
PyTorch Multiprocessing

PyTorch’s torch.multiprocessing package is a wrapper around the native multiprocessing module and the API is 100% compatible.

Code
import torch

# ... CLI args parsing

world_size = torch.cuda.device_count()  # number of GPUs

torch.multiprocessing.spawn(
    main,
    args=(
        world_size,
        args.train_path,
        args.random_state,
        args.lr,
        args.epochs,
        args.num_workers,
        args.batch_size,
    ),
    nprocs=world_size,  # this is used to set the `rank` parameter.  It is passed as the first argument
)

MLFlow logging

Since there are now multiple processes runnning the same code, the same logging will happen on each process. MLFlow doesn’t know how to distinguish that there are different processes logging the same metric. We can guard against this by only logging on the main process (GPU0):

Code
import mlflow

if rank == 0:
    mlflow.log_metrics(
        {
            "train loss": train_loss,
            "train accuracy": train_accuracy,
            "val loss": val_loss,
            "val accuracy": val_accuracy,
        },
        step=epoch,
    )

# Gradients, Losses, and Metrics

Under the hood, DDP synchronizes and gathers the gradients across all processes. However, any other ad-hoc value calculated in your code is not; e.g. losses and metrics.

Gradients are synchronized

Model gradients are synchronized across processes during the backward pass. This means that the model in each process is the same! See DDP: Internal Design.

Losses are not synchronized

Even though the model is the same in each process, the loss is calculated on only the portion of the batch that each process sees. The losses don’t need to be synchronized for training but we may want to synchronize the losses for logging or definitelty when calculating metrics on the hold-out (validation) dataset.

We could avoid this by not using a DistributedSampler for the validation set, but then only 1 process would be used to calculate the loss for the whole validation set each epoch, which will be slow.

So if each process is calculating and accumulating losses and metrics separately, how do we log those and report as if there were a single process? Well, those values will need to be gathered and then accumulated. Say we have 4 processes, one for each GPU, and each is processing 1/4 of the training dataset. We want to report the loss for each epoch. If we log the loss in each process, we will have 4 different losses. We can gather and combine them in a few ways. Since the loss is just a number value we can use torch.distributed.reduce or torch.distributed.all_reduce:

Example: torch.distributed.reduce

In this example, we gather and combine using summation with the dist.ReduceOp.SUM, all the loss_tensors into the process 0 tensor (dst=0). Each tensor in each process must be the same shape. Since we are assigning the values in each processes’s loss_tensor to it’s rank, we expect the final gathered values to be 0 + 1 + 2 + 3 = 6 in the main process loss_tensor.

Code
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed import init_process_group

def reduce_tensor(rank: int, world_size: int) -> None:
    init_process_group(
        backend="nccl" if torch.cuda.is_available() else "gloo",  # CPU only works on gloo backend
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank) # tell each device (GPU) which one it is.

    loss_tensor = torch.tensor([rank, rank]).cuda()
    print(loss_tensor)

    dist.reduce(loss_tensor, op=dist.ReduceOp.SUM, dst=0, async_op=True)
    print(loss_tensor)

    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"

    num_gpu = torch.cuda.device_count()
    mp.spawn(reduce_tensor, nprocs=num_gpu, args=(num_gpu,))
tensor([1, 1], device='cuda:1')
tensor([3, 3], device='cuda:3')
tensor([2, 2], device='cuda:2')
tensor([0, 0], device='cuda:0')

tensor([6, 6], device='cuda:0')
tensor([1, 1], device='cuda:1')
tensor([2, 2], device='cuda:2')
tensor([3, 3], device='cuda:3')

Example: torch.distributed.all_reduce

In this example, we gather and combine using summation with the dist.ReduceOp.SUM, all the loss_tensors into all the processes. Each tensor in each process must be the same shape. Since we are assigning the values in each processes’s loss_tensor to it’s rank, we expect the final gathered values to be 0 + 1 + 2 + 3 = 6 in the all the processes’s loss_tensor.

Code
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed import init_process_group

def all_reduce_tensor(rank: int, world_size: int) -> None:
    init_process_group(
        backend="nccl" if torch.cuda.is_available() else "gloo",  # CPU only works on gloo backend
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank) # tell each device (GPU) which one it is.

    loss_tensor = torch.tensor([rank, rank]).cuda()
    print(loss_tensor)

    dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM, dst=0, async_op=True)
    print(loss_tensor)

    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"

    num_gpu = torch.cuda.device_count()
    mp.spawn(all_reduce_tensor, nprocs=num_gpu, args=(num_gpu,))
tensor([2, 2], device='cuda:2')
tensor([3, 3], device='cuda:3')
tensor([1, 1], device='cuda:1')
tensor([0, 0], device='cuda:0')

tensor([6, 6], device='cuda:2')
tensor([6, 6], device='cuda:3')
tensor([6, 6], device='cuda:1')
tensor([6, 6], device='cuda:0')

Example: torch.distributed.gather_object

Syncing across processing is simple enough for tensors, but if we have a number of values to gather (say a bunch of metrics for example) it would be easier to only need to gather once and store those values in an appropriate data structure. Most of the distributed gathering function only work on tensors, but we can use torch.distributed.gather_object and / or torch.distributed.all_gather_object to pass pickleable Python objects between ranks.

In this example, we want to track the losses and number of samples in a Counter so that we can combine and calculate the mean loss after gathering. We gather each loss counter to rank 0. Each counter is placed into the gather_list, which must have all elements set to None initially. When calling, dist.gather_object, the gather_list must only exist in the rank being gathered to (dst=0 or rank 0 in this case). Then we use functools.reduce to sum all the Counters gathered in the gather_list.

Code
import os
import operator
from collections import Counter
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed import init_process_group

def gather_object(rank: int, world_size: int) -> None:
    init_process_group(
        backend="nccl" if torch.cuda.is_available() else "gloo",  # CPU only works on gloo backend
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank) # tell each device (GPU) which one it is.

    losses = Counter(
        {"loss": 0.01, "num_samples": rank}
    )
    print(losses)

    gather_list = [None for _ in range(world_size)]

    dist.gather_object(losses, gather_list if rank == 0 else None, dst=0)
    losses = functools.reduce(operator.add, gather_list)

    print(losses)
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"

    num_gpu = torch.cuda.device_count()
    mp.spawn(gather_object, nprocs=num_gpu, args=(num_gpu,))
Counter({'num_samples': 1, 'loss': 0.01})
Counter({'num_samples': 2, 'loss': 0.01})
Counter({'num_samples': 4, 'loss': 0.01})
Counter({'num_samples': 3, 'loss': 0.01})

Counter({'num_samples': 10, 'loss': 0.04})
Counter({'num_samples': 4, 'loss': 0.01})
Counter({'num_samples': 3, 'loss': 0.01})
Counter({'num_samples': 2, 'loss': 0.01})

Example: torch.distributed.all_gather_object

In this example, we want to track the losses and number of samples in a Counter so that we can combine and calculate the mean loss after gathering. We gather each loss counter to each rank. Each counter is placed into the gather_list, which must have all elements set to None initially. Then we use functools.reduce to sum all the Counters gathered in the gather_list.

Code
import os
import operator
from collections import Counter
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.distributed import init_process_group

def all_gather_object(rank: int, world_size: int) -> None:
    init_process_group(
        backend="nccl" if torch.cuda.is_available() else "gloo",  # CPU only works on gloo backend
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank) # tell each device (GPU) which one it is.

    losses = Counter(
        {"loss": 0.01, "num_samples": rank}
    )
    print(losses)

    gather_list = [None for _ in range(world_size)]

    dist.all_gather_object(gather_list, losses)
    losses = functools.reduce(operator.add, gather_list)

    print(losses)
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"

    num_gpu = torch.cuda.device_count()
    mp.spawn(all_gather_object, nprocs=num_gpu, args=(num_gpu,))
Counter({'num_samples': 1, 'loss': 0.01})
Counter({'num_samples': 4, 'loss': 0.01})
Counter({'num_samples': 2, 'loss': 0.01})
Counter({'num_samples': 3, 'loss': 0.01})

Counter({'num_samples': 10, 'loss': 0.04})
Counter({'num_samples': 10, 'loss': 0.04})
Counter({'num_samples': 10, 'loss': 0.04})
Counter({'num_samples': 10, 'loss': 0.04})

# Multi-Node Training

Now that we’ve setup DDP, we have the option of running on multiple nodes. There are a few extra bits that need to be changed before we run multinode training:

  • change a few lines to enable torchrun
  • install the project as a python package for torchrun to work properly
  • add fault tolerance

# Resources

  1. DistributedDataParallel - Documentation for DDP.
  2. Getting Started with Distributed Data Parallel - Good starting point to understand DDP and writing a training script using DDP for single-node and multi-node.
  3. Writing Distributed Applications with Pytorch - In-depth article about writing distributed applications in PyTorch and how communication works under the hood.
  4. DistributedSampler - Used in conjunction with a DataLoader, the DistributedSampler enables each process to only load the data it processes, rather than all the data to be processed.
  5. DistributedDataParallel training in PyTorch - Explaination of how DDP works and how to use it.
  6. (AML) Distributed GPU training guide (SDK v2)
  7. PyTorch DistributedDataParallel Example In Azure ML - Multi-Node Multi-GPU Distributed Training - Example of DDP for AML.
  8. PyTorch Distributed: Experiences on Accelerating Data Parallel Training - Paper on how Facebook designed DDP to be faster for distributed training.