Git Product home page Git Product logo

Comments (5)

TopCoder2K avatar TopCoder2K commented on July 25, 2024

Good afternoon, @ashkamath!

Thank you for the great MDETR!
I've been trying to learn MDETR on the VQA2 dataset. As I wrote above, I managed to implement the code that is needed to run training on the VQA2. Then I decided to conduct your experiment from the Appendix E of the article. I fine-tuned on the GQA balanced with --no-detection option for 10 epochs and then fine-tuned on the VQA2 for 25 epochs. But the results are quite strange. It seems that the model hasn't learned, the loss on the GQA has increased and on the VQA almost hasn't changed. Evaluation during training was performed on the val and minival splits of the GQA and VQA2 respectively.
image
image

Here are the commands I've used.
Fine-tuning on the GQA:

python run_with_submitit.py --dataset_config configs/gqa.json --ngpus 1 --nodes 1 --ema --epochs 10 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --resume https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth --batch_size 4 --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint

Fine-tuning on the VQA2:

python run_with_submitit.py --dataset_config configs/vqa2.json --ngpus 1 --nodes 1  --epochs 25 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --backbone resnet101 --load ~/MDETR/mdetr/checkpoint/pchelintsev/experiments/19311/BEST_checkpoint.pth --batch_size 4 --no_aux_loss --no_contrastive_align_loss --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint --do_qa_with_qa_fine-tuned

Evaluation on the VQA2:

python main.py --dataset_config configs/vqa2.json --eval --do_qa --split_qa_heads --no_contrastive_align_loss --no_aux_loss --no_detection --backbone resnet101 --qa_loss_coef 25 --resume ~/MDETR/mdetr/checkpoint/pchelintsev/experiments/5063/BEST_checkpoint.pth

The only significant difference is that I did 10 epochs on the GQA balanced while you did 5 epochs on the GQA all. But can it have such an impact? I think that hyperparameters can be wrong... Could you please provide what options and what hyperparameters values you used?
Also, it's interesting how you loaded the model from BEST_checkpoint.pth after fine-runing on the GQA? I used the --loadoption, and mismatching heads are deleted before running load_state_dict() (for example head for the types of questions).

from mdetr.

TopCoder2K avatar TopCoder2K commented on July 25, 2024

An interesting thing I've noticed is that when running GQA on 5 epochs, the graphs are different, and the training goes! But I still have no good results on VQA2. Comparing to my previous comment, I've improved the dataset processing, fixed a small bug, used --ema. So, I have no idea why the training doesn't go well...

Here are the graphs of the total loss and some others metrics and losses:
GQA, 5 epochs, with --no_detection, torch.set_deterministic(True)
image
VQA2, 10 epochs, torch.set_deterministic(True), after fine-tuning on GQA balanced for 5 epochs with --no_detection
image

And here are the commands I used (running on GQA and VQA respectively):

python run_with_submitit.py --dataset_config configs/gqa.json --ngpus 1 --nodes 1 --ema --epochs 5 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --load pretrained_resnet101_checkpoint.pth --batch_size 4 --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint
python run_with_submitit.py --dataset_config configs/vqa2.json --ngpus 1 --nodes 1 --ema --epochs 10 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --load ~/MDETR/mdetr/checkpoint/pchelintsev/experiments/26220/BEST_checkpoint.pth --batch_size 4 --no_aux_loss --no_contrastive_align_loss --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint --do_qa_with_qa_fine-tuned

And I still have the request: Could you please provide what options and what hyperparameters values you used?

from mdetr.

ashkamath avatar ashkamath commented on July 25, 2024

Hi!
For GQA, there was a typo in the paper and in the appendix, which was fixed in the main paper but I seem to have forgotten to update the appendix - After pre-training on modulated detection, we fine tune with the QA queries on GQA all for 125 epochs, with the command:

python run_with_submitit.py --dataset_config configs/gqa.json --ngpus 8 --ema --epochs 125 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --load https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth --nodes 4 --batch_size 4 --no_aux_loss --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5
here, it is important to have this QA loss coefficient that puts more weight on the QA losses than on the detection losses.

Using this model, we then fine tune on VQA for 25 epochs (we used a bce loss on the qa head):

python run_with_submitit.py --backbone "resnet101" --dataset_config configs/vqa2.json --num_queries 100 --batch_size 4 --num_workers 5 --schedule linear_with_warmup --text_encoder_type roberta-base --ngpus 8 --nodes 4 --ema --do_qa --load path/to/gqa/model --no_aux_loss --no_detection --lr 7e-5 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --bce_qa --epochs 25

So in short, if youre training on GQA all, dont use --no_detection, and then initialize from best model after 125 epochs and finetune on vqa for 25.

Hope this helps! Feel free to get back with questions.

Best,
Aish

from mdetr.

TopCoder2K avatar TopCoder2K commented on July 25, 2024

Thank you for the reply!

Aa, you've fine tuned it on GQA all for 125 epochs and you haven't used --no_detection, okay! By the way, there is no --bce_qa flag in main.py... Anyway, I haven't enough resources to train on GQA all for 125 epochs. I'll try the second option: fine-tuning the pre-trained model.
And could you provide the command you've used to fine tune the pre-trained model on the VQA2?

from mdetr.

TopCoder2K avatar TopCoder2K commented on July 25, 2024

I've tried the second option by running the following command:
python run_with_submitit.py --dataset_config configs/vqa2.json --ngpus 1 --nodes 1 --ema --epochs 10 --epoch_chunks 25 --do_qa --split_qa_heads --load pretrained_resnet101_checkpoint.pth --batch_size 4 --no_aux_loss --no_detection --lr 7e-5 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint
Here I've tried to adjust it according to your command for fine-tuning after GQA fine-tuning

python run_with_submitit.py --backbone "resnet101" --dataset_config configs/vqa2.json --num_queries 100 --batch_size 4 --num_workers 5 --schedule linear_with_warmup --text_encoder_type roberta-base --ngpus 8 --nodes 4 --ema --do_qa --load path/to/gqa/model --no_aux_loss --no_detection --lr 7e-5 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --bce_qa --epochs 25 

So, I have deleted qa_coeff, changed lr, removed lr_drop, turned on contrastive_align_loss.
Unfortunately, the results are the not what I would like.
image

Could you please post the exact command you used?

UPD1:
Oh, I've just realized that I can use your gqa_resnet101_checkpoint.pth! I've already set up the experiments with this checkpoint! Looking forward to the results!

UPD2:
Yeah, I've got good results, thank you!

from mdetr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.