Distributed training explained using torch distributed.
Published
April 14, 2025
# What is DistributedDataParallel (DDP)?
DistributedDataParallel is a way to parallelize training across multiple GPUs or nodes. It is an extension of DataParallel that provides more flexibility and scalability. DataParallel (DP) is an older approach to data parallelism. DP is trivially simple (with just one extra line of code) but it is less performant. DDP improves upon the architecture in a few ways:
DataParallel
DistributedDataParallel
Simpler to use
More involved changes to use
More overhead; model is replicated and destroyed at each forward pass
Model is replicated only once at the start
Only supports single-node parallelism
Supports single-node and multi-node parallelism
Slower; uses multithreading on a single process and runs into Global Interpreter Lock (GIL) contention
Faster (no GIL contention) because it uses multiprocessing
# Multi-GPU Training with DDP
DDP uses multiprocessing to copy the model to each GPU (rank). This allows the model (and code) to only be copied to each process once at the start of the script. Multiprocessing pickles Python objects to serialize across processes. This means all objects must be pickleable.
All objects sent to each process must be pickleable
Setup the communications by setting the host and port
Spawn a training process for each GPU with torch.multiprocessing.spawn
Initialize the process group using init_process_group:
GPU - "nccl"
CPU - "gloo"
Wrap the model with DistributedDataParallel
Create a DistributedSampler and DataLoader for the dataset
Train the model and update sampler with the epoch
Destroy the process group using destroy_process_group
Code
from torch.utils.data.distributed import DistributedSamplerfrom torch.distributed import init_process_group, destroy_process_groupdef main( rank: int, # rank is the GPU number world_size: int, # world_size is the number of processes, typically set to the number of GPUs train_path: str, random_state: int, lr: float, epochs: int, num_workers: int, batch_size: int,) ->None:# ... Other setup# initialize the process group for distributed training init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, )# we need to divide the workers and batch across the different processes used in distributed training num_workers_per_proc = num_workers // world_size # avoids CPU contentionn batch_size_per_proc = batch_size // world_size # avoids OOM# DistributedSampler ensures that training data is chunked across GPUs without overlapping samples train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler( val_dataset, shuffle=False, # don't shuffle the validation dataset drop_last=True, # DistributedSampler will append additional samples to fill an incomplete batch. We don't want that for the validation dataset. ) train_dataloader = DataLoader( train_dataset, shuffle=False, # don't shuffle if using DistributedSampler as that's done within the sampler sampler=train_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups, ) val_dataloader = DataLoader( val_dataset, shuffle=False, sampler=val_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups, )# set up the NN model as normal and then wrap with DDP model.to(rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])for epch inrange(0, epochs):# need to call `set_epoch()` at the beginning of each epoch before creating the# `DataLoader` iterator to make shuffling work properly across multiple epochs train_sampler.set_epoch(epch)# ... training loop# ... diagnostics# cleanly shutdown distributed processes torch.distributed.destroy_process_group()if__name__=="__main__":import osimport torch os.environ["MASTER_ADDR"] ="localhost"# for single node on local compute os.environ["MASTER_PORT"] ="12345"# any free port world_size = torch.cuda.device_count() # number of GPUs# ... CLI args parsing# spawn multiple processes equal to world_size first argument passed in will be the rank torch.multiprocessing.spawn( main, args=( world_size, args.train_path, args.random_state, args.lr, args.epochs, args.num_workers, args.batch_size, ), nprocs=world_size, # this is used to set the `rank` parameter. It is passed as the first argument )
1. Communication: host and port
Setting up distributed training on a single (local) node is as simple as setting the host and port as below. To setup multi-node see torchrun.
Code
import osos.environ["MASTER_ADDR"] ="localhost"# for single node on local computeos.environ["MASTER_PORT"] ="12345"# any free port
2. Spawn a process on each rank (GPU)
Since torch.multiprocessing follows the same API as multiprocessing. To spawn a new process we pass the function to run, the arguments as a tuple, and specify the number of processes (usually the number of GPUs).
Code
import torch# spawn multiple processes equal to world_size first argument passed in will be the ranktorch.multiprocessing.spawn( main, args=( world_size, args.train_path, args.random_state, args.lr, args.epochs, args.num_workers, args.batch_size, ), nprocs=world_size, # this is used to set the `rank` parameter. It is passed as the first argument)
3. Constructing the process group
First, before initializing the group process, call set_device, which sets the default GPU for each process. This is important to prevent hangs or excessive memory utilization on GPU:0
The process group can be initialized by TCP (default) or from a shared file-system. Read more on process group initialization.
from torch.distributed import init_process_group# initialize the process group for distributed traininginit_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size,)
5. Constructing the DDP model
device_ids - 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None. (From the DDP docs)
DistributedSampler chunks the input data across all distributed processes, without overlap. If we have 4 GPUs then each process will only load 1/4 of the training dataset.
The batch_size needs to be divided among the processes (GPUs). Each process will receive an input batch of batch_size_per_proc; the effective batch size is batch_size_per_proc * world_size, if the batch_size is 64 and world_size is 4 GPUs, then the effective batch size is still 64 in total.
The num_workers also needs to be divided among the processes (GPUs). Each proces will receive num_workers_per_proc.
batch_size and OOM
If the batch_size isn’t divided among the processes then then each process gets a full batch and the effective batch size is now xworld_size larger and we are likely to run out of CPU or GPU memory if not careful.
Code
# we need to divide the workers and batch across the different processes used in distributed trainingnum_workers_per_proc = num_workers // world_size # avoids CPU contentionbatch_size_per_proc = batch_size // world_size # avoids OOM
Setting Up the DistributedSampler
shuffle - by default, the DistributedSampler will shuffle the dataset. We don’t want to shuffle the validation dataset.
drop_last - by default, the DistributedSampler will append additional samples to fill an incomplete batch (e.g. there’s 100 training samples with batch_size=64 there would be one batch of 36 samples). We don’t want to repeat samples for the validation dataset as that would change the metrics.
pin_memory - For large datasets that are loaded into the CPU in the Dataset, pinning the memory can speed up the host to device transfer (see this NVIDIA blog for more details).
Code
# DistributedSampler ensures that training data is chunked across GPUs without overlapping samplestrain_sampler = DistributedSampler(train_dataset)val_sampler = DistributedSampler( val_dataset, shuffle=False, # don't shuffle the validation dataset drop_last=True, # DistributedSampler will append additional samples to fill an incomplete batch. We don't want that for the validation dataset.)train_dataloader = DataLoader( train_dataset, shuffle=False, # don't shuffle if using DistributedSampler as that's done within the sampler sampler=train_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups,)val_dataloader = DataLoader( val_dataset, shuffle=False, sampler=val_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups,)
7. Shuffling across epochs
Calling the set_epoch() method on the DistributedSampler at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.
Code
# ... Neural Network setupfor epch inrange(0, epochs):# need to call `set_epoch()` at the beginning of each epoch before creating the# `DataLoader` iterator to make shuffling work properly across multiple epochs train_sampler.set_epoch(epch)# ... training loop
world_size is the number of processes across the training job. For GPU training, this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.
Both rank and world_size are new parameters to main(). Because of how spawning processes works, rankneeds to be the first parameter to the calling function, main(rank, ...).
import torch# ... CLI args parsingworld_size = torch.cuda.device_count() # number of GPUstorch.multiprocessing.spawn( main, args=( world_size, args.train_path, args.random_state, args.lr, args.epochs, args.num_workers, args.batch_size, ), nprocs=world_size, # this is used to set the `rank` parameter. It is passed as the first argument)
MLFlow logging
Since there are now multiple processes runnning the same code, the same logging will happen on each process. MLFlow doesn’t know how to distinguish that there are different processes logging the same metric. We can guard against this by only logging on the main process (GPU0):
Under the hood, DDP synchronizes and gathers the gradients across all processes. However, any other ad-hoc value calculated in your code is not; e.g. losses and metrics.
Gradients are synchronized
Model gradients are synchronized across processes during the backward pass. This means that the model in each process is the same! See DDP: Internal Design.
Losses are not synchronized
Even though the model is the same in each process, the loss is calculated on only the portion of the batch that each process sees. The losses don’t need to be synchronized for training but we may want to synchronize the losses for logging or definitelty when calculating metrics on the hold-out (validation) dataset.
We could avoid this by not using a DistributedSampler for the validation set, but then only 1 process would be used to calculate the loss for the whole validation set each epoch, which will be slow.
So if each process is calculating and accumulating losses and metrics separately, how do we log those and report as if there were a single process? Well, those values will need to be gathered and then accumulated. Say we have 4 processes, one for each GPU, and each is processing 1/4 of the training dataset. We want to report the loss for each epoch. If we log the loss in each process, we will have 4 different losses. We can gather and combine them in a few ways. Since the loss is just a number value we can use torch.distributed.reduce or torch.distributed.all_reduce:
Example: torch.distributed.reduce
In this example, we gather and combine using summation with the dist.ReduceOp.SUM, all the loss_tensors into the process 0 tensor (dst=0). Each tensor in each process must be the same shape. Since we are assigning the values in each processes’s loss_tensor to it’s rank, we expect the final gathered values to be 0 + 1 + 2 + 3 = 6 in the main process loss_tensor.
Code
import osimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef reduce_tensor(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. loss_tensor = torch.tensor([rank, rank]).cuda()print(loss_tensor) dist.reduce(loss_tensor, op=dist.ReduceOp.SUM, dst=0, async_op=True)print(loss_tensor) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(reduce_tensor, nprocs=num_gpu, args=(num_gpu,))
In this example, we gather and combine using summation with the dist.ReduceOp.SUM, all the loss_tensors into all the processes. Each tensor in each process must be the same shape. Since we are assigning the values in each processes’s loss_tensor to it’s rank, we expect the final gathered values to be 0 + 1 + 2 + 3 = 6 in the all the processes’s loss_tensor.
Code
import osimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef all_reduce_tensor(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. loss_tensor = torch.tensor([rank, rank]).cuda()print(loss_tensor) dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM, dst=0, async_op=True)print(loss_tensor) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(all_reduce_tensor, nprocs=num_gpu, args=(num_gpu,))
Syncing across processing is simple enough for tensors, but if we have a number of values to gather (say a bunch of metrics for example) it would be easier to only need to gather once and store those values in an appropriate data structure. Most of the distributed gathering function only work on tensors, but we can use torch.distributed.gather_object and / or torch.distributed.all_gather_object to pass pickleable Python objects between ranks.
In this example, we want to track the losses and number of samples in a Counter so that we can combine and calculate the mean loss after gathering. We gather each loss counter to rank 0. Each counter is placed into the gather_list, which must have all elements set to None initially. When calling, dist.gather_object, the gather_list must only exist in the rank being gathered to (dst=0 or rank 0 in this case). Then we use functools.reduce to sum all the Counters gathered in the gather_list.
Code
import osimport operatorfrom collections import Counterimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef gather_object(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. losses = Counter( {"loss": 0.01, "num_samples": rank} )print(losses) gather_list = [Nonefor _ inrange(world_size)] dist.gather_object(losses, gather_list if rank ==0elseNone, dst=0) losses = functools.reduce(operator.add, gather_list)print(losses) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(gather_object, nprocs=num_gpu, args=(num_gpu,))
In this example, we want to track the losses and number of samples in a Counter so that we can combine and calculate the mean loss after gathering. We gather each loss counter to each rank. Each counter is placed into the gather_list, which must have all elements set to None initially. Then we use functools.reduce to sum all the Counters gathered in the gather_list.
Code
import osimport operatorfrom collections import Counterimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef all_gather_object(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. losses = Counter( {"loss": 0.01, "num_samples": rank} )print(losses) gather_list = [Nonefor _ inrange(world_size)] dist.all_gather_object(gather_list, losses) losses = functools.reduce(operator.add, gather_list)print(losses) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(all_gather_object, nprocs=num_gpu, args=(num_gpu,))
Now that we’ve setup DDP, we have the option of running on multiple nodes. There are a few extra bits that need to be changed before we run multinode training:
DistributedSampler - Used in conjunction with a DataLoader, the DistributedSampler enables each process to only load the data it processes, rather than all the data to be processed.
---title: "Distributed Training with DistributedDataParallel"date: "2025-04-14"categories: [deep learning, pytorch, distributed systems]description: "Distributed training explained using torch distributed."reading-time: truereference-location: documentcitation-location: document# bibliography: references.bibcitations-hover: trueformat: html: code-fold: show code-tools: true code-summary: ""---## # What is DistributedDataParallel (DDP)?`DistributedDataParallel` is a way to parallelize training across multiple GPUs or nodes. It is an extension of `DataParallel` that provides more flexibility and scalability. `DataParallel` (DP) is an older approach to data parallelism. DP is trivially simple (with just one extra line of code) but it is less performant. DDP improves upon the architecture in a few ways:|`DataParallel` | `DistributedDataParallel` ||--------------------------------------------------------------------------------------------------------|------------------------------------------------------------|| Simpler to use | More involved changes to use || More overhead; model is replicated and destroyed at each forward pass | Model is replicated only once at the start || Only supports single-node parallelism | Supports single-node and multi-node parallelism || Slower; uses multithreading on a single process and runs into Global Interpreter Lock (GIL) contention | Faster (no GIL contention) because it uses multiprocessing |## # Multi-GPU Training with DDPDDP uses multiprocessing to copy the model to each GPU (`rank`). This allows the model (and code) to only be copied to each process once at the start of the script. Multiprocessing pickles Python objects to serialize across processes. This means _all_ objects must be [pickleable](https://docs.python.org/3/library/pickle.html).::: {.callout-important title="All objects sent to each process must be pickleable"}[What can be pickled and unpickled?](https://docs.python.org/3/library/pickle.html#what-can-be-pickled-and-unpickled)**Some objects that can't be pickled:**- Generators- Database connections- Sockets- File descriptors- Lambdas:::The basic outline of DDP training is:1. Setup the communications by setting the host and port2. Spawn a training process for each GPU with `torch.multiprocessing.spawn`3. Initialize the process group using `init_process_group`: - GPU - `"nccl"` - CPU - `"gloo"`4. Wrap the model with `DistributedDataParallel`5. Create a `DistributedSampler` and `DataLoader` for the dataset6. Train the model and update sampler with the epoch7. Destroy the process group using `destroy_process_group````{python}# | label: multi-gpu# | code-fold: show# | eval: falsefrom torch.utils.data.distributed import DistributedSamplerfrom torch.distributed import init_process_group, destroy_process_groupdef main( rank: int, # rank is the GPU number world_size: int, # world_size is the number of processes, typically set to the number of GPUs train_path: str, random_state: int, lr: float, epochs: int, num_workers: int, batch_size: int,) ->None:# ... Other setup# initialize the process group for distributed training init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, )# we need to divide the workers and batch across the different processes used in distributed training num_workers_per_proc = num_workers // world_size # avoids CPU contentionn batch_size_per_proc = batch_size // world_size # avoids OOM# DistributedSampler ensures that training data is chunked across GPUs without overlapping samples train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler( val_dataset, shuffle=False, # don't shuffle the validation dataset drop_last=True, # DistributedSampler will append additional samples to fill an incomplete batch. We don't want that for the validation dataset. ) train_dataloader = DataLoader( train_dataset, shuffle=False, # don't shuffle if using DistributedSampler as that's done within the sampler sampler=train_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups, ) val_dataloader = DataLoader( val_dataset, shuffle=False, sampler=val_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups, )# set up the NN model as normal and then wrap with DDP model.to(rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])for epch inrange(0, epochs):# need to call `set_epoch()` at the beginning of each epoch before creating the# `DataLoader` iterator to make shuffling work properly across multiple epochs train_sampler.set_epoch(epch)# ... training loop# ... diagnostics# cleanly shutdown distributed processes torch.distributed.destroy_process_group()if__name__=="__main__":import osimport torch os.environ["MASTER_ADDR"] ="localhost"# for single node on local compute os.environ["MASTER_PORT"] ="12345"# any free port world_size = torch.cuda.device_count() # number of GPUs# ... CLI args parsing# spawn multiple processes equal to world_size first argument passed in will be the rank torch.multiprocessing.spawn( main, args=( world_size, args.train_path, args.random_state, args.lr, args.epochs, args.num_workers, args.batch_size, ), nprocs=world_size, # this is used to set the `rank` parameter. It is passed as the first argument )```### 1. Communication: host and portSetting up distributed training on a single (local) node is as simple as setting the host and port as below. To setup multi-node see [torchrun](https://pytorch.org/tutorials/intermediate/ddp_series_multinode.html).```{python}# | label: host and port# | code-fold: showimport osos.environ["MASTER_ADDR"] ="localhost"# for single node on local computeos.environ["MASTER_PORT"] ="12345"# any free port```### 2. Spawn a process on each rank (GPU)Since `torch.multiprocessing` follows the same API as `multiprocessing`. To spawn a new process we pass the function to run, the arguments as a tuple, and specify the number of processes (usually the number of GPUs).```{python}# | label: spawn process on each rank# | code-fold: show# | eval: falseimport torch# spawn multiple processes equal to world_size first argument passed in will be the ranktorch.multiprocessing.spawn( main, args=( world_size, args.train_path, args.random_state, args.lr, args.epochs, args.num_workers, args.batch_size, ), nprocs=world_size, # this is used to set the `rank` parameter. It is passed as the first argument)```### 3. Constructing the process group- First, before initializing the group process, call [set_device](https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html?highlight=set_device#torch.cuda.set_device), which sets the default GPU for each process. This is important to prevent hangs or excessive memory utilization on GPU:0- The process group can be initialized by TCP (default) or from a shared file-system. Read more on [process group initialization](https://pytorch.org/docs/stable/distributed.html#tcp-initialization).- [init_process_group](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) initializes the distributed process group.- Read more about [choosing a DDP backend](https://pytorch.org/docs/stable/distributed.html#which-backend-to-use).```{python}# | label: initializing process group# | code-fold: show# | eval: falsefrom torch.distributed import init_process_group# initialize the process group for distributed traininginit_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size,)```### 5. Constructing the DDP model- `device_ids` - 1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None. 2) For multi-device modules and CPU modules, device_ids must be None. (From the [DDP docs](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html))```pythonmodel.to(rank)model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])```### 4. Distributing the data with `DistributedSampler`#### Dividing the workload- [`DistributedSampler`](https://pytorch.org/docs/stable/data.html?highlight=distributedsampler#torch.utils.data.distributed.DistributedSampler) chunks the input data across all distributed processes, without overlap. If we have 4 GPUs then each process will only load 1/4 of the training dataset.- The `batch_size` needs to be divided among the processes (GPUs). Each process will receive an input batch of `batch_size_per_proc`; the effective batch size is `batch_size_per_proc` * `world_size`, if the `batch_size` is 64 and `world_size` is 4 GPUs, then the effective batch size is still 64 in total.- The `num_workers` also needs to be divided among the processes (GPUs). Each proces will receive `num_workers_per_proc`.::: {.callout-caution title="`batch_size` and OOM"}If the batch_size isn't divided among the processes then then each process gets a full batch and the effective batch size is now x`world_size` larger and we are likely to run out of CPU or GPU memory if not careful.:::```{python}# | label: distributing workload across workers# | code-fold: show# | eval: false# we need to divide the workers and batch across the different processes used in distributed trainingnum_workers_per_proc = num_workers // world_size # avoids CPU contentionbatch_size_per_proc = batch_size // world_size # avoids OOM```#### Setting Up the DistributedSampler- `shuffle` - by default, the `DistributedSampler` will shuffle the dataset. We don't want to shuffle the validation dataset.- `drop_last` - by default, the `DistributedSampler` will append additional samples to fill an incomplete batch (e.g. there's 100 training samples with `batch_size=64` there would be one batch of 36 samples). We don't want to repeat samples for the validation dataset as that would change the metrics.- `pin_memory` - For large datasets that are loaded into the CPU in the `Dataset`, pinning the memory can speed up the host to device transfer (see this [NVIDIA blog](https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/#pinned_host_memory) for more details).```{python}# | label: distributed sampler# | code-fold: show# | eval: false# DistributedSampler ensures that training data is chunked across GPUs without overlapping samplestrain_sampler = DistributedSampler(train_dataset)val_sampler = DistributedSampler( val_dataset, shuffle=False, # don't shuffle the validation dataset drop_last=True, # DistributedSampler will append additional samples to fill an incomplete batch. We don't want that for the validation dataset.)train_dataloader = DataLoader( train_dataset, shuffle=False, # don't shuffle if using DistributedSampler as that's done within the sampler sampler=train_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups,)val_dataloader = DataLoader( val_dataset, shuffle=False, sampler=val_sampler, num_workers=num_workers_per_proc, batch_size=batch_size_per_proc, pin_memory=True, collate_fn=collate_rowgroups,)```#### 7. Shuffling across epochs- Calling the `set_epoch()` method on the `DistributedSampler` at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.```{python}# | label: shuffling across epochs# | code-fold: show# | eval: false# ... Neural Network setupfor epch inrange(0, epochs):# need to call `set_epoch()` at the beginning of each epoch before creating the# `DataLoader` iterator to make shuffling work properly across multiple epochs train_sampler.set_epoch(epch)# ... training loop```### 6. Running the distributed training job- `rank` is auto-allocated by DDP when calling [`torch.multiprocessing.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#spawning-subprocesses).- `world_size` is the number of processes across the training job. For GPU training, this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.- Both `rank` and `world_size` are new parameters to `main()`. Because of how spawning processes works, `rank` _needs_ to be the first parameter to the calling function, `main(rank, ...)`.::: {.callout-tip title="PyTorch Multiprocessing"}[PyTorch's `torch.multiprocessing` package](https://pytorch.org/docs/stable/multiprocessing.html) is a wrapper around the native `multiprocessing` module and the API is 100% compatible.:::```{python}# | label: mean and variance# | code-fold: show# | eval: falseimport torch# ... CLI args parsingworld_size = torch.cuda.device_count() # number of GPUstorch.multiprocessing.spawn( main, args=( world_size, args.train_path, args.random_state, args.lr, args.epochs, args.num_workers, args.batch_size, ), nprocs=world_size, # this is used to set the `rank` parameter. It is passed as the first argument)```### MLFlow loggingSince there are now multiple processes runnning the same code, the same logging will happen on each process. MLFlow doesn't know how to distinguish that there are different processes logging the same metric. We can guard against this by only logging on the main process (GPU0):```{python}# | label: MLFlow and torch.distributed# | code-fold: show# | eval: falseimport mlflowif rank ==0: mlflow.log_metrics( {"train loss": train_loss,"train accuracy": train_accuracy,"val loss": val_loss,"val accuracy": val_accuracy, }, step=epoch, )```## # Gradients, Losses, and MetricsUnder the hood, DDP synchronizes and gathers the gradients across all processes. However, any other ad-hoc value calculated in your code is not; e.g. losses and metrics.::: {.callout-note title="Gradients are synchronized"}Model gradients are synchronized across processes during the backward pass. This means that the model in each process is the same! [See DDP: Internal Design](https://pytorch.org/docs/master/notes/ddp.html#internal-design).:::::: {.callout-caution title="Losses are not synchronized"}Even though the model is the same in each process, the loss is calculated on only the portion of the batch that each process sees. The losses don't need to be synchronized for training but we may want to synchronize the losses for logging or definitelty when calculating metrics on the hold-out (validation) dataset.We could avoid this by _not_ using a `DistributedSampler` for the validation set, but then only 1 process would be used to calculate the loss for the whole validation set each epoch, which will be _slow_.:::So if each process is calculating and accumulating losses and metrics separately, how do we log those and report as if there were a single process? Well, those values will need to be gathered and then accumulated. Say we have 4 processes, one for each GPU, and each is processing 1/4 of the training dataset. We want to report the loss for each epoch. If we log the loss in each process, we will have 4 different losses. We can gather and combine them in a few ways. Since the loss is just a number value we can use [`torch.distributed.reduce`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.reduce) or [`torch.distributed.all_reduce`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce):### Example: `torch.distributed.reduce`In this example, we gather and combine using summation with the `dist.ReduceOp.SUM`, all the `loss_tensor`s into the process `0` tensor (`dst=0`). Each tensor in each process must be the same shape. Since we are assigning the values in each processes's `loss_tensor` to it's rank, we expect the final gathered values to be `0 + 1 + 2 + 3 = 6` in the main process `loss_tensor`.```{python}# | label: example `torch.distributed.reduce`# | code-fold: showimport osimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef reduce_tensor(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. loss_tensor = torch.tensor([rank, rank]).cuda()print(loss_tensor) dist.reduce(loss_tensor, op=dist.ReduceOp.SUM, dst=0, async_op=True)print(loss_tensor) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(reduce_tensor, nprocs=num_gpu, args=(num_gpu,))``````{text}tensor([1, 1], device='cuda:1')tensor([3, 3], device='cuda:3')tensor([2, 2], device='cuda:2')tensor([0, 0], device='cuda:0')tensor([6, 6], device='cuda:0')tensor([1, 1], device='cuda:1')tensor([2, 2], device='cuda:2')tensor([3, 3], device='cuda:3')```### Example: `torch.distributed.all_reduce`In this example, we gather and combine using summation with the `dist.ReduceOp.SUM`, all the `loss_tensor`s into all the processes. Each tensor in each process must be the same shape. Since we are assigning the values in each processes's `loss_tensor` to it's rank, we expect the final gathered values to be `0 + 1 + 2 + 3 = 6` in the all the processes's `loss_tensor`.```{python}# | label: example `torch.distributed.all_reduce`# | code-fold: showimport osimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef all_reduce_tensor(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. loss_tensor = torch.tensor([rank, rank]).cuda()print(loss_tensor) dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM, dst=0, async_op=True)print(loss_tensor) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(all_reduce_tensor, nprocs=num_gpu, args=(num_gpu,))``````{text}tensor([2, 2], device='cuda:2')tensor([3, 3], device='cuda:3')tensor([1, 1], device='cuda:1')tensor([0, 0], device='cuda:0')tensor([6, 6], device='cuda:2')tensor([6, 6], device='cuda:3')tensor([6, 6], device='cuda:1')tensor([6, 6], device='cuda:0')```### Example: `torch.distributed.gather_object`Syncing across processing is simple enough for tensors, but if we have a number of values to gather (say a bunch of metrics for example) it would be easier to only need to gather once and store those values in an appropriate data structure. Most of the distributed gathering function only work on tensors, but we can use [`torch.distributed.gather_object`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.gather_object) and / or [`torch.distributed.all_gather_object`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object) to pass pickleable Python objects between ranks.In this example, we want to track the losses and number of samples in a `Counter` so that we can combine and calculate the mean loss after gathering. We gather each loss counter to rank 0. Each counter is placed into the `gather_list`, which must have all elements set to `None` initially. When calling, `dist.gather_object`, the `gather_list` must only exist in the rank being gathered to (`dst=0` or rank 0 in this case). Then we use `functools.reduce` to sum all the Counters gathered in the `gather_list`.```{python}# | label: example `torch.distributed.gather_object`# | code-fold: showimport osimport operatorfrom collections import Counterimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef gather_object(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. losses = Counter( {"loss": 0.01, "num_samples": rank} )print(losses) gather_list = [Nonefor _ inrange(world_size)] dist.gather_object(losses, gather_list if rank ==0elseNone, dst=0) losses = functools.reduce(operator.add, gather_list)print(losses) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(gather_object, nprocs=num_gpu, args=(num_gpu,))``````{text}Counter({'num_samples': 1, 'loss': 0.01})Counter({'num_samples': 2, 'loss': 0.01})Counter({'num_samples': 4, 'loss': 0.01})Counter({'num_samples': 3, 'loss': 0.01})Counter({'num_samples': 10, 'loss': 0.04})Counter({'num_samples': 4, 'loss': 0.01})Counter({'num_samples': 3, 'loss': 0.01})Counter({'num_samples': 2, 'loss': 0.01})```### Example: `torch.distributed.all_gather_object`In this example, we want to track the losses and number of samples in a `Counter` so that we can combine and calculate the mean loss after gathering. We gather each loss counter to each rank. Each counter is placed into the `gather_list`, which must have all elements set to `None` initially. Then we use `functools.reduce` to sum all the Counters gathered in the `gather_list`.```{python}# | label: example `torch.distributed.all_gather_object`# | code-fold: showimport osimport operatorfrom collections import Counterimport torchimport torch.distributed as distimport torch.multiprocessing as mpfrom torch.distributed import init_process_groupdef all_gather_object(rank: int, world_size: int) ->None: init_process_group( backend="nccl"if torch.cuda.is_available() else"gloo", # CPU only works on gloo backend rank=rank, world_size=world_size, ) torch.cuda.set_device(rank) # tell each device (GPU) which one it is. losses = Counter( {"loss": 0.01, "num_samples": rank} )print(losses) gather_list = [Nonefor _ inrange(world_size)] dist.all_gather_object(gather_list, losses) losses = functools.reduce(operator.add, gather_list)print(losses) torch.distributed.destroy_process_group()if__name__=="__main__": os.environ["MASTER_ADDR"] ="localhost" os.environ["MASTER_PORT"] ="12345" num_gpu = torch.cuda.device_count() mp.spawn(all_gather_object, nprocs=num_gpu, args=(num_gpu,))``````{text}Counter({'num_samples': 1, 'loss': 0.01})Counter({'num_samples': 4, 'loss': 0.01})Counter({'num_samples': 2, 'loss': 0.01})Counter({'num_samples': 3, 'loss': 0.01})Counter({'num_samples': 10, 'loss': 0.04})Counter({'num_samples': 10, 'loss': 0.04})Counter({'num_samples': 10, 'loss': 0.04})Counter({'num_samples': 10, 'loss': 0.04})```## # Multi-Node TrainingNow that we've setup DDP, we have the option of running on multiple nodes. There are a few extra bits that need to be changed before we run [multinode training](https://pytorch.org/tutorials/intermediate/ddp_series_multinode.html):- change a few lines to enable [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html)- install the project as a python package for `torchrun` to work properly- add [fault tolerance](https://pytorch.org/tutorials/beginner/ddp_series_fault_tolerance.html)## # Resources1. [`DistributedDataParallel`](https://pytorch.org/docs/stable/notes/ddp.html) - Documentation for DDP.2. [Getting Started with Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) - Good starting point to understand DDP and writing a training script using DDP for single-node and multi-node.3. [Writing Distributed Applications with Pytorch](https://pytorch.org/tutorials/intermediate/dist_tuto.html) - In-depth article about writing distributed applications in PyTorch and how communication works under the hood.4. [`DistributedSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler) - Used in conjunction with a DataLoader, the DistributedSampler enables each process to only load the data it processes, rather than all the data to be processed.5. [DistributedDataParallel training in PyTorch](https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html) - Explaination of how DDP works and how to use it.6. [(AML) Distributed GPU training guide (SDK v2)](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-train-distributed-gpu?view=azureml-api-2)7. [PyTorch DistributedDataParallel Example In Azure ML - Multi-Node Multi-GPU Distributed Training](https://ochzhen.com/blog/pytorch-distributed-data-parallel-azure-ml) - Example of DDP for AML.8. [PyTorch Distributed: Experiences on Accelerating Data Parallel Training](https://arxiv.org/pdf/2006.15704) - Paper on how Facebook designed DDP to be faster for distributed training.