Git Product home page Git Product logo

mvanet's Introduction

MVANet

The official repo of the CVPR 2024 paper (Highlight), Multi-view Aggregation Network for Dichotomous Image Segmentation

PWC PWC PWC PWC PWC

Introduction

Dichotomous Image Segmentation (DIS) has recently emerged towards high-precision object segmentation from high-resolution natural images. When designing an effective DIS model, the main challenge is how to balance the semantic dispersion of high-resolution targets in the small receptive field and the loss of high-precision details in the large receptive field. Existing methods rely on tedious multiple encoder-decoder streams and stages to gradually complete the global localization and local refinement.

Human visual system captures regions of interest by observing them from multiple views. Inspired by it, we model DIS as a multi-view object perception problem and provide a parsimonious multi-view aggregation network (MVANet), which unifies the feature fusion of the distant view and close-up view into a single stream with one encoder-decoder structure. Specifically, we split the high-resolution input images from the original view into the distant view images with global information and close-up view images with local details. Thus, they can constitute a set of complementary multi-view low-resolution input patches.

image

Moreover, two efficient transformer-based multi-view complementary localization and refinement modules (MCLM & MCRM) are proposed to jointly capturing the localization and restoring the boundary details of the targets.

image

We achieves state-of-the-art performance in terms of almost all metrics on the DIS benchmark dataset.

image

We have optimized the code and achieved an enhanced FPS performance, reaching 15.2.

image

Here are some of our visual results:

image

I. Requiremets

  • python==3.7
  • torch==1.10.0
  • torchvision==0.11.0
  • mmcv-full==1.3.17
  • mmdet==2.17.0
  • mmengine==0.8.1
  • mmsegmentation==0.19.0
  • numpy
  • ttach
  • einops
  • timm
  • scipy

II. Training

  1. Download the pretrained model at Google Drive.
  2. Then, you can start training by simply run:
python train.py

III. Testing

  1. Update the data path in config file ./utils/config.py (line 4~8)

  2. Replace the existing path with the path to your saved model in ./predict.py (line 14)

    You can also download our trained model at Google Drive.

  3. Start predicting by:

python predict.py
  1. Change the predicted map path in ./test.py (line 17) and start testing:
python test.py

You can get our prediction maps at Google Drive.

To Do List

  • Release our camere-ready paper on arxiv (done)
  • Release our training code (done)
  • Release our model checkpoints (done)
  • Release our prediction maps (done)

Citations

@article{yu2024multi,
  title={Multi-view Aggregation Network for Dichotomous Image Segmentation},
  author={Yu, Qian and Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and Lu, Huchuan},
  journal={arXiv preprint arXiv:2404.07445},
  year={2024}
}

mvanet's People

Contributors

qianyu-dlut avatar

Stargazers

feichaiyu avatar Jean-Philippe Deblonde avatar  avatar Charlie Stocker avatar jie avatar Jamil Zakirov avatar Michael jentsch avatar Naphat John avatar  avatar 東 avatar  avatar Michael Scofield avatar Bhanu pratap mishra avatar Phan Hoang avatar Katsuya Hyodo avatar KazuhitoTakahashi avatar Denis Kodin avatar  avatar Chris Maltais avatar  avatar  avatar Stéphane Monté avatar  avatar  avatar Jiaying Lin avatar Will Rice avatar  avatar  avatar  avatar Gavin Li avatar Collonville avatar Xelawk avatar Marvin Schirrmacher avatar Qianli Feng avatar Jeongmin Lee avatar  avatar An-zhi WANG avatar Yunyao_Shen avatar  avatar Daniil Putilov avatar  avatar Jungbeom Lee avatar Teppei Fujisawa avatar 爱可可-爱生活 avatar  avatar Shogo Ishigami avatar lxin123456 avatar  avatar  avatar  avatar  avatar Nico avatar  avatar Taehun Kim avatar Peng Zheng avatar Adrian Johnston avatar TestLong avatar Vu Hoang avatar Yuhao Wang  avatar Colle avatar Leo Yang avatar  avatar Jeff Carpenter avatar Pang avatar Nathan Breitsch avatar  avatar walker avatar Abraham avatar

Watchers

Elon avatar Andranik Sargsyan avatar Leo Yang avatar  avatar  avatar  avatar  avatar

mvanet's Issues

inf_MCRM and MCRM weights names discreptancy

@qianyu-dlut thanks for this great work, i was having one more question regarding MCRM module.

MCRM naming is linear1 and linear2 here

inf_MCRM naming is linear3 and linear4 here

In Model_80.pth, there are 4 different layers linear1 / linear2 / linear3 / linear4

dec_blk1.linear1.weight
dec_blk1.linear1.bias
dec_blk1.linear2.weight
dec_blk1.linear2.bias
dec_blk1.linear3.weight
dec_blk1.linear3.bias
dec_blk1.linear4.weight
dec_blk1.linear4.bias

and the values are different :

import torch

pretrained_dict = torch.load("./saved_model/MVANet/Model_80.pth", map_location='cuda')
print('dec_blk1.linear1.weight', torch.sum(pretrained_dict['dec_blk1.linear1.weight']))
print('dec_blk1.linear3.weight', torch.sum(pretrained_dict['dec_blk1.linear3.weight']))

outputs

dec_blk1.linear1.weight tensor(2.0187, device='cuda:0')
dec_blk1.linear3.weight tensor(-0.5632, device='cuda:0')

What is the difference between linear1 and linear3 ?

Thanks for your help

Questions About 'MCRM' Module: Positional Encoding and Output Feature

Hello,

Thank you for your excellent work on this project!

While reviewing the code, I noticed a few discrepancies between the implementation and the manuscript's description, specifically in the "MCRM" module. According to the manuscript, the local feature should include positional encoding before applying the cross-attention mechanism. However, in the code, the local feature is directly used as the key and value for cross-attention without adding positional information.

loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')

Additionally, the manuscript states that the output of the "MCRM" module is derived from the element-wise sum of the "updated local feature" and the "global feature." In contrast, the code seems to compute the output feature using the "updated local feature" and the "local feature."

src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)

Could these be potential bugs in the implementation?
Thank you again for your impressive work! I look forward to your clarification.

Environemnt setup

Hi, thank you for sharing your code, trying to setup the environment with pip, conda without succes, can you write a guidance on how to install all dependencies. Ty

Removing OpenMMLab dependencies

Any chance you'd be open to removing the openmmlab dependencies for this repository? It's a pretty hefty set of dependencies, and from what I can tell, the only thing all those libraries are used for is a logger class that is then used to log the timm library's load_checkpoint function.

I'd be happy to replace that logger with another without dependencies and submit a pull request if you're open to it. It would make using your project much simpler!

Arbitrary input size and einops error

First of all thanks a lot for pushing this repository 🙌.

I am having troubles in processing inputs of arbitrary size: when processing an image of size [1, 3, 864, 1280] the model throws the following error:

einops.EinopsError:  Error while processing rearrange-reduction pattern "b c (hg h) (wg w) -> (hg wg b) c h w".
 Input tensor shape: torch.Size([1, 128, 27, 40]). Additional info: {'hg': 2, 'wg': 2}.
 Shape mismatch, can't divide axis of length 27 in chunks of 2

Which it seems is caused by this line:

patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)

I've noticed in predict.py all inputs are resized to 1024x1204, I assume exactly for this reason. Is resizing inputs to a standard size the correct strategy here?

Error when using batch size > 1

I am having errors due to the line
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
(https://github.com/qianyu-dlut/MVANet/blob/main/model/MVANet.py#L418)
when training with batch size > 1

here, e5 will have leading (5 * batch_size) and hence split([4,1]) (is possible only for 5) is not possible for any batch size > 1

when I dug deeper, the batch index was mixed up (for instance , in https://github.com/qianyu-dlut/MVANet/blob/main/model/MVANet.py#L38)

The exact error i got was:

Traceback (most recent call last):
File "./train.py", line 1117, in
sideout5, sideout4, sideout3, sideout2, sideout1, final, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1 = generator.forward(
File "./train.py", line 882, in forward
loc_e5, glb_e5 = e5.split([4, 1], dim=0)
File "/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 10 (input tensor's size at dimension 0), but got split_sizes=[4, 1]

Is it possible to fix this while still retaining the exact architecture of the model (finetune on personal datasets starting from the pretrained 80th epoch)?

Training error : Relu variables has been modified by an inplace operation

Thanks for this work, very interesting paper

in place error raised

I faced following error while trying to run

python train.py

Result :

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [256, 1, 256]], which is output 0 of ReluBackward0

After doing some investigation i feel the problem is coming from self.activation = get_activation_fn('relu') and m.inplace = True

I have been able to find a workaround be using gelu instead of relu, but i'm still not sure why is this piece of code :

        for m in self.modules():
            if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
                m.inplace = True

Full trace

❯ python train.py 
Generator Learning Rate: 1e-05
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/PIL/Image.py:3179: DecompressionBombWarning: Image size (101824320 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.
  DecompressionBombWarning,
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/PIL/Image.py:3179: DecompressionBombWarning: Image size (102717153 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.
  DecompressionBombWarning,
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 12 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
  warnings.warn(warning.format(ret))
Generator Learning Rate: 1e-05
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/nn/functional.py:3734: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/PIL/Image.py:3179: DecompressionBombWarning: Image size (102521250 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.
  DecompressionBombWarning,
/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/autograd/__init__.py:199: UserWarning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
  File "train.py", line 103, in <module>
    sideout5, sideout4, sideout3, sideout2, sideout1, final, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3,tokenattmap2,tokenattmap1= generator.forward(images)
  File "/home/piercus/repos/mvanet/model/MVANet.py", line 412, in forward
    e5 = self.multifieldcrossatt(loc_e5, glb_e5)  # (4,128,16,16)
  File "/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/piercus/repos/mvanet/model/MVANet.py", line 141, in forward
    activated = self.activation(linear1)
  File "/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/nn/functional.py", line 1457, in relu
    result = torch.relu(input)
  File "/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "train.py", line 126, in <module>
    scaler.scale(loss).backward()
  File "/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/_tensor.py", line 489, in backward
    self, gradient, retain_graph, create_graph, inputs=inputs
  File "/home/piercus/miniconda3/envs/mvanet/lib/python3.7/site-packages/torch/autograd/__init__.py", line 199, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [256, 1, 256]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.