Git Product home page Git Product logo

hr-nas's Introduction

HR-NAS: Searching Efficient High-Resolution Neural Architectures with Lightweight Transformers (CVPR21 Oral)

Environment

Require Python3, CUDA>=10.1, and torch>=1.4, all dependencies are as follows:

pip3 install torch==1.4.0 torchvision==0.5.0 opencv-python tqdm tensorboard lmdb pyyaml packaging Pillow==6.2.2 matplotlib yacs pyarrow==0.17.1
pip3 install cityscapesscripts  # for Cityscapes segmentation
pip3 install mmcv-full==latest+torch1.4.0+cu101 -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html  # for Segmentation data loader
pip3 install pycocotools shapely==1.6.4 Cython pandas pyyaml json_tricks scikit-image  # for COCO keypoint estimation

or pip3 install requirements.txt

Setup

Optionally configure NCCL before running:

export NCCL_IB_DISABLE=1 
export NCCL_IB_HCA=mlx5_0 
export NCCL_IB_GID_INDEX=3 
export NCCL_SOCKET_IFNAME=eth0 
export HOROVOD_MPI_THREADS_DISABLE=1
export OMP_NUM_THREADS=56
export KMP_AFFINITY=granularity=fine,compact,1,0

Set the following ENV variable:

$MASTER_ADDR: IP address of the node 0 (Not required if you have only one node (machine))
$MASTER_PORT: Port used for initializing distributed environment
$NODE_RANK: Index of the node
$N_NODES: Number of nodes 
$NPROC_PER_NODE: Number of GPUs (NOTE: should exactly match local GPU numbers with `CUDA_VISIBLE_DEVICES`)

Example1 (One machine with 8 GPUs):

Node 1: 
>>> python -m torch.distributed.launch --nproc_per_node=8
--nnodes=1 --node_rank=0 --master_port=1234 train.py

Example2 (Two machines, each has 8 GPUs):

Node 1: (IP: 192.168.1.1, and has a free port: 1234)
>>> python -m torch.distributed.launch --nproc_per_node=8
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
--master_port=1234 train.py

Node 2:
>>> python -m torch.distributed.launch --nproc_per_node=8
--nnodes=2 --node_rank=1 --master_addr="192.168.1.1"
--master_port=1234 train.py

Datasets

  1. ImageNet

    • Prepare ImageNet data following pytorch example.
    • Optional: Generate lmdb dataset by utils/lmdb_dataset.py. If not, please overwrite dataset:imagenet1k_lmdb in yaml to dataset:imagenet1k.
    • The directory structure of $DATA_ROOT should look like this:
    ${DATA_ROOT}
    ├── imagenet
    └── imagenet_lmdb
    
    • Link the data:
    ln -s YOUR_LMDB_DIR data/imagenet_lmdb
  2. Cityscapes

    • Download data from Cityscapes.
    • unzip gtFine_trainvaltest.zip leftImg8bit_trainvaltest.zip
    • Link the data:
    ln -s YOUR_DATA_DIR data/cityscapes
    • preprocess the data:
    python3 tools/convert_cityscapes.py data/cityscapes --nproc 8
  3. ADE20K

    • Download data from ADE20K.
    • unzip ADEChallengeData2016.zip
    • Link the data:
    ln -s YOUR_DATA_DIR data/ade20k
  4. COCO keypoints

    • Download data from COCO.
    • build tools
    git clone https://github.com/cocodataset/cocoapi.git
    cd cocoapi/PythonAPI
    python3 setup.py build_ext --inplace
    python3 setup.py build_ext install
    make  # for nms
    • Unzip and Link the data:
    ln -s YOUR_DATA_DIR data/coco
    • We also provide person detection result of COCO val2017 and test-dev2017 to reproduce our multi-person pose estimation results. Please download from OneDrive or GoogleDrive.
    • Download and extract them under data/coco/person_detection_results, and make them look like this:
    ${POSE_ROOT}
    |-- data
    `-- |-- coco
        `-- |-- annotations
            |   |-- person_keypoints_train2017.json
            |   `-- person_keypoints_val2017.json
            |-- person_detection_results
            |   |-- COCO_val2017_detections_AP_H_56_person.json
            |   |-- COCO_test-dev2017_detections_AP_H_609_person.json
            `-- images
                |-- train2017
                |   |-- 000000000009.jpg
                |   |-- 000000000025.jpg
                |   |-- 000000000030.jpg
                |   |-- ... 
                `-- val2017
                    |-- 000000000139.jpg
                    |-- 000000000285.jpg
                    |-- 000000000632.jpg
                    |-- ... 
    

Running (train & evaluation)

  • Search for NAS models.

    python3 -m torch.distributed.launch --nproc_per_node=${NPROC_PER_NODE} --nnodes=${N_NODES} \
        --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
        --use_env train.py app:configs/YOUR_TASK.yml

    Supported tasks:

    • cls_imagenet
    • seg_cityscapes
    • seg_ade20k
    • keypoint_coco

    The super network is constructed using model_kwparams in YOUR_TASK.yml.
    To enable the searching of Transformer, set prune_params.use_transformer=True in YOUR_TASK.yml, the token numbers of each Transformer will be printed during training.
    The searched architecture can be found in best_model.json in the output dir.

  • Retrain the searched models.

    • For retraining the searched classification model, please use best_model.json to overwrite the checkpoint.json in root dir of this project.

    • Modify models/hrnet.py to set checkpoint_kwparams = json.load(open('checkpoint.json')) and class InvertedResidual(InvertedResidualChannelsFused)

    • Retrain the model.

    python3 -m torch.distributed.launch --nproc_per_node=${NPROC_PER_NODE} --nnodes=${N_NODES} \
        --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
        --use_env train.py app:configs/cls_retrain.yml

Miscellaneous

  1. Plot keypoint detection results.

    python3 tools/plot_coco.py --prediction output/results/keypoints_val2017_results_0.json --save-path output/vis/
  2. About YAML config

  • The codebase is a general ImageNet training framework using yaml config with several extension under apps dir, based on PyTorch.
    • YAML config with additional features
      • ${ENV} in yaml config.
      • _include for hierachy config.
      • _default key for overwriting.
      • xxx.yyy.zzz for partial overwriting.
    • --{{opt}} {{new_val}} for command line overwriting.
  1. Any questions regarding HR-NAS, feel free to contact the author ([email protected]).
  2. If you find our work useful in your research please consider citing our paper:
    @inproceedings{ding2021hrnas,
      title={HR-NAS: Searching Efficient High-Resolution Neural Architectures with Lightweight Transformers},
      author={Ding, Mingyu and Lian, Xiaochen and Yang, Linjie and Wang, Peng and Jin, Xiaojie and Lu, Zhiwu and Luo, Ping},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
      year={2021}
    }
    

hr-nas's People

Contributors

dingmyu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

hr-nas's Issues

No retrain.yml for segmentation tasks

Hi,

There is no retrain.yml available in the configs folder. I am wondering how to retrain the searched models for cityscapes.

Also, when I train the search model with the seg_cityscapes.yml, there will be an issue comes up when evaluating the model:

Traceback (most recent call last):
File "train.py", line 655, in
main()
File "train.py", line 651, in main
train_val_test()
File "train.py", line 507, in train_val_test
get_prune_weights(model_eval_wrapper), prune_threshold) # get mask for all bn weights (depth-wise)
File "/mnt/lustre/zengguohang/HR-NAS/utils/prune.py", line 274, in cal_mask_network_slimming_by_threshold
weights = torch.cat(bn_weights_abs)
RuntimeError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat. This usually means that this function requires a non-empty list of Tensors. Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode]

pretrained models

@dingmyu hi,
were the pretrained models uploaded such as for the imagenet and cityscapes.
can it run inferencing using cpu-only
can hr-nas detect small objects

How is L1 penalty term coefficient implemented?

Hi. I'm trying to reproduce the m_IoU performance on ADE20k and notice that your results are divided to HRnet-A/B using different L1 penalty term (lambda). But I couldn't find the lambda adjustment protocol in the code. Could you please tell me how is this lambda adjustment implemented? Thank you!

Welcome update to OpenMMLab 2.0

Welcome update to OpenMMLab 2.0

I am Vansin, the technical operator of OpenMMLab. In September of last year, we announced the release of OpenMMLab 2.0 at the World Artificial Intelligence Conference in Shanghai. We invite you to upgrade your algorithm library to OpenMMLab 2.0 using MMEngine, which can be used for both research and commercial purposes. If you have any questions, please feel free to join us on the OpenMMLab Discord at https://discord.gg/amFNsyUBvm or add me on WeChat (van-sin) and I will invite you to the OpenMMLab WeChat group.

Here are the OpenMMLab 2.0 repos branches:

OpenMMLab 1.0 branch OpenMMLab 2.0 branch
MMEngine 0.x
MMCV 1.x 2.x
MMDetection 0.x 、1.x、2.x 3.x
MMAction2 0.x 1.x
MMClassification 0.x 1.x
MMSegmentation 0.x 1.x
MMDetection3D 0.x 1.x
MMEditing 0.x 1.x
MMPose 0.x 1.x
MMDeploy 0.x 1.x
MMTracking 0.x 1.x
MMOCR 0.x 1.x
MMRazor 0.x 1.x
MMSelfSup 0.x 1.x
MMRotate 1.x 1.x
MMYOLO 0.x

Attention: please create a new virtual environment for OpenMMLab 2.0.

use_transformer Always = False ?

Hallo, I read your code and why is this variable use_transformer always set to false when building the network.

transformer_dict in mobilnet_base.py: transformer_dict is also none. If I want to activate transformer path, how can I set the transformer_dict?

Thanks for the help in advance.

The problem about importance factor

Hello. I am master in Xi'an Jiaotong Universtiy. This is an outstanding paper, but I have one question about importance factor.
I read the source code and find that the importance factor is defined as the weight sum of each block.
loss_bn_l1 = prune.cal_bn_l1_loss(get_prune_weights(model), FLAGS._bn_to_prune.penalty, rho)
Do you think my understanding is correct?
I can not understand that importance factor is the weight sum of each block.
Thank you very much!

Search Cost

Hi!

Could you tell me the search cost on different tasks?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.