Git Product home page Git Product logo

forecast-mae's Introduction


Forecast-MAE: Self-supervised Pre-training for Motion Forecasting with Masked Autoencoders

Jie Cheng1    Xiaodong Mei1    Ming Liu1,2   
HKUST1    HKUST(GZ)1,2   

arXiv PDF

Highlight

  • A neat yet effective MAE-based pre-training scheme for motion forecasting.
  • A pretty simple forecasting model (basically pure transformer encoders) with relative good performance.

Getting Started

Setup Environment

1. Clone this repository:

git clone https://github.com/jchengai/forecast-mae.git
cd forecast-mae

2. Setup conda environment:

conda create -n forecast_mae python=3.8
conda activate forecast_mae
sh ./scripts/setup.sh

3. Setup Argoverse 2 Motion Forecasting Dataset, the expected data structure should be:

data_root
    ├── train
    │   ├── 0000b0f9-99f9-4a1f-a231-5be9e4c523f7
    │   ├── 0000b6ab-e100-4f6b-aee8-b520b57c0530
    │   ├── ...
    ├── val
    │   ├── 00010486-9a07-48ae-b493-cf4545855937
    │   ├── 00062a32-8d6d-4449-9948-6fedac67bfcd
    │   ├── ...
    ├── test
    │   ├── 0000b329-f890-4c2b-93f2-7e2413d4ca5b
    │   ├── 0008c251-e9b0-4708-b762-b15cb6effc27
    │   ├── ...

Preprocess

(recommend) By default, we use ray and 16 cpu cores for preprocessing. It will take about 30 minutes to finish.

python3 preprocess --data_root=/path/to/data_root -p

or you can disable parallel preprocessing by removing -p.

Training

  • For single-card training, remove gpus=4 in the following commands. batch_size refers to the batch size of each GPU.
  • If you use WandB, you can enable wandb logging by adding option wandb=online.

1. Pre-training + fine-tuning

phase 1 - pre-training:

python3 train.py data_root=/path/to/data_root model=model_mae gpus=4 batch_size=32

phase 2 - fine-tuning:

(Note that quotes in 'pretrained_weights="/path/to/pretrain_ckpt"' are necessary)

python3 train.py data_root=/path/to/data_root model=model_forecast gpus=4 batch_size=32 monitor=val_minFDE 'pretrained_weights="/path/to/pretrain_ckpt"'

2. Training from scratch

python3 train.py data_root=/path/to/data_root model=model_forecast gpus=4 batch_size=32 monitor=val_minFDE

Evaluation

Evaluate on the validation set

python3 eval.py data_root=/path/to/data_root batch_size=64 'checkpoint="/path/to/checkpoint"'

Generate submission for the AV2 single-agent motion forecasting benchmark

python3 eval.py data_root=/path/to/data_root batch_size=64 'checkpoint="/path/to/checkpoint"' test=true

Results and checkpoints

For this repository, the expected performance on Argoverse 2 validation set is:

Models minADE1 minFDE1 minADE6 minFDE6 MR6
Forecast-MAE (scratch) 1.802 4.529 0.8104 1.430 0.187
Forecast-MAE (fine-tune) 1.744 4.37 0.7984 1.408 0.178

You can download the checkpoints with the corresponding link.

Qualitative Results

demo

Acknowledgements

This repo benefits from MAE, Point-BERT, Point-MAE, NATTEN and HiVT. Thanks for their great works.

Citation

If you found this repository useful, please consider citing our work:

@article{cheng2023forecast,
  title={{Forecast-MAE}: Self-supervised Pre-training for Motion Forecasting with Masked Autoencoders},
  author={Cheng, Jie and Mei, Xiaodong and Liu, Ming},
  journal={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year={2023}
}

forecast-mae's People

Contributors

jchengai avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.