Git Product home page Git Product logo

slotformer's Introduction

SlotFormer

SlotFormer: Unsupervised Visual Dynamics Simulation with Object-Centric Models
Ziyi Wu, Nikita Dvornik, Klaus Greff, Thomas Kipf, Animesh Garg
ICLR'23 | GitHub | arXiv | Project page

Ground-Truth        Our Prediction Ground-Truth        Our Prediction
image image

Introduction

This is the official PyTorch implementation for paper: SlotFormer: Unsupervised Visual Dynamics Simulation with Object-Centric Models, which is accepted by ICLR 2023. The code contains:

  • Training base object-centric slot models
  • Video prediction task on OBJ3D and CLEVRER
  • VQA task on CLEVRER
  • VQA task on Physion
  • Planning task on PHYRE

Update

  • 2023.9.20: BC-breaking change! We fix an error in the mIoU calculation code. This won't change the order of benchmarked methods, but will change their absolute values. See this PR for more details. Please re-run the evaluation code on your trained models to get the correct results. The updated mIoU of SlotFormer on CLEVRER is 49.42 (using the provided pre-trained weight)
  • 2023.1.20: The paper is accepted by ICLR 2023!
  • 2022.10.26: Support Physion VQA task and PHYRE planning task
  • 2022.10.16: Initial code release!
    • Support base object-centric model training
    • Support SlotFormer training
    • Support evaluation on the video prediction task
    • Support evaluation on the CLEVRER VQA task

Installation

Please refer to install.md for step-by-step guidance on how to install the packages.

Experiments

This codebase is tailored to Slurm GPU clusters with preemption mechanism. For the configs, we mainly use RTX6000 with 24GB memory (though many experiments don't require so much memory). Please modify the code accordingly if you are using other hardware settings:

  • Please go through scripts/train.py and change the fields marked by TODO:
  • Please read the config file for the model you want to train. We use DDP with multiple GPUs to accelerate training. You can use less GPUs to achieve a better memory-speed trade-off

Dataset Preparation

Please refer to data.md for steps to download and pre-process each dataset.

Reproduce Results

Please see benchmark.md for detailed instructions on how to reproduce our results in the paper.

Citation

Please cite our paper if you find it useful in your research:

@article{wu2022slotformer,
  title={SlotFormer: Unsupervised Visual Dynamics Simulation with Object-Centric Models},
  author={Wu, Ziyi and Dvornik, Nikita and Greff, Klaus and Kipf, Thomas and Garg, Animesh},
  journal={arXiv preprint arXiv:2210.05861},
  year={2022}
}

Acknowledgement

We thank the authors of Slot-Attention, slot_attention.pytorch, SAVi, RPIN and Aloe for opening source their wonderful works.

License

SlotFormer is released under the MIT License. See the LICENSE file for more details.

Contact

If you have any questions about the code, please contact Ziyi Wu [email protected]

slotformer's People

Contributors

johannestheo avatar wuziyi616 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

slotformer's Issues

DDP settings

Hi there,

during implementing your great research code, I faced a problem that the DDP setting doesn't work.

I tried to implement SAVi model training and SlotFormer training for the OBJ3D and followed your obj3d.mdinstructions.
I simply copied and pasted the code given in obj3d.md, however failed to train models in multi GPUs.
It did run but only in a single GPU.
(I found out that the extract_slot code runs in multi GPUs.)

Could you give me a hand?

Weird bug about module import

Hi there,

during evaluating slotformer in video prediction on OBJ3D, I met a weird bug. I built the environment following your instructions in install.md with environment.yml and set up nerv. I even tried it on different computers, and it still raises the same error.

When I run python slotformer/video_prediction/test_vp.py, I get this error:

  File ".../SlotFormer/slotformer/video_prediction/vp_utils.py", line 13, in <module>
    from slotformer.base_slots.models import to_rgb_from_tensor
ModuleNotFoundError: No module named 'slotformer'

I checked the import of slotformer in your code and it seems correct to me. I tried to type in the same import code in PyCharm and the import of slotformer can be auto-completed.

Could you give me a hand?

CRITICAL BUG: Multi GPU training possibly not implemented correctly, resulting models potentially under-trained

Hello,

I hope you are doing well. While using your codebase and building on it for a while now, I have just discovered a critical bug in the way that the nerv toolkit implements multi gpu training with Distributed Data Parallel. I am opening this issue here because this is critical for the results of the paper with the parameters provided.

Summary: DDP is not implemented correctly. There are two main problems, one major and one minor:

  • The models on each GPU do not communicate with each other, and are trained differently and diverge, which is wrong behavior. In essence, the result when training using 8 GPUs for 80 epochs, is that every GPU trains the starting model differently, and performs gradient calculation and model updates separately, based on the input batch for the current GPU. The result is 8 separate models, each trained on the equivalent of roughly 10 epochs.
  • When training for an epoch, the indices across the GPUs are not mutually exclusive, so technically we get slightly different behavior from single GPU training. When using torch distributed sampler properly, each GPU gets a unique subset of the data that does not intersect with any of the other gpus. Although this does not affect performance significantly as the models still see the entirety of the data on average, it be might good to keep that in mind.

Multi GPU training issue

The main issue arises in the forward pass logic here.

When a model is wrapped in DDP module, the DDP module handles the forward and backward pass across the GPUs such that the gradients are synced during the background pass. However, for this to function correctly, the forward pass must be called directly using the DDP forward specifically.

In the code, the forward pass is done by bypassing this forward and accessing the inner variable module in order to access the wrapped model. This completely skips the gradient synchronization and thus leads to each GPU calculating the gradient for only its own batch and updating accordingly. This can be easily tested by comparing the parameters across the different ranks, and they diverge to a large extent as the training goes. In essence, every GPU trains a separate model independently. This does not affect the correctness, because each model is correctly trained, and in the end of the training the model on rank 0/gpu 0 is saved.

However: what this means is that the computation of all GPUs other than the first GPU are completely wasted, and that in every epoch, the model on GPU 0 is trained on 1/N of the data, where N is the number of GPUs. As a result, increasing the number of GPUs means the model is trained using less data each epoch compared to using a single GPU.

After 5 epochs, I checked the difference in L2 norms of each parameter between the main GPU and other ranks, and here are the largest differences, which I measured using my toolkit here just for convenience:

----------------------------------------------------------------------
|        Parameter        | Ranks compared |  L2 norm of difference  |
----------------------------------------------------------------------
|module.decoder.3.0.weight|    0 <-> 3     |        2.4202e+01       |
|module.decoder.3.0.weight|    0 <-> 2     |        2.3539e+01       |
|module.decoder.3.0.weight|    0 <-> 1     |        2.2490e+01       |
|module.decoder.0.0.weight|    0 <-> 2     |        1.5555e+01       |
|module.decoder.0.0.weight|    0 <-> 1     |        1.5359e+01       |
|module.decoder.0.0.weight|    0 <-> 3     |        1.5071e+01       |
|module.decoder.2.0.weight|    0 <-> 2     |        1.4756e+01       |
|module.decoder.2.0.weight|    0 <-> 3     |        1.4714e+01       |
|module.decoder.2.0.weight|    0 <-> 1     |        1.4555e+01       |
|module.encoder.2.0.weight|    0 <-> 1     |        1.4034e+01       |
----------------------------------------------------------------------

As you can see, the L2 norm becomes massive after a while, meaning the models are completely different on each GPU, which is not desired behavior.

The solution to this problem is to reword the loss calculation logic in DDP, such the the wrapped model is a model that returns the desired loss to be optimized directly in the forward pass, so that calling

loss = ddp_wrapped_model(batch)

can be used as is. This ensures the syncing is happening and that the training signal is from all the available GPUs.

If this issue is fixed, this means that we can train the models in this project much faster with multiple GPUs than currently, which should be a very strong improvement.

Moreover, this also means that results reported for X epochs on N gpus should be roughly equivalent to training using 1 GPU for X/N epochs, assuming the last checkpoint is taken in both cases. This might imply either the models that were trained with multiple GPUs are either 1) undertrained, and can perform better or 2) trained enough, but could be trained much faster with proper multigpu support.

Multi GPU data sampling issue

This is a relatively minor issue but I still wanted to bring attention to it. In the distributed sampler, here is the logic of generating the indices on each GPU:

indices = torch.randperm(len(self.dataset))

which are then divided into partitions, one for each GPU here:

indices = indices[self.rank:self.total_size:self.num_replicas]

However, the indices generated on each GPU are completely independent and not the same, which leads to the subsets of indices on each GPU overlapping.

The best way to solve this, which is how it is implemented in the official torch repository as follows:

g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() 

In this arrangement, the generation of the indices on every GPU is fixed by the seed and the current epoch, which must be explicitly set using sampler.set_epoch at the beginning of each epoch before iterating over the data, and each GPU gets a non-overlapping subset of the data.

I checked the possible overlapping of the indices using this method, and found such overlap in all epochs I tested in your training pipeline, unlike the torch sampler.

In the end, this is not a critical bug like the first one, but it does mean that the sampling is different in cases of single GPU (every GPU sees the entire dataset) vs multi GPU (every GPU sees 1/N randomly sampled of the data with replacement, so all GPUs together don't see the entire dataset (probabilistically very unlikely)). It might not matter much if at all for final performance.

Since I have been using this code for some experiments, please let me know if you need any help in addressing the core of the issue, would be happy to connect and explain in more detail/assist in any way I can.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.