Git Product home page Git Product logo

caranet's Introduction

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation

PWC

PWC

PWC

PWC

Result
This repository contains the implementation of a novel attention based network (CaraNet) to segment the polyp (CVC-T, CVC-ClinicDB, CVC-ColonDB, ETIS and Kvasir) and brain tumor (BraTS). The CaraNet show great overall segmentation performance (mean dice) on polyp and brain tumor, but also show great performance on small medical objects (small polyps and brain tumors) segmentation.

🔥 NEWS 🔥 The full paper is available: CaraNet

The journal version is available: CaraNet

Architecture of CaraNet

Backbone

We use Res2Net as our backbone.

Context module

We choose our CFP module as context module, and choose the dilation rate is 8. For the details of CFP module you can find here: CFPNet. The architecture of CFP module as shown in following figure:

Result

Axial Reverse Attention

As shown in architecture of CaraNet, the Axial Reverse Attention (A-RA) module contains two routes: 1) Reverse attention; 2) Axial-attention (The code of axial attention is applied from UACANET)

Installation & Usage

Enviroment

  • Enviroment: Python 3.6;
  • Install some packages:
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
conda install opencv-python pillow numpy matplotlib
  • Clone this repository
git clone https://github.com/AngeLouCN/CaraNet

Training

  • Download the training and testing dataset from this link: Experiment Dataset
  • Change the --train_path & --test_path in Train.py
  • Run Train.py
  • Testing dataset is ordered as follow:
|-- TestDataset
|   |-- CVC-300
|   |   |-- images
|   |   |-- masks
|   |-- CVC-ClinicDB
|   |   |-- images
|   |   |-- masks
|   |-- CVC-ColonDB
|   |   |-- images
|   |   |-- masks
|   |-- ETIS-LaribPolypDB
|   |   |-- images
|   |   |-- masks
|   |-- Kvasir
|       |-- images
|       |-- masks

Testing

  • Change the data_path in Test.py

Evaluation

  • Change the image_root and gt_root in eval_Kvasir.py
  • You can also run the matlab code in eval fold, it contains other four measurement metrics results.
  • You can download the segmentation maps of CaraNet from this link: CaraNet
  • dice_average.m is to compute the averaged dice values according to sizes of objects, for small area analysis.

Segmentation Results

  • Polyp Segmentation Results
Result
  • Conditions of test datasets:
Result
Result
  • Small polyp analysis

The x-axis is the proportion size (%) of polyp; y-axis is the average mean dice coefficient.

Result
Result
Result
Result
Result

Brain Tumor Segmentation

  • Dataset
BraTS input Segmentation truth
Result
Result
  • Results
Result
  • Small tumor analysis

For very small areas (<1%):

Result

The difference between results of CaraNet and PraNet:

Result

Citation

If you think our work is helpful, please cite both conference and journal version.

@inproceedings{lou2021caranet,
author = {Ange Lou and Shuyue Guan and Hanseok Ko and Murray H. Loew},
title = {{CaraNet: context axial reverse attention network for segmentation of small medical objects}},
volume = {12032},
booktitle = {Medical Imaging 2022: Image Processing},
organization = {International Society for Optics and Photonics},
publisher = {SPIE},
pages = {81 -- 92},
year = {2022},
doi = {10.1117/12.2611802}}

@inproceedings{9506485,
  author={Lou, Ange and Loew, Murray},
  booktitle={2021 IEEE International Conference on Image Processing (ICIP)}, 
  title={CFPNET: Channel-Wise Feature Pyramid For Real-Time Semantic Segmentation}, 
  year={2021},
  volume={},
  number={},
  pages={1894-1898},
  doi={10.1109/ICIP42928.2021.9506485}}
  
@article{lou2023caranet,
  title={CaraNet: context axial reverse attention network for segmentation of small medical objects},
  author={Lou, Ange and Guan, Shuyue and Loew, Murray},
  journal={Journal of Medical Imaging},
  volume={10},
  number={1},
  pages={014005},
  year={2023},
  publisher={SPIE}
}

caranet's People

Contributors

angeloucn avatar shuyueg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

caranet's Issues

About the CaraNet.py

Thanks for your code! But I still can not figure what cfpnet_res2net_v4 () is. This code is on line 135 of CaraNet.py.

About the calculations of the model indicators of every epoch

Sorry to bother you. May I ask about how to calculate the IoU, precision, accuracy, recall , and DSC of every epoch? Or what variables should be used to calculate these indicators?I mean ,which variables are TN, FN, TP,FP.Sorry to trouble you!

dice_average.m and compute size of polyp

Hi, I wonder to know how can I use the dice_average.m to compute the relationship between dice and the size of polyps. I'm not familiar with matalab , but dice_average.m seems only has a defined function ,I try to call this function in main.m files, but failed. Besides, I also curious about how did you compute the ratio of polyps , does the x axis 0,1,2,3,4,5,6 means the relative size ?

Some questions about self_attn

Hi.

  1. Why there is no premute operation before view in mode h?
# for mode h
projected_query = self.query_conv(x).premute(0, 1, 3, 2).view(*view).permute(0, 2, 1)
  1. Why use sigmoid instead of softmax?

About the main.m

Hi, sorry to bother. I'm wondering why I got all the Dice with NaN. I'v check the pathes, but still got NaN in each Dice.
cell. The _result.txt is like
"(Dataset:CVC-ClinicDB; Model:PraNet) meanDic:NaN;meanIoU:NaN;wFm:NaN;Sm:NaN;meanEm:NaN;MAE:NaN;maxEm:NaN;maxDice:NaN;maxIoU:NaN;meanSen:NaN;maxSen:NaN;meanSpe:NaN;maxSpe:NaN.
"

Fracture segmentation

Hi! I am really interested in your proposal and I was wondering if it could fit also bones fracture segmentation tasks. Unfortunately, it seems that it does not perform as expected, with a dice coef of 0.26. Do you have some more insights about the code and what changes may I apply to make your network fit my segmentation problem?

About Query, Key, and Value

Hi,
I'm studying medical segmentation for my job.
I read your paper and implementation on this repo.

I have questions about your implementation.

Questions

In axial reverse attention module, self attention is use;

self.query_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0)
self.key_conv = Conv(in_channels, in_channels // 8, kSize=(1, 1),stride=1,padding=0)
self.value_conv = Conv(in_channels, in_channels, kSize=(1, 1),stride=1,padding=0)

I think Query, Key , and Value in attention modules generally seems to be created by Fully Connected Layer (nn.Linear) , But convolutional layers is used in this network.

for example, ViT repo;

https://github.com/lucidrains/vit-pytorch/blob/4b8f5bc90002a5506d765c811b554760d8dd6ee7/vit_pytorch/vit.py#L47

I'm not a specialist in Transformer, so there are some ideas using convolutional layers for extract Q, K, and V.

Please tell me about :

  • Is the method using convolutional layers as potential embedding and for extract Q, K, and V popular or adopted DL architecture ?
  • The reason why you select not fully connected layers but convolutional layers for extract Q, K, and V ?

About Axial-Attention

Hi
I notice that the derivation of axial-attention in the code is as follows:
image

But in the original version of axial attention it would look like this:
image
I don't fully understand your code, can you explain it for me?I am looking forward to hearing from u.

Preprocess of Brats data

Hi,
How do you process the data of Brats? Input it into model by slice?May I have your datasets.py and preprocessed data?
I'm looking forward to your replying.

About Self-attn

Regarding the use of self-attention. Before the activation function uses sigmoid, do you need to reduce d**-0.5, because Transformer uses softmax, and it is scaled before operation.

do u do some ablation experments on ur modules

Hi, thanks for ur great work, But I dont see ablation experiments whin ur net, like which part effective the performence more , Do u have done it? I'm very interesting in which modules matters most. thx

Why add a PRELU output to a linear output?

The output of the partial decoder (https://github.com/AngeLouCN/CaraNet/blob/main/lib/partial_decoder.py#L30) has a linear activation. The output of each of the axial attention modules, which are designed for residual learning, go through a BN-PReLU (https://github.com/AngeLouCN/CaraNet/blob/main/CaraNet.py#L47)

The output (decoder + axial transformer 1, 2, 3, 4) then gets a sigmoid activation to generate class probabilities.

Why modify the original linear output (from the partial decoder) with a nonlinear function that biases positive (PRELU)? Doesn't this mean that you're more likely to saturate the sigmoid by having a large input? Or at the very least result in exploding biases for the partial decoder?

My understanding of residual learning is that it's commonly done with no activation functions prior to the summation to prevent exploding biases (continually adding a positive value to a positive value)

About the Best modle

Hi, thanks for your code! Did you save an optimal model for every dataset about the polyp segmentation?

After several experiments, the same computer on the test index is different.

Hello, I am very sorry to have occupied your study time. I have a question to ask you! The graphics card of my desktop computer is 3090, I downloaded the source code on your homepage (Github), and conducted several experiments, adding the indexes of mDice and mIOU. After testing and comparing the results of many experiments, I found that for the test dataset Kvasir, the difference of mDice indexes was as high as 1.27 percent. For the other four test sets, the gap may be larger, which is not normal. Why is this? Sorry to trouble you. Sorry again for taking up your time!

binary segmentation

Hi, thank you for sharing your code.
it is really helpful for me

anyway, i have a question.
I tried applying your code to my data. My data label values consist of 0,1 .
When applied to this code, the output is 0-255.
Can you tell me which part needs to be modified to adjust the output value to 0 and 1?
I will wait for your reply. thank you

About axial attention module gamma

Hello good sir,

i am currently using your architecture as basis for my thesis work and while training the module i wondered about the axial attention module. you utilize a gamma value by setting it this way: self.gamma = nn.Parameter(torch.zeros(1)).

Afterwards you multiply the result of´the axial attention with the gamma value: self.gamma * out and then apply the residual connection of x : out = self.gamma * out + x.

Does this not mean you don't even utilize axial attention in any way by setting the output of axial attention to zero ? and only using the residual output ?

Is this a different version of the code or am i missing something.

I would be very thankful if you could elaborate on this.

dice_average

hello,I want to ask if dice_average code is correct?

About Test.py file

When I run Test.py file, I am getting the following error:

Traceback (most recent call last): File "Test.py", line 6, in <module> from lib.HarDMSEG import HarDMSEG ModuleNotFoundError: No module named 'lib.HarDMSEG'

How to resolve it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.