Git Product home page Git Product logo

image-captioning's Introduction

Overview

This repository contains PyTorch implementations of Show and Tell: A Neural Image Caption Generator and Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. These models were among the first neural approaches to image captioning and remain useful benchmarks against newer models.

drawing

Installation

The code was written for Python 3.6 or higher, and it has been tested with PyTorch 0.4.1. Training is only available with GPU. To get started, try to clone the repository

git clone https://github.com/tangbinh/image-captioning
cd image-captioning

Preprocessing

First, you need to download images and captions from the COCO website. By default, we use train2014, val2014, val 2017 for training, validating, and testing, respectively. The data directory should have the following structure:

.
├── annotations
│   ├── captions_train2014.json
│   ├── captions_val2014.json
│   └── captions_val2017.json
└── images
    ├── train2014
    │   └── COCO_train2014_000000000092.jpg
    ├── val2014
    │   └── COCO_val2014_000000000042.jpg
    └── val2017
        └── 000000000139.jpg

Once all the annotations and images are downloaded to, say, DATA_DIR, you can run the following command to map caption words into indices in a dictionary and extract image features from a pretrained VGG19 network:

python preprocess.py --data $DATA_DIR --dest-dir $DEST_DIR

Note that the resulting directory DEST_DIR will be quite large; the features for training and validation images take up 157GB and 77GB already. Experiments with HDF5 shows that there's a significant slowdown due to concurrent access with multiple data workers (see this discussion and this note). Hence, the preprocessing script saves CNN features of different images into separate files.

Then, the following commands help build dictionaries and map tokens into indices:

DATA_PATH=data/iwslt14.tokenized.de-en
python train.py --arch show_attend_tell --data /local/storage/bvt5/data/coco/caption-vgg --save-dir checkpoints/show_attend_tell
python preprocess.py --source-lang de --target-lang en --train-prefix $DATA_PATH/train --valid-prefix $DATA_PATH/valid --test-prefix $DATA_PATH/test --dest-dir data-bin/iwslt14.tokenized.de-en

Training

To get started with training a model on SQuAD, you might find the following commands helpful:

python train.py --arch show_attend_tell --data $DEST_DIR --save-dir checkpoints/show_attend_tell --log-file logs/show_attend_tell.log

The show-attend-tell model results in a validation loss of 2.761 after the first epoch. The loss decreases to 2.298 after 20 epochs and shows no lower values than 2.266 after 50 epochs. Although the implementations doesn't support fine-tuning the CNN network, the feature can be added quite easily and probably yields better performance.

Prediction

When the training is done, you can make predictions with the test dataset and compute BLEU scores:

python generate.py --checkpoint-path checkpoints/show_attend_tell/checkpoint_best.pt > /tmp/show_attend_tell.out
grep ^H /tmp/show_attend_tell.out | cut -f2- | sed -r 's/'$(echo -e "\033")'\[[0-9]{1,2}(;([0-9]{1,2})?)?[mK]//g' > /tmp/show_attend_tell.sys
grep ^T /tmp/show_attend_tell.out | cut -f2- | sed -r 's/'$(echo -e "\033")'\[[0-9]{1,2}(;([0-9]{1,2})?)?[mK]//g' > /tmp/show_attend_tell.ref
python score.py --reference /tmp/show_attend_tell.ref --system /tmp/show_attend_tell.sys

Visualization

To display generated captions alongside their corresponding images, run the following command:

python visualize.py --checkpoint-path checkpoints/show_attend_tell/checkpoint_best.pt --coco-path $DATA_DIR

image-captioning's People

Contributors

tangbinh avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.