Git Product home page Git Product logo

sam_myelin_seg_tem's Introduction

Binder Build Status Documentation Status Coverage Status Twitter Follow

Segment axon and myelin from microscopy data using deep learning. Written in Python. Using the TensorFlow framework. Based on a convolutional neural network architecture. Pixels are classified as either axon, myelin or background.

For more information, see the documentation website.

alt tag

Help

Whether you are a newcomer or an experienced user, we will do our best to help and reply to you as soon as possible. Of course, please be considerate and respectful of all people participating in our community interactions.

  • If you encounter difficulties during installation and/or while using AxonDeepSeg, or have general questions about the project, you can start a new discussion on the AxonDeepSeg GitHub Discussions forum. We also encourage you, once you've familiarized yourself with the software, to continue participating in the forum by helping answer future questions from fellow users!
  • If you encounter bugs during installation and/or use of AxonDeepSeg, you can open a new issue ticket on the AxonDeepSeg GitHub issues webpage.

Napari plugin

A tutorial demonstrating the basic features of our plugin for Napari is hosted on YouTube, and can be viewed by clicking this link.

References

AxonDeepSeg

Applications

Reviews

Citation

If you use this work in your research, please cite it as follows:

Zaimi, A., Wabartha, M., Herman, V., Antonsanti, P.-L., Perone, C. S., & Cohen-Adad, J. (2018). AxonDeepSeg: automatic axon and myelin segmentation from microscopy data using convolutional neural networks. Scientific Reports, 8(1), 3816. Link to paper: https://doi.org/10.1038/s41598-018-22181-4.

Copyright (c) 2018 NeuroPoly (Polytechnique Montreal)

Licence

The MIT License (MIT)

Copyright (c) 2018 NeuroPoly, École Polytechnique, Université de Montréal

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Contributors

Pierre-Louis Antonsanti, Stoyan Asenov, Mathieu Boudreau, Oumayma Bounou, Marie-Hélène Bourget, Julien Cohen-Adad, Victor Herman, Melanie Lubrano, Antoine Moevus, Christian Perone, Vasudev Sharma, Thibault Tabarin, Maxime Wabartha, Aldo Zaimi.

sam_myelin_seg_tem's People

Contributors

hermancollin avatar

Watchers

 avatar  avatar  avatar  avatar

sam_myelin_seg_tem's Issues

Train object detection model to predict axon centroids

Pretty straightforward. This would make this pipeline fully automatic.

From my understanding, it would be possible to get accurate results because the SOTAs for object detection are pretty powerful. Also, the data is already there. For other datasets than TEM, some minor preprocessing would be required to get the axon centroids.

Move on to other datasets and aggregate datasets

The TEM dataset was the biggest of our private datasets, but we will also need to move on to other modalities to compare performance with ivadomed and nnunetv2.

The ultimate goal of using SAM was to train a single "foundation" model for every type of contrasts/resolution, so we will eventually want to train SAM on an aggregation of all our private datasets. For comparison, we will need to train models on every dataset individually before aggregating. This necessary step will also allow us to fix bugs with these datasets before we use them all at once.

Save best model based on validation dice

It would be important to keep the model that gives the best validation metrics. Currently, this is not the case and we use the final checkpoint for testing, but this is not optimal.

Try fully automatic axon segmentation

We would like to know how easily SAM can be fine-tuned for fully automatic segmentation (i.e. prompting with the whole image as a bbox instead of prompting with a ROI of interest).

The perfect pretext to try this is axon segmentation.

  1. If it works well, this axon segmentation could then be used to generate the bbox prompts used subsequently for the myelin segmentation
  2. There is no overlap between instances in the axon class. The myelin class would not be well suited because there are many overlaps (or rather touching myelin sheaths). A big advantage of segmenting the myelin with localized bbox prompts is that we get a clean and reliable instance segmentation. The way AxonDeepSeg currently works it that the semantic segmentation is "subdivided" (semantic to instance) with a watershed algorithm. Although it works fine in a lot of cases, sometimes this process deteriorates the segmentation. See example below, and look for small axons touching big axons. In the instance segmentation, a "leak" artifact occurs, where the myelin of the small axon is wrongly attributed to the big axon.
    image

Randomize bounding box

def get_myelin_bbox(bbox_df, axon_id):
return np.array(bbox_df.iloc[axon_id])

Currently, bounding boxes are directly loaded for training. They are generated by extracting the thightest bounding box around the annotation. The exact coordinates of the bounding box should have a random component to avoid overfit. Similar to what was done in MedSAM:

https://github.com/bowang-lab/MedSAM/blob/8432244ac07be6baba120dcb786e8a694c188eb9/train_one_gpu.py#L104-L107

Train image encoder + mask decoder

It would be interesting to see if training the image encoder as well could help. MedSAM trains the encoder as well and they specify that all the weights in the image encoder were updated.

Train cascaded pipeline

Currently, the axon and myelin segmentation models are trained separately and independently. The 2 cascaded models should be trained at once. This would also allow parameter sharing, like having a common image encoder for both models.

This month, the myelin segmentation should get good enough to move on to this "cascaded" training. I already expect a lot of autograd problems...

However, after this, the model should be ready for a public release.

Roadmap

This is a general overview of what needs to be done in this project before moving on to other datasets. Currently, both the axon and myelin seg models outperform ivadomed, but the overall pipeline is not efficient and nnUNet still beats SAM.

  • integrate "patch-based" training: similar to how we trained U-Nets, this would allow bigger batch sizes and would eventually allow for data augmentation. Ideally, implement this with the MONAI dataloader for easier dataAug integration
  • merge axon and myelin image encoders (halves overall model size, allows parameter sharing, more efficient training pipeline); see #10. Eventually, all datasets would be aggregated and the image encoder would learn to process all modalities.
  • implement multi-GPU training to further increase the batch size and be able to train longer

ViT_H was trained with ViT_B image embeddings...

I just realized that I forgot to re-compute the image embeddings for ViT_H training. Not sure I understand why the training could still be completed... This needs to be fixed for proper ViT_H results.

Add regularization to loss

I would like to add some regularization to the loss function for robustness to discourage the model to produce "glitchy" segmentations. For a perfect illustration, see the image below, taken from the validation set of https://github.com/brainhack-school2023/collin_project/tree/main (first iteration of this project).

Screenshot_20230722_142152

  1. Pink axon has discontinuities.
  2. Brown axon is not complete

I am not yet entirely sure how to regularize the myelin prediction, but will update this issue later.

Integrate SAM into ADS

(Maybe we should keep the training scripts separate from the rest)
This would be pretty straightforward to integrate. The model checkpoints will need a release, and we can use the inference script as a reference. The only time-consuming part would be to add tests.

Add shuffling in dataloader

def bids_dataloader(data_dict, maps_path, embeddings_path, sub_list):
'''
:param data_dict: contains img, mask and px_size info per sample per subject
:param maps_path: paths to myelin maps (instance masks)
:param embeddings_path paths to pre-computed image embeddings
:param sub_list subjects included
'''
subjects = list(data_dict.keys())
# # we keep the last subject for testing
# for sub in subjects[:-1]:
for sub in subjects:
if sub in sub_list:
samples = (s for s in data_dict[sub].keys() if 'sample' in s)
for sample in samples:
emb_path = embeddings_path / sub / 'micr' / f'{sub}_{sample}_TEM_embedding.pt'
bboxes = get_sample_bboxes(sub, sample, maps_path)
myelin_map = get_myelin_map(sub, sample, maps_path)
yield (emb_path, bboxes, myelin_map)

Order in which samples are loaded should be shuffled.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.