Git Product home page Git Product logo

ait's Introduction

All in Tokens: Unifying Output Space of Visual Tasks via Soft Token

PWC

By Jia Ning*, Chen Li*, Zheng Zhang*, Zigang Geng, Qi Dai, Kun He, Han Hu

Introduction

AiT is initially described in arxiv, which is a framework to unify the output space of visual tasks. We demonstrate a single unified model that simultaneously handles two typical visual tasks of instance segmentation and depth estimation, which have discrete/fixed-length and continuous/varied-length outputs, respectively. We propose several new techniques that take into account the particularity of visual tasks: 1) Soft tokens. We employ soft tokens to represent the task output. Unlike hard tokens in the common VQ-VAE which are assigned one-hot to discrete codebooks/vocabularies, the soft tokens are assigned softly to the codebook embeddings. Soft tokens can improve the accuracy of both the next token inference and decoding the task output; 2) Mask augmentation. Many visual tasks have corruption, undefined or invalid values in label annotations, i.e., occluded area of depth maps. We show that a mask augmentation technique can greatly benefit these tasks. With these new techniques and other designs, we show that the proposed general-purpose task solver can perform both instance segmentation and depth estimation well. Particularly, we achieve 0.275 RMSE on the specific task of NYUv2 depth estimation, setting a new record on this benchmark.

teaser

Results and Models

Results on COCO instance segmentation

Model
Box AP Mask AP VQ-VAE Model Task-Solver Model
AiT(SwinV2-B) 43.3 34.2 vqvae_insseg.pt model
AiT(SwinV2-B) w/o soft token 43.6 31.1(-3.1) vqvae_insseg.pt model

Results on NYUv2 depth estimation

Model
D1 D2 D3 Abs Rel RMSE Log10 VQ-VAE
Model
Task-Solver
Model
AiT(SwinV2-B) 0.934 0.991 0.998 0.087 0.305 0.037 vqvae_depth.pt model
AiT-P(SwinV2-B) 0.940 0.992 0.998 0.085 0.301 0.036 vqvae_depth.pt model
AiT(SwinV2-B) w/o soft token 0.932 0.991 0.998 0.089 0.318 0.038 vqvae_depth.pt model
AiT(SwinV2-L) 0.949 0.993 0.999 0.079 0.284 0.034 vqvae_depth.pt model
AiT-P(SwinV2-L) 0.954 0.994 0.999 0.076 0.275 0.033 vqvae_depth.pt model

Joint training results on COCO and NYUv2

Model
Box AP Mask AP RMSE VQ-VAE Model Task-Solver
Model
AiT(SwinV2-B) 42.2 34.1 0.310 vqvae_depth.pt/vqvae_insseg.pt model

Usage

Installation

We recommend using pytorch>=1.10, other packages can be found in requirements.txt. To install boundary-iou-api, please using the following command:

git clone https://github.com/bowenc0221/boundary-iou-api && cd boundary-iou-api && pip install -e .

Data/Pre-training model Preparation

  1. Download the NYU Depth V2 dataset, COCO datasets, our preprocess box-cropped binary instance masks, named maskcoco, and organize the data according to the following directory structure:
AiT
├── ait
├── vae
├── data
│   ├── coco
│   │   ├── annotations
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017
│   ├── maskcoco
│   ├── nyu_depth_v2
  1. Create the data links using following commands:
ln -s data ait/data
ln -s data vae/data
  1. Download pre-trained backbone models swin_v2_base_densesimmim.pth and swin_v2_large_densesimmim.pth.

Training

Training VQ-VAE on depth estimation:

cd vae
python -m torch.distributed.launch --nproc_per_node=${N_GPUS} train_depth_vqvae_dist.py  configs/depth/ait_depth_vqvae.py --cfg-options <custom-configs>

Training VQ-VAE on instance segmentation:

cd vae
python -m torch.distributed.launch --nproc_per_node=${N_GPUS} train_insseg_vqvae_dist.py  configs/insseg/ait_insseg_vqvae.py --cfg-options <custom-configs>

Training task-solver on depth estimation:

cd ait

# Train auto-regressive model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_depthonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt # for AR training

# Train parallel model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_parallel_depthonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt # for parallel training

Training task-solver on object detection

cd ait
python -m torch.distributed.launch --nproc_per_node=16 --nnodes=2 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} code/train.py configs/swinv2b_640reso_detonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth

Note: We use the pre-trainined object detection model to initialize the instance segmentation models and joint-training models to save training cost, please download the pre-trained model (ait_det_swinv2b_wodec.pth) before training on instance segmentation and joint training setting.

Training task-solver on instance segmentation

python -m torch.distributed.launch --nproc_per_node=16 code/train.py configs/swinv2b_640reso_inssegonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt load_from=ait_det_swinv2b_wodec.pth

Joint training on instance segmentation and depth estimation

python -m torch.distributed.launch --nproc_per_node=16 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} code/train.py configs/swinv2b_640reso_joint.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt load_from=ait_det_swinv2b_wodec.pth

Inference

Evaluate on depth estimation

cd ait

# Evaluating auto-regressive model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

# Evaluating parallele model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_parallel_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

Evaluate on instance segmentation

cd ait

python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_640reso_inssegonly.py --cfg-options model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt --eval <model_checkpiont>

Evaluate on both depth estimation and instance segmentation

cd ait

python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_640reso_joint.py --cfg-options model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

Citation

@article{ning2023all,
  title={All in Tokens: Unifying Output Space of Visual Tasks via Soft Token},
  author={Ning, Jia and Li, Chen and Zhang, Zheng and Geng, Zigang and Dai, Qi and He, Kun and Hu, Han},
  journal={arXiv preprint arXiv:2301.02229},
  year={2023}
}

ait's People

Contributors

ancientmooner avatar hust-nj avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.