Git Product home page Git Product logo

cfa's Introduction

CFA for SGG in Pytorch

LICENSE Python PyTorch

Our paper Compositional Feature Augmentation for Unbiased Scene Graph Generation has been accepted by ICCV 2023.

Installation

Check INSTALL.md for installation instructions.

Dataset

Check DATASET.md for instructions of dataset preprocessing.

Extract Features

CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10032 --nproc_per_node=1 tools/generate_aug_feature.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor  TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR exp/motif-precls MIXUP.FEAT_PATH feats TYPE extract_aug

Processing Features

python tools/processing_features.py

Training Models with CFA

# for PredCls
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 10054 --nproc_per_node=2 tools/relation_train_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 2 DTYPE "float16" SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR ./exp/motifs_cfa_predcls TYPE cfa MIXUP.FEAT_PATH feats MIXUP.MIXUP_BG True MIXUP.MIXUP_FG True MIXUP.BG_LAMBDA 0.5 MIXUP.FG_LAMBDA 0.5 MIXUP.PREDICATE_LOSS_TYPE MIXUP_CE MIXUP.MIXUP_ADD_TAIL True FG_TAIL True FG_BODY True BG_TAIL True CL_TAIL True USE_PREDCLS_FEATURE True CONTRA False PKO False
# for SGCls
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 10054 --nproc_per_node=2 tools/relation_train_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 2 DTYPE "float16" SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR ./exp/motifs_cfa_sgcls TYPE cfa MIXUP.FEAT_PATH feats MIXUP.MIXUP_BG True MIXUP.MIXUP_FG True MIXUP.BG_LAMBDA 0.5 MIXUP.FG_LAMBDA 0.5 MIXUP.PREDICATE_LOSS_TYPE MIXUP_CE MIXUP.MIXUP_ADD_TAIL True FG_TAIL True FG_BODY True BG_TAIL True CL_TAIL True USE_PREDCLS_FEATURE False CONTRA True PKO False
# for SGDet
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 10054 --nproc_per_node=2 tools/relation_train_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 2 DTYPE "float16" SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR ./exp/motifs_cfa_sgdet TYPE cfa MIXUP.FEAT_PATH feats MIXUP.MIXUP_BG True MIXUP.MIXUP_FG True MIXUP.BG_LAMBDA 0.5 MIXUP.FG_LAMBDA 0.5 MIXUP.PREDICATE_LOSS_TYPE MIXUP_CE MIXUP.MIXUP_ADD_TAIL True FG_TAIL True FG_BODY True BG_TAIL True CL_TAIL True USE_PREDCLS_FEATURE False CONTRA True PKO False

Test Models with Prior Knowledge

CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10054 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 1 DTYPE "float16" SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR ./exp/motifs_cfa_sgcls TYPE cfa MIXUP.FEAT_PATH feats MIXUP.MIXUP_BG True MIXUP.MIXUP_FG True MIXUP.BG_LAMBDA 0.5 MIXUP.FG_LAMBDA 0.5 MIXUP.PREDICATE_LOSS_TYPE MIXUP_CE MIXUP.MIXUP_ADD_TAIL True FG_TAIL True FG_BODY True BG_TAIL True CL_TAIL True USE_PREDCLS_FEATURE False CONTRA True PKO True

Comments for Parameters in Command

To make it easier for you to run our code, the Parameters in the command are explained here:

  • --master_port: It represents the port on which the command is run.
  • CUDA_VISIBLE_DEVICES: It means the the GPUs that you are going to use. For example, CUDA_VISIBLE_DEVICES=0,1 use the first two GPUs.
  • --nproc_per_node: It is the number of GPUs you are going to use.
  • SOLVER.IMS_PER_BATCH: It is the training batch size.
  • TEST.IMS_PER_BATCH: It is the testing batch size.
  • SOLVER.MAX_ITER: It is the maximum iteration.
  • SOLVER.STEPS: It is the steps where we decay the learning rate
  • SOLVER.VAL_PERIOD: It is the period of conducting val.
  • SOLVER.CHECKPOINT_PERIOD: It is the period of saving checkpoint.
  • MODEL.RELATION_ON It means turning on the relationship head or not (since this is the pretraining phase for Faster R-CNN only, we turn off the relationship head), OUTPUT_DIR is the output directory to save checkpoints.
  • MODEL.ROI_RELATION_HEAD.USE_GT_BOX and MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL: They used to select the protocols, (1) PredCls: They are all set as True. (2) SGCls: MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL is set to False, while MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True is set to True. (3) SGDet: They are all set to False.
  • MODEL.ROI_RELATION_HEAD.PREDICTOR: It is the backbobe you are going to use., and the MOTIFS SGG backbone (MotifPredictor) is used by default.
  • MIXUP.FEAT_PATH: It refers to the path through which features are extracted and processed.
  • EXTRACT_GROUP: It represents the predicate in which group to extract. The options are head, body, tail, or their combinations, separated by commas.
  • TYPE: The type of operation. If it set to 'cfa', it represents training with cfa. If it set to 'extract_aug', it represents feature extraction operation
  • FG_HEAD/FG_BODY/FG_TAIL: It represents whether the Etrinsic-CFA operation for the group's foreground is performed.
  • BG_HEAD/BG_BODY/BG_TAIL: It represents whether the Etrinsic-CFA operation for the group's background is performed.
  • CL_HEAD/CL_BODY/CL_TAIL: It represents whether the Intrinsic-CFA operation for the group is performed.
  • CONTRA: It implies whether to use the contrastive loss.
  • PKO: It implies whether to use the prior knowledge during the inference.

Models and Generated Files

For the Motifs-CFA, we provide the trained models (checkpoint) for verification purpose. Please download from here* and unzip to checkpoints. Besides, we provide the extracted feature files, you can download from here*.

Citations

If you find this project helps your research, please kindly consider citing our paper in your publications.

@InProceedings{Li_2023_ICCV,
    author    = {Li, Lin and Chen, Guikun and Xiao, Jun and Yang, Yi and Wang, Chunping and Chen, Long},
    title     = {Compositional Feature Augmentation for Unbiased Scene Graph Generation},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {21685-21695}
}

Credits

Our codebase is based on Scene-Graph-Benchmark.pytorch.

cfa's People

Contributors

muktilin avatar

Stargazers

tujiangwei avatar Zhanwen Chen avatar  avatar yahooo avatar Shania S avatar  avatar Jiaming Lei avatar  avatar Long Chen avatar Guikun Chen avatar Maëlic Neau avatar Kanghoon Yoon avatar  avatar

Watchers

Long Chen avatar

cfa's Issues

Regarding the pre-trained model that extracts the feature augmentations

I really appreciate your work, which gives me a lot of insight. I succeeded in reproducing the result of your work. But I have remaining questions (these questions may be the last..)

First, I downloaded the feature files for the memory bank (e.g., sgcls_body_feature_with_proposal_dict_motf). In the name of the file, the task 'sgcls' is included. Should I generate different features for each task to perform CFA, i.e., predcls_feature, sgdet_feature ?

Second, does CFA require a pre-trained SGG model for each task, or just pre-trained Faster R-CNN?

I thought CFA needs pre-trained motif models for predcls, sgcls, and sgdet, respectively, to extract the union features. However, in the below command you offered, the pre-trained_detector_ckpt is in the faster_rcnn directory, which confuses me.

CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10032 --nproc_per_node=1 tools/generate_aug_feature.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor  TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR exp/motif-precls MIXUP.FEAT_PATH feats TYPE extract_aug

Again, I really appreciate your kind and professional responses. Thank you :)

Details about GPUs

Hi,

I'm curious about the details about GPU model you utilized for training during your research.

Thx.

Regarding the input of Relation Feature Extractor for CFA

Hello.

I tried to run the python execution in the guidelines in your readme as follows:

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --master_port 10054 --nproc_per_node=2 tools/relation_train_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor SOLVER.IMS_PER_BATCH 12 TEST.IMS_PER_BATCH 2 DTYPE "float16" SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 GLOVE_DIR glove MODEL.PRETRAINED_DETECTOR_CKPT checkpoints/pretrained_faster_rcnn/model_final.pth OUTPUT_DIR ./exp/motifs_cfa_sgcls TYPE cfa MIXUP.FEAT_PATH feats ``MIXUP.MIXUP_BG`` True ``MIXUP.MIXUP_FG True`` MIXUP.BG_LAMBDA 0.5 MIXUP.FG_LAMBDA 0.5 MIXUP.PREDICATE_LOSS_TYPE MIXUP_CE MIXUP.MIXUP_ADD_TAIL True FG_TAIL True FG_BODY True BG_TAIL True CL_TAIL True USE_PREDCLS_FEATURE False CONTRA True PKO False

However, I found that the feature extractor, AugBilvlMxiUpRelationFeatureExtractor, cannot get the MIX_UP_FG and MIX_UP_BG from the above command. I think the default of mix_up_fg and mix_up_bg is False (the forward function of AugBilvlMxiUpRelationFeatureExtractor), and your code does not put the cfg.MIX_UP_FG itself in the forward function of the relation feature extractor. So, when I run the code, the mix_up_fg and mix_up_bg are False although above command set these variables as True

I think it needs to be corrected. are there things that I misunderstood?

Purpose of 'extract_cfa_feat'

Thank you for sharing your code.

I am currently seeing the details of the code, I confirmed three options 'extract_cfa_feat', 'cfa', and 'extract_aug'.


'extract_aug' seems to save the roi features in the memory bank; 'cfa' is an option used to train motif-cfa.
But I can't see the use of 'extract_cfa_feat'.

Am I correctly understanding the function of each option?
If I am wrong, can you explain the purpose of each option?

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.