Git Product home page Git Product logo

g-cascade's Introduction

G-CASCADE

Official Pytorch implementation of G-CASCADE: Efficient Cascaded Graph Convolutional Decoding for 2D Medical Image Segmentation WACV 2024.
Md Mostafijur Rahman, Radu Marculescu

The University of Texas at Austin

Architecture

Qualitative Results

Usage:

Recommended environment:

Python 3.8
Pytorch 1.11.0
torchvision 0.12.0

Please use pip install -r requirements.txt to install the dependencies.

Data preparation:

  • Synapse Multi-organ dataset: Sign up in the official Synapse website and download the dataset. Then split the 'RawData' folder into 'TrainSet' (18 scans) and 'TestSet' (12 scans) following the TransUNet's lists and put in the './data/synapse/Abdomen/RawData/' folder. Finally, preprocess using python ./utils/preprocess_synapse_data.py or download the preprocessed data and save in the './data/synapse/' folder. Note: If you use the preprocessed data from TransUNet, please make necessary changes (i.e., remove the code segment (line# 88-94) to convert groundtruth labels from 14 to 9 classes) in the utils/dataset_synapse.py.

  • ACDC dataset: Download the preprocessed ACDC dataset from Google Drive of MT-UNet and move into './data/ACDC/' folder.

  • Polyp datasets: Download the training and testing datasets Google Drive and move them into './data/polyp/' folder.

  • ISIC2018 dataset: Download the training and validation datasets from https://challenge.isic-archive.com/landing/2018/ and merge them together. Afterwards, split the dataset into 80%, 10%, and 10% training, validation, and testing datasets, respectively. Move the splited dataset into './data/ISIC2018/' folder.

Pretrained model:

You should download the pretrained PVTv2 model from Google Drive, and then put it in the './pretrained_pth/pvt/' folder for initialization. Similarly, you should download the pretrained MaxViT models from Google Drive, and then put it in the './pretrained_pth/maxvit/' folder for initialization.

Training:

cd into G-CASCADE

For Synapse Multi-organ dataset training, run CUDA_VISIBLE_DEVICES=0 python -W ignore train_synapse.py

For ACDC dataset training, run CUDA_VISIBLE_DEVICES=0 python -W ignore train_ACDC.py

For Polyp datasets training, run CUDA_VISIBLE_DEVICES=0 python -W ignore train_polyp.py

For ISIC2018 dataset training, run CUDA_VISIBLE_DEVICES=0 python -W ignore train_ISIC2018.py

Testing:

cd into G-CASCADE 

For Synapse Multi-organ dataset testing, run CUDA_VISIBLE_DEVICES=0 python -W ignore test_synapse.py

For ACDC dataset testing, run CUDA_VISIBLE_DEVICES=0 python -W ignore test_ACDC.py

For Polyp dataset testing, run CUDA_VISIBLE_DEVICES=0 python -W ignore test_polyp.py

For ISIC2018 dataset testing, run CUDA_VISIBLE_DEVICES=0 python -W ignore test_ISIC2018.py

Acknowledgement

We are very grateful for these excellent works timm, MERIT, CASCADE, PraNet, Polyp-PVT and TransUNet, which have provided the basis for our framework.

Citations

@inproceedings{rahman2024g,
  title={G-CASCADE: Efficient Cascaded Graph Convolutional Decoding for 2D Medical Image Segmentation},
  author={Rahman, Md Mostafijur and Marculescu, Radu},
  booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
  pages={7728--7737},
  year={2024}
}

g-cascade's People

Contributors

mostafij-rahman avatar sldgroup avatar

Stargazers

 avatar  avatar wuzhy avatar 1WuNeiGui avatar  avatar LEI SHI avatar House avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar veropeak avatar  avatar  avatar  avatar  avatar Ricardo avatar  avatar  avatar  avatar Pang Yaoxing avatar  avatar Rui Zhang avatar leo1 avatar 吴瀚骋 avatar Tongfei avatar Ellery Queen avatar  avatar Spider Man avatar Dexing Huang avatar jiaxing chai avatar  avatar  avatar HIT_zw avatar Keying Qi avatar janghan avatar XU avatar  avatar  avatar

Watchers

 avatar

Forkers

thealphajas

g-cascade's Issues

Help!Training on Synapse dataset!

Sorry, this is my first time asking a question on GitHub. This question can be directly deleted or ignored (I really can't find a way to delete it).

Question about LV MYO RV values in your paper?

Thanks for your outstanding work. Can you share the version of the test.py code that outputs the values of LV MYO RV? Besides, I have successfully run your code but obtained the values of 12 test cases on the ACDC dataset but still struggling to understand how you obtained the LV MYO RV values.
Capture

Help!Training on Synapse dataset!

Hello author, after configuring the environment using your project, my first attempt to train the ACDC dataset went smoothly.
However, when I tried to train the Synapse dataset, I encountered the following problem:
KQ`W9G59LZNA$C$2DLI2DC3

These are the details that I consider useful:
R_T9J53M)6S WD}ZWH~ T2L
3QC`3 L1RPGW2I65483F%UV

I am looking forward to your reply. Thank you!

Experiment result!

Hi!
Thank you for your excellent work, can I ask about why I can't find the segmentation result of Polyp and ISIC, and look forward to your reply

about polyp

What do I do with the downloaded polyp data set to make it usable? If possible, could you share the data you have processed?About the polyp data set, how do you divide the data? The downloaded data set cannot be used directly. How do you handle it?
Snipaste_2023-12-14_15-10-22

height (28) must be divisible by window (8)

Hi;

I got the below error when trying to test it with the synapse dataset. Notice this problem appears only when I used "MERIT_GCASCADE" but not with "PVT_GCASCADE".

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Input In [2], in MERIT_GCASCADE.forward(self, x)
    408     x = self.conv_1cto3c(x)
    410 # transformer backbone as encoder
--> 411 f1 = self.backbone1(F.interpolate(x, size=self.img_size_s1, mode=self.interpolation))                
    412 #print([f1[3].shape,f1[2].shape,f1[1].shape,f1[0].shape])
    413 
    414 # decoder
    415 x11_o, x12_o, x13_o, x14_o = self.decoder1(f1[3], [f1[2], f1[1], f1[0]])

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1730, in MaxxVit.forward(self, x)
   1729 def forward(self, x):
-> 1730     x = self.forward_features(x)
   1731     #x = self.forward_head(x)
   1732     return x

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1714, in MaxxVit.forward_features(self, x)
   1712 features = []
   1713 for i in range(len(self.stages)):
-> 1714     x = self.stages[i](x)
   1715     #print(x.shape)
   1716     if(i==len(self.stages)-1):

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1550, in MaxxVitStage.forward(self, x)
   1548     x = checkpoint_seq(self.blocks, x)
   1549 else:
-> 1550     x = self.blocks(x)
   1551 return x

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /scratch/ahmed/lib/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
    213 def forward(self, input):
    214     for module in self:
--> 215         input = module(input)
    216     return input

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1440, in MaxxVitBlock.forward(self, x)
   1438 if not self.nchw_attn:
   1439     x = x.permute(0, 2, 3, 1)  # to NHWC (channels-last)
-> 1440 x = self.attn_block(x)
   1441 x = self.attn_grid(x)
   1442 if not self.nchw_attn:

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1228, in PartitionAttentionCl.forward(self, x)
   1227 def forward(self, x):
-> 1228     x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
   1229     x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
   1230     return x

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1215, in PartitionAttentionCl._partition_attn(self, x)
   1213 img_size = x.shape[1:3]
   1214 if self.partition_block:
-> 1215     partitioned = window_partition(x, self.partition_size)
   1216 else:
   1217     partitioned = grid_partition(x, self.partition_size)

File /scratch/ahmed/G-CASCADE-main/lib/maxxvit_4out.py:1127, in window_partition(x, window_size)
   1125 def window_partition(x, window_size: List[int]):
   1126     B, H, W, C = x.shape
-> 1127     _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
   1128     _assert(W % window_size[1] == 0, '')
   1129     x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)

File /scratch/ahmed/lib/torch/__init__.py:1404, in _assert(condition, message)
   1402 if type(condition) is not torch.Tensor and has_torch_function((condition,)):
   1403     return handle_torch_function(_assert, (condition,), condition, message)
-> 1404 assert condition, message

AssertionError: height (28) must be divisible by window (8)

I have encountered some errors and need help

Hello, I encountered some errors while replicating this project. When I use MERIT as my encoder, the following error occurs.
屏幕截图 2024-05-25 222811
Later I find this problem in train() function, when using PVT decoder, the input image size will not change, while in MERIT model, the input image size will be changed to 256, while the original input image size is 256*0.75=196, I would like to ask how to solve this problem.
屏幕截图 2024-05-25 223436

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.