Git Product home page Git Product logo

graphclip-tmp's Introduction

CLIPTrans

[Project page] [arxiv]

Training Pipeline GIF Official implementation for the paper "CLIPTrans: Transferring Visual Knowledge with Pre-trained Models for Multimodal Machine Translation", published at ICCV'23. The aim of the paper is to leverage existing pre-trained models(multilingual mBART and multimodal M-CLIP) for multimodal machine translation. More generally, it proposes a framework(pictured above) for multilingual generative tasks using multimodal data.

Setup

Setup the repository with the following commands:

git clone --recursive https://github.com/xjiaf/GraphCLIP.git
cd GraphCLIP
conda env create --file environment.yml
conda activate graphclip
pip install -r requirements.txt
python -m spacy download en
cd transformers
pip install -e .

Data

All data should be organised in the data/ directory.

Multi30k

Download the images for Flickr30k and upzip it in the data/multi30k folder.

Then, simply execute setup_multi30k.sh to download the text data and organise the folders.

WIT

The WIT dataset used in our paper can be downloaded from here into the data/wit folder. Once downloaded, unzip with tar -xvzf wit_mmt.tar.gz. The images for each of these can be downloaded by running the following command:

python download_images.py $FOLDER_NAME $SPLIT

where FOLDER_NAME can be one of [es_fr, en_ro, en_fr, en_es, en_de, en_af, de_es] and SPLIT can be one of [train, test, valid]. This will take a while. Also note that the downloading of images depends upon the availability of the image on the hosted service; due to which some variance in the scores is to be expected.

Pre-trained Models and Inference

The sharded pre-trained models can be found here. For a particular model, download both the shards and place them in models/multi30k-en-de(for example). Check the second point below for more details on the flags and the naming scheme. While loading these models, the code will automatically detect the shards and use them. Inference can be run with the command:

python src/main.py --num_gpus 1 --mn multi30k --src_lang en --tgt_lang fr --prefix_length 10 --bs 32 --test_ds 2016 flickr --stage translate --test --lm model_best_test.pth

or for a mulit-GPU setup:

python -m torch.distributed.run --nproc_per_node 4 src/main.py --num_gpus 4  --mn multi30k --src_lang en --tgt_lang fr --prefix_length 10 --bs 32 --test_ds 2016 flickr --stage translate --test --lm model_best_test.pth

Training

Training is done in two stages. To run the first stage(captioning), the following commands can be used, depending on the number of available GPUs:

python src/main.py --num_gpus 1 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage caption --tgt_lang fr
python -m torch.distributed.run --nproc_per_node 4 src/main.py --num_gpus 4 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage caption --tgt_lang fr

For stage 2, use the following commands:

python src/main.py --num_gpus 1 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage translate --tgt_lang fr --lm model_pretrained.pth
python -m torch.distributed.run --nproc_per_node 4 src/main.py --num_gpus 4 --mn multi30k --prefix_length 10 --bs 32 --update_count 4 --lr 1e-5 --test_ds 2016 val --stage translate --tgt_lang fr --lm model_pretrained.pth

Flags

Here is a quick guide to some specifics about the flags:

  1. --stage denotes the training task. There are four choices available which are detailed in the table below. These affect the training task, and inference is modified appropriately:
--stage CLIP input mBART input mBART output
caption image trivial image caption
translate source text source text target text
text_recon source text trivial source text
triplet image source text target text
  1. --mn sets the name of the job. It is used to create a unique model folder where the weights are stored and can be loaded from. The source and target language are appended to this name. It must remain uniform across stage 1 and stage 2.
  2. --lm is the name of the weights file to be loaded(which gets saved in the aforementioned folder). For final results, load model_best_test.pth. Stage 1 models are saved as model_pretrained.pth. To continue training from a saved model, load its weights and add the flag --ct.
  3. --test_ds sets the dataset to be used in validation/test. While training, pass 2016 val(Multi30k) or valid(WIT). For inference, pass 2016 flickr, 2017 flickr or 2017 mscoco for Multi30k and test for WIT. Also add the flag --test so that only inference is run for a saved model.
  4. To finetune an mBART on a dataset, simply pass --prefix_length 0.
  5. To use images in inference, add the flags --noise_test --mask_prob 0 to the inference command.

If the code and/or method was useful to your work, please consider citing us!

@inproceedings {gupta2023cliptrans,
    title={CLIPTrans: Transferring Visual Knowledge with Pre-trained Models for Multimodal Machine Translation},
    author={Gupta, Devaansh and Kharbanda, Siddhant and Zhou, Jiawei and Li, Wanhua and Pfister, Hanspeter and Wei, Donglai},
    booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
    year={2023}
}

graphclip-tmp's People

Contributors

devaansh100 avatar xjiaf avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.