Hello,
I hope you are doing well. While using your codebase and building on it for a while now, I have just discovered a critical bug in the way that the nerv toolkit implements multi gpu training with Distributed Data Parallel. I am opening this issue here because this is critical for the results of the paper with the parameters provided.
Summary: DDP is not implemented correctly. There are two main problems, one major and one minor:
- The models on each GPU do not communicate with each other, and are trained differently and diverge, which is wrong behavior. In essence, the result when training using 8 GPUs for 80 epochs, is that every GPU trains the starting model differently, and performs gradient calculation and model updates separately, based on the input batch for the current GPU. The result is 8 separate models, each trained on the equivalent of roughly 10 epochs.
- When training for an epoch, the indices across the GPUs are not mutually exclusive, so technically we get slightly different behavior from single GPU training. When using torch distributed sampler properly, each GPU gets a unique subset of the data that does not intersect with any of the other gpus. Although this does not affect performance significantly as the models still see the entirety of the data on average, it be might good to keep that in mind.
Multi GPU training issue
The main issue arises in the forward pass logic here.
When a model is wrapped in DDP module, the DDP module handles the forward and backward pass across the GPUs such that the gradients are synced during the background pass. However, for this to function correctly, the forward pass must be called directly using the DDP forward specifically.
In the code, the forward pass is done by bypassing this forward and accessing the inner variable module in order to access the wrapped model. This completely skips the gradient synchronization and thus leads to each GPU calculating the gradient for only its own batch and updating accordingly. This can be easily tested by comparing the parameters across the different ranks, and they diverge to a large extent as the training goes. In essence, every GPU trains a separate model independently. This does not affect the correctness, because each model is correctly trained, and in the end of the training the model on rank 0/gpu 0 is saved.
However: what this means is that the computation of all GPUs other than the first GPU are completely wasted, and that in every epoch, the model on GPU 0 is trained on 1/N of the data, where N is the number of GPUs. As a result, increasing the number of GPUs means the model is trained using less data each epoch compared to using a single GPU.
After 5 epochs, I checked the difference in L2 norms of each parameter between the main GPU and other ranks, and here are the largest differences, which I measured using my toolkit here just for convenience:
----------------------------------------------------------------------
| Parameter | Ranks compared | L2 norm of difference |
----------------------------------------------------------------------
|module.decoder.3.0.weight| 0 <-> 3 | 2.4202e+01 |
|module.decoder.3.0.weight| 0 <-> 2 | 2.3539e+01 |
|module.decoder.3.0.weight| 0 <-> 1 | 2.2490e+01 |
|module.decoder.0.0.weight| 0 <-> 2 | 1.5555e+01 |
|module.decoder.0.0.weight| 0 <-> 1 | 1.5359e+01 |
|module.decoder.0.0.weight| 0 <-> 3 | 1.5071e+01 |
|module.decoder.2.0.weight| 0 <-> 2 | 1.4756e+01 |
|module.decoder.2.0.weight| 0 <-> 3 | 1.4714e+01 |
|module.decoder.2.0.weight| 0 <-> 1 | 1.4555e+01 |
|module.encoder.2.0.weight| 0 <-> 1 | 1.4034e+01 |
----------------------------------------------------------------------
As you can see, the L2 norm becomes massive after a while, meaning the models are completely different on each GPU, which is not desired behavior.
The solution to this problem is to reword the loss calculation logic in DDP, such the the wrapped model is a model that returns the desired loss to be optimized directly in the forward pass, so that calling
loss = ddp_wrapped_model(batch)
can be used as is. This ensures the syncing is happening and that the training signal is from all the available GPUs.
If this issue is fixed, this means that we can train the models in this project much faster with multiple GPUs than currently, which should be a very strong improvement.
Moreover, this also means that results reported for X epochs on N gpus should be roughly equivalent to training using 1 GPU for X/N epochs, assuming the last checkpoint is taken in both cases. This might imply either the models that were trained with multiple GPUs are either 1) undertrained, and can perform better or 2) trained enough, but could be trained much faster with proper multigpu support.
Multi GPU data sampling issue
This is a relatively minor issue but I still wanted to bring attention to it. In the distributed sampler, here is the logic of generating the indices on each GPU:
indices = torch.randperm(len(self.dataset))
which are then divided into partitions, one for each GPU here:
indices = indices[self.rank:self.total_size:self.num_replicas]
However, the indices generated on each GPU are completely independent and not the same, which leads to the subsets of indices on each GPU overlapping.
The best way to solve this, which is how it is implemented in the official torch repository as follows:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
In this arrangement, the generation of the indices on every GPU is fixed by the seed and the current epoch, which must be explicitly set using sampler.set_epoch at the beginning of each epoch before iterating over the data, and each GPU gets a non-overlapping subset of the data.
I checked the possible overlapping of the indices using this method, and found such overlap in all epochs I tested in your training pipeline, unlike the torch sampler.
In the end, this is not a critical bug like the first one, but it does mean that the sampling is different in cases of single GPU (every GPU sees the entire dataset) vs multi GPU (every GPU sees 1/N randomly sampled of the data with replacement, so all GPUs together don't see the entire dataset (probabilistically very unlikely)). It might not matter much if at all for final performance.
Since I have been using this code for some experiments, please let me know if you need any help in addressing the core of the issue, would be happy to connect and explain in more detail/assist in any way I can.