Git Product home page Git Product logo

hbi's Introduction

【CVPR'2023 Highlight🔥】Video-Text as Game Players: Hierarchical Banzhaf Interaction for Cross-Modal Representation Learning

Conference Project Paper

The implementation of CVPR 2023 Highlight (Top 10%) paper Video-Text as Game Players: Hierarchical Banzhaf Interaction for Cross-Modal Representation Learning.

In this paper, we creatively model video-text as game players with multivariate cooperative game theory to wisely handle the uncertainty during fine-grained semantic interaction with diverse granularity, flexible combination, and vague intensity.

📌 Citation

If you find this paper useful, please consider staring 🌟 this repo and citing 📑 our paper:

@inproceedings{jin2023video,
  title={Video-text as game players: Hierarchical banzhaf interaction for cross-modal representation learning},
  author={Jin, Peng and Huang, Jinfa and Xiong, Pengfei and Tian, Shangxuan and Liu, Chang and Ji, Xiangyang and Yuan, Li and Chen, Jie},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={2472--2482},
  year={2023}
}
💡 I also have other text-video retrieval projects that may interest you ✨.

DiffusionRet: Generative Text-Video Retrieval with Diffusion Model
Accepted by ICCV 2023 | [DiffusionRet Code]
Peng Jin, Hao Li, Zesen Cheng, Kehan Li, Xiangyang Ji, Chang Liu, Li Yuan, Jie Chen

Expectation-Maximization Contrastive Learning for Compact Video-and-Language Representations
Accepted by NeurIPS 2022 | [EMCL Code]
Peng Jin, Jinfa Huang, Fenglin Liu, Xian Wu, Shen Ge, Guoli Song, David Clifton, Jie Chen

Text-Video Retrieval with Disentangled Conceptualization and Set-to-Set Alignment
Accepted by IJCAI 2023 | [DiCoSA Code]
Peng Jin, Hao Li, Zesen Cheng, Jinfa Huang, Zhennan Wang, Li Yuan, Chang Liu, Jie Chen

📣 Updates

  • [2023/10/15]: We release our pre-trained estimator weights. If you want to apply a to other tasks, you can initialize a new estimator with the weights we provide. If you want better performance, you can train the estimator with a smaller learning rate and more epochs.
  • [2023/10/11]: We release code for Banzhaf Interaction estimator. Recommended running parameters will be provided shortly, and we will also release our pre-trained estimator weights.
  • [2023/10/08]: I am working on the code for Banzhaf Interaction estimator, which is expected to be released soon.
  • [2023/06/28]: Release code for reimplementing the experiments in the paper.
  • [2023/03/28]: Our HBI has been selected as a Highlight paper at CVPR 2023! (Top 2.5% of 9155 submissions).
  • [2023/02/28]: We will release the code asap. (I am busy with other DDLs. After that, I will open the source code as soon as possible. Please understand.)

⚡ Demo

demo_github.mp4

😍 Visualization

Example 1

More examples

Example 2

Example 3

Example 4

Example 5

Example 6

Example 7

🚀 Quick Start

Setup

Setup code environment

conda create -n HBI python=3.9
conda activate HBI
pip install -r requirements.txt
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html

Download CLIP Model

cd HBI/models
wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
# wget https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
# wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt

Download Datasets

Datasets Google Cloud Baidu Yun Peking University Yun
MSR-VTT Download Download Download
MSVD Download Download Download
ActivityNet TODO Download Download
DiDeMo TODO Download Download

Train the Banzhaf Interaction Estimator

Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py.

The training code is provided in banzhaf_estimator.py. We provide our trained weights, and if you want to apply a to other tasks, you can initialize a new estimator with the weights we provide.

We have tested the performance of Estimator_1e-2_epoch6 with R@1 of 48.2 (log) on the MSR-VTT dataset. If you want better performance, you can train the estimator with a smaller learning rate and more epochs.

Models Google Cloud Baidu Yun Peking University Yun log
Estimator_1e-2_epoch1 Download Download Download log
Estimator_1e-2_epoch2 Download Download Download log
Estimator_1e-2_epoch3 Download Download Download log
Estimator_1e-2_epoch4 Download Download Download log
Estimator_1e-2_epoch5 Download Download Download log
Estimator_1e-2_epoch6 Download Download Download log
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=4 \
banzhaf_estimator.py \
--do_train 1 \
--workers 8 \
--n_display 1 \
--epochs 10 \
--lr 1e-2 \
--coef_lr 1e-3 \
--batch_size 128 \
--batch_size_val 128 \
--anno_path data/MSR-VTT/anns \
--video_path ${DATA_PATH}/MSRVTT_Videos \
--datatype msrvtt \
--max_words 24 \
--max_frames 12 \
--video_framerate 1 \
--output_dir ${OUTPUT_PATH} 

Text-video Retrieval

Checkpoint Google Cloud Baidu Yun Peking University Yun
MSR-VTT Download Download Download
ActivityNet Download Download Download

Eval on MSR-VTT

CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=2 \
main_retrieval.py \
--do_eval 1 \
--workers 8 \
--n_display 50 \
--batch_size_val 128 \
--anno_path data/MSR-VTT/anns \
--video_path ${DATA_PATH}/MSRVTT_Videos \
--datatype msrvtt \
--max_words 24 \
--max_frames 12 \
--video_framerate 1 \
--init_model ${CHECKPOINT_PATH} \
--output_dir ${OUTPUT_PATH} 

Train on MSR-VTT

CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=2 \
main_retrieval.py \
--do_train 1 \
--workers 8 \
--n_display 50 \
--epochs 5 \
--lr 1e-4 \
--coef_lr 1e-3 \
--batch_size 128 \
--batch_size_val 128 \
--anno_path data/MSR-VTT/anns \
--video_path ${DATA_PATH}/MSRVTT_Videos \
--datatype msrvtt \
--max_words 24 \
--max_frames 12 \
--video_framerate 1 \
--estimator ${ESTIMATOR_PATH} \
--output_dir ${OUTPUT_PATH} \
--kl 2 \
--skl 1

Eval on ActivityNet Captions

CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=2 \
main_retrieval.py \
--do_eval 1 \
--workers 8 \
--n_display 50 \
--batch_size_val 128 \
--anno_path ${DATA_PATH}/ActivityNet \
--video_path ${DATA_PATH}/ActivityNet/Activity_Videos \
--datatype activity \
--max_words 64 \
--max_frames 64 \
--video_framerate 1 \
--init_model ${CHECKPOINT_PATH} \
--output_dir ${OUTPUT_PATH} 

Train on ActivityNet Captions

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=8 \
main_retrieval.py \
--do_train 1 \
--workers 8 \
--n_display 10 \
--epochs 10 \
--lr 1e-4 \
--coef_lr 1e-3 \
--batch_size 128 \
--batch_size_val 128 \
--anno_path ${DATA_PATH}/ActivityNet \
--video_path ${DATA_PATH}/ActivityNet/Activity_Videos \
--datatype activity \
--max_words 64 \
--max_frames 64 \
--video_framerate 1 \
--estimator ${ESTIMATOR_PATH} \
--output_dir ${OUTPUT_PATH} \
--kl 2 \
--skl 1

Video-question Answering

Checkpoint Google Cloud Baidu Yun Peking University Yun
MSR-VTT-QA Download Download Download

Eval on MSR-VTT-QA

CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=2 \
main_vqa.py \
--do_eval \ 
--num_thread_reader=8 \
--train_csv data/MSR-VTT/qa/train.jsonl \
--val_csv data/MSR-VTT/qa/test.jsonl \
--data_path data/MSR-VTT/qa/train_ans2label.json \
--features_path ${DATA_PATH}/MSRVTT_Videos \
--max_words 32 \
--max_frames 12 \
--batch_size_val 16 \
--datatype msrvtt \
--expand_msrvtt_sentences  \
--feature_framerate 1 \
--freeze_layer_num 0  \
--slice_framepos 2 \
--loose_type \
--linear_patch 2d \
--init_model ${CHECKPOINT_PATH} \
--output_dir ${OUTPUT_PATH}

Train on MSR-VTT-QA

CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=2 \
main_vqa.py \
--do_train \ 
--num_thread_reader=8 \
--epochs=5 \
--batch_size=32 \
--n_display=50 \
--train_csv data/MSR-VTT/qa/train.jsonl \
--val_csv data/MSR-VTT/qa/test.jsonl \
--data_path data/MSR-VTT/qa/train_ans2label.json \
--features_path ${DATA_PATH}/MSRVTT_Videos \
--lr 1e-4 \
--max_words 32 \
--max_frames 12 \
--batch_size_val 16 \
--datatype msrvtt \
--expand_msrvtt_sentences  \
--feature_framerate 1 \
--coef_lr 1e-3 \
--freeze_layer_num 0  \
--slice_framepos 2 \
--loose_type \
--linear_patch 2d \
--estimator ${ESTIMATOR_PATH} \
--output_dir ${OUTPUT_PATH} \
--kl 2 \
--skl 1

🎗️ Acknowledgments

Our code is based on EMCL, CLIP, CLIP4Clip and DRL. We sincerely appreciate for their contributions.

hbi's People

Contributors

jpthu17 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

hbi's Issues

Banzhaf Interaction questions

Thank you for your excellent work!
But I have a question, is the following code "banzhaf[:, i, j] = self.banzhaf_interaction(retrieve_logits, text_mask, video_mask, text_weight,video_weight, i, j)" missing a plus sign?

 for i in range(self.t_len):
            for j in range(self.v_len):
                for _ in range(self.num):
                    banzhaf[:, i, j] = self.banzhaf_interaction(retrieve_logits, text_mask, video_mask, text_weight,
                                                                    video_weight, i, j)
        banzhaf = banzhaf / self.num
        banzhaf = torch.einsum('btv,bt->btv', [banzhaf, text_mask])
        banzhaf = torch.einsum('btv,bv->btv', [banzhaf, video_mask])
        return banzhaf

What's the function of the following code in BanzhafInteraction Class

Thank you for your great work!
But I'm still wondering how the Banzhaf Interaction works in the following codes:

        s_t = (torch.rand((self.t_len)) > 0.5).long().to(retrieve_logits.device) 
        s_j = (torch.rand((self.v_len)) > 0.5).long().to(retrieve_logits.device) 
        s_t[i], s_j[j] = 0, 0

        _text_mask, _video_mask = text_mask.clone(), video_mask.clone()
        _text_mask[:, s_t] = 0
        _video_mask[:, s_j] = 0

Does the _text_mask[:, s_t] = 0 mean masking the first word token and second word token because values in s_t and s_j are only 1 and 0? Or I just have the wrong understanding about it.
any reply will be helpful!

Hi

您好,我想请问您会在近期放一个arxiv吗,我对您的论文非常感兴趣,如果可以的话希望可以尽快阅读到您的著作,感谢!

banzhaf_interaction

Did you ignore that _text_mask[:, i] = 0 when calculating banzhaf_value3?

Activity=Net训练参数问题

你好,请问对于Activity-Net数据集,max_words 与max_frames 都是64的情况下,v_rate0到t_rate1都是保持原来的MSR-VTT的标准吗,以及Activity-Net的训练的Batchsize是64还是128?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.