Git Product home page Git Product logo

pose-ae-train's Introduction

Associative Embedding: Training Code

Multi-person pose estimation with PyTorch based on:

Associative Embedding: End-to-end Learning for Joint Detection and Grouping. Alejandro Newell, Zhiao Huang, and Jia Deng. Neural Information Processing Systems (NIPS), 2017.

(A pretrained model in TensorFlow is also available here: https://github.com/umich-vl/pose-ae-demo)

Getting Started

This repository provides everything necessary to train and evaluate a multi-person pose estimation model on COCO keypoints. If you plan on training your own model from scratch, we highly recommend using multiple GPUs. We also provide a pretrained model.

Requirements:

  • Python 3 (code has been tested on Python 3.6)
  • PyTorch
  • CUDA and cuDNN
  • Python packages (not exhaustive): opencv-python, cffi, munkres, tqdm, json

Before using the repository there are a couple of setup steps:

First, you must compile the C implementation of the associative embedding loss. Go to extensions/AE/ and call python build.py install. If you run into errors with missing include files for CUDA, this can be addressed by first calling export CPATH=/path/to/cuda/include.

Next, set up the COCO dataset. You can download it from here, and update the paths in data/coco_pose/ref.py to the correct directories for both images and annotations. After that, make sure to install the COCO PythonAPI from here.

You should be all set after that! For reference, the code is organized as follows:

  • data/: data loading and data augmentation code
  • models/: network architecture definitions
  • task/: task-specific functions and training configuration
  • utils/: image processing code and miscellaneous helper functions
  • extensions/: custom C code that needs to be compiled
  • train.py: code for model training
  • test.py: code for model evaluation

Training and Testing

To train a network, call:

python train.py -e test_run_001 (-e,--exp allows you to specify an experiment name)

To continue an experiment where it left off, you can call:

python train.py -c test_run_001

All training hyperparameters are defined in task/pose.py, and you can modify __config__ to test different options. It is likely you will have to change the batchsize to accommodate the number of GPUs you have available.

Once a model has been trained, you can evaluate it with:

python test.py -c test_run_001 -m [single|multi]

The argument -m,--mode indicates whether to do single- or multi-scale evaluation. Single scale evaluation is faster, but multiscale evaluation is responsible for large gains in performance. You can edit test.py to evaluate at more scales for further improvements.

Training/Validation split

This repository includes a predefined training/validation split that we use in our experiments, data/coco_pose/valid_id lists all images used for validation.

Pretrained model

To evaluate on the pretrained model, you can download it from here and unpack the file into exp/. Then call:

python test.py -c pretrained -m single

That should return a mAP of about 0.59 for single scale evaluation, and .66 for multiscale (performance can be improved further by evaluating at more than the default 3 scales). Results will not necessarily be the same on the COCO test sets.

To use this model for your own images, you can set up code to pass your own data to the multiperson function in test.py.

pose-ae-train's People

Contributors

anewell avatar

Watchers

Zhangxuan Gu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.