Git Product home page Git Product logo

medsegdiff's Introduction

MedSegDiff: Medical Image Segmentation with Diffusion Model

MedSegDiff a Diffusion Probabilistic Model (DPM) based framework for Medical Image Segmentation. The algorithm is elaborated on our paper MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model and MedSegDiff-V2: Diffusion based Medical Image Segmentation with Transformer.

Diffusion Models work by destroying training data through the successive addition of Gaussian noise, and then learning to recover the data by reversing this noising process. After training, we can use the Diffusion Model to generate data by simply passing randomly sampled noise through the learned denoising process.In this project, we extend this idea to medical image segmentation. We utilize the original image as a condition and generate multiple segmentation maps from random noises, then perform ensembling on them to obtain the final result. This approach captures the uncertainty in medical images and outperforms previous methods on several benchmarks.

A Quick Overview

MedSegDiff-V1 MedSegDiff-V2

News

  • [TOP] Join in our Discord to ask questions and discuss with others.
  • 22-11-30. This project is still quickly updating. Check TODO list to see what will be released next.
  • 22-12-03. BraTs2020 bugs fixed. Example case added.
  • 22-12-15. Fix multi-gpu distributed training.
  • 22-12-16. DPM-Solver ✖️ MedSegDiff DONE 🥳 Now DPM-Solver is avaliable in MedsegDiff. Enjoy its lightning-fast sampling (1000 steps ❌ 20 steps ⭕️) by setting --dpm_solver True.
  • 22-12-23. Fixed some bugs of DPM-Solver.
  • 23-01-31. MedSegDiff-V2 will be avaliable soon 🥳 . Check our paper MedSegDiff-V2: Diffusion based Medical Image Segmentation with Transformer first.
  • 23-02-07. Optimize workflow in BRATS sampling. Add dataloader for processing raw 3D BRATS data.
  • 23-02-11. Fix bugs 3D BRATS data training bugs, issue 31.
  • 23-03-04. Paper MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model has been officially accepted by MIDL 2023 🥳
  • 23-04-11. A new version based on the v2 framework has been released 🥳. It's more accurate, stable, and domain-adaptable than the previous version, while still not hogging too much of your resources. We've also fixed up a bunch of small things, like the requirement.txt and isic csv files. Huge thanks to all of you who reported issues, you really helped us a lot 🤗. btw, it will run the new version by default. Add "--version 1" if you want run the previous version.
  • 23-04-12. Added a simple evaluation file for isic dataset (script/segmentation_env). Usage: python scripts/segmentation_env.py --inp_pth *folder you save prediction images* --out_pth *folder you save ground truth images*
  • 23-12-05. Paper MedSegDiff-V2: Diffusion based Medical Image Segmentation with Transformer has been officially accepted by AAAI 2024 🥳

Requirement

pip install -r requirement.txt

Example Cases

Melanoma Segmentation from Skin Images

  1. Download ISIC dataset from https://challenge.isic-archive.com/data/. Your dataset folder under "data" should be like:
data
|   ----ISIC
|       ----Test
|       |   |   ISBI2016_ISIC_Part1_Test_GroundTruth.csv
|       |   |   
|       |   ----ISBI2016_ISIC_Part1_Test_Data
|       |   |       ISIC_0000003.jpg
|       |   |       .....
|       |   |
|       |   ----ISBI2016_ISIC_Part1_Test_GroundTruth
|       |           ISIC_0000003_Segmentation.png
|       |   |       .....
|       |           
|       ----Train
|           |   ISBI2016_ISIC_Part1_Training_GroundTruth.csv
|           |   
|           ----ISBI2016_ISIC_Part1_Training_Data
|           |       ISIC_0000000.jpg
|           |       .....
|           |       
|           ----ISBI2016_ISIC_Part1_Training_GroundTruth
|           |       ISIC_0000000_Segmentation.png
|           |       .....
  1. For training, run: python scripts/segmentation_train.py --data_name ISIC --data_dir *input data direction* --out_dir *output data direction* --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --lr 1e-4 --batch_size 8

  2. For sampling, run: python scripts/segmentation_sample.py --data_name ISIC --data_dir *input data direction* --out_dir *output data direction* --model_path *saved model* --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --num_ensemble 5

  3. For evaluation, run python scripts/segmentation_env.py --inp_pth *folder you save prediction images* --out_pth *folder you save ground truth images*

In default, the samples will be saved at ./results/

Brain Tumor Segmentation from MRI

  1. Download BRATS2020 dataset from https://www.med.upenn.edu/cbica/brats2020/data.html. Your dataset folder should be like:
data
└───training
│   └───slice0001
│       │   brats_train_001_t1_123_w.nii.gz
│       │   brats_train_001_t2_123_w.nii.gz
│       │   brats_train_001_flair_123_w.nii.gz
│       │   brats_train_001_t1ce_123_w.nii.gz
│       │   brats_train_001_seg_123_w.nii.gz
│   └───slice0002
│       │  ...
└───testing
│   └───slice1000
│       │  ...
│   └───slice1001
│       │  ...
  1. For training, run: python scripts/segmentation_train.py --data_dir (where you put data folder)/data/training --out_dir output data direction --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --lr 1e-4 --batch_size 8

  2. For sampling, run: python scripts/segmentation_sample.py --data_dir (where you put data folder)/data/testing --out_dir output data direction --model_path saved model --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --num_ensemble 5

Other Examples

...

Run on your own dataset

It is simple to run MedSegDiff on the other datasets. Just write another data loader file following ./guided_diffusion/isicloader.py or ./guided_diffusion/bratsloader.py. Welcome to open issues if you meet any problem. It would be appreciated if you could contribute your dataset extensions. Unlike natural images, medical images vary a lot depending on different tasks. Expanding the generalization of a method requires everyone's efforts.

Suggestions for Hyperparameters and Training

To train a fine model, i.e., MedSegDiff-B in the paper, set the model hyperparameters as:

--image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 

diffusion hyperparameters as:

--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False

To speed up the sampling:

--diffusion_steps 50 --dpm_solver True 

run on multiple GPUs:

--multi-gpu 0,1,2 (for example)

training hyperparameters as:

--lr 5e-5 --batch_size 8

and set --num_ensemble 5 in sampling.

Run about 100,000 steps in training will be converged on most of the datasets. Note that although loss will not decrease in most of the later steps, the quality of the results are still improving. Such a process is also observed on the other DPM applications, like image generation. Hope someone smart can tell me why🥲.

I will soon publish its performance under smaller batch size (suitable to run on 24GB GPU) for the need of comparison🤗.

A setting to unleash all its potential is (MedSegDiff++):

--image_size 256 --num_channels 512 --class_cond False --num_res_blocks 12 --num_heads 8 --learn_sigma True --use_scale_shift_norm True --attention_resolutions 24 

Then train it with batch size --batch_size 64 and sample it with ensemble number --num_ensemble 25.

Be a part of MedSegDiff ! Authors are YOU !

Welcome to contribute to MedSegDiff. Any technique can improve the performance or speed up the algorithm is appreciated🙏. I am writting MedSegDiff V2, aiming at Nature journals/CVPR like publication. I'm glad to list the contributors as my co-authors🤗.

TODO LIST

  • Fix bugs in BRATS. Add BRATS example.
  • Release REFUGE and DDIT dataloaders and examples
  • Speed up sampling by DPM-solver
  • Inference of depth
  • Fix bugs in Multi-GPU parallel
  • Sample and Vis in training
  • Release pre processing and post processing
  • Release evaluation
  • Deploy on HuggingFace
  • configuration

Thanks

Code copied a lot from openai/improved-diffusion, WuJunde/ MrPrism, WuJunde/ DiagnosisFirst, LuChengTHU/dpm-solver, JuliaWolleb/Diffusion-based-Segmentation, hojonathanho/diffusion, guided-diffusion, bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets, nnUnet, lucidrains/vit-pytorch

Cite

Please cite

@inproceedings{wu2023medsegdiff,
  title={MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
  author={Wu, Junde and FU, RAO and Fang, Huihui and Zhang, Yu and Yang, Yehui and Xiong, Haoyi and Liu, Huiying and Xu, Yanwu},
  booktitle={Medical Imaging with Deep Learning},
  year={2023}
}
@article{wu2023medsegdiff,
  title={MedSegDiff-V2: Diffusion based Medical Image Segmentation with Transformer},
  author={Wu, Junde and Ji, Wei and Fu, Huazhu and Xu, Min and Jin, Yueming and Xu, Yanwu}
  journal={arXiv preprint arXiv:2301.11798},
  year={2023}
}

Buy Me A Coffee 🥤😉

https://ko-fi.com/jundewu

medsegdiff's People

Contributors

baiduihu avatar heikeyuhuajia avatar jiayuanz3 avatar jiwei0921 avatar lin-tianyu avatar nobleaustine avatar utkarshtambe10 avatar wujunde avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

medsegdiff's Issues

About DDTI dataset

Hi Wu, I am a beginner to the medical imaging processing. Could you share the DDTI dataset and example cases? Thanks a lot.

Sample Visualization and Metrics

We have trained the diffusion model in more than 100,000 steps and sampled the test images.
However, the predictions seem wrong as pixel values vary from 0 to tens instead of 0,1.
How to obtain the final segmentation mask?
0000015_output

How to train this model on my own dataset

Thanks for the great work, I have a problem. there are four parts of my dataset ------train_images, train_mask, test_images, test_mask. There are no jason documents or csv document. should I create a jason document in coco form for my dataset or just use the mask images. thank you!!

Problems during sample

I run the segmentation_sample.py, and meet the problem:

Logging to /root/autodl-tmp/MedSegDif/med_results/img_out/
creating model and diffusion...
sampling...
no dpm-solver
/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py:1709: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "MedSegDif/med_scripts/segmentation_sample.py", line 163, in
main()
File "MedSegDif/med_scripts/segmentation_sample.py", line 109, in main
sample, x_noisy, org, cal, cal_out = sample_fn(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 553, in p_sample_loop_known
for sample in self.p_sample_loop_progressive(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 624, in p_sample_loop_progressive
out = self.p_sample(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 435, in p_sample
out = self.p_mean_variance(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/respace.py", line 90, in p_mean_variance
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 319, in p_mean_variance
model_mean, _, _ = self.q_posterior_mean_variance(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 219, in q_posterior_mean_variance
assert x_start.shape == x_t.shape
AssertionError

I know that it is because x_start.shape is not equal to x_t.shape. However, My dataset is similar to ISICDataset, so I feel very strange.
Thanks a lot if you can reply.

Code Training Problem

Depending on the author, after downloading the dataset and running the code, the following issues arise。
image
Then after debug, we found that the length of hs is 12, so we changed the index to 11 and ran it again, which caused the following problems
image
image
Please help the author to answer, thank you。

Training with different image size

Training with image size = 128 and ISIC dataset fails.

Failing occurs when values are going through FFParser.
image

I assume this is related to hardcoded values in the instantiation of FFParser modules (unet.py file):
image

Do you have the same issue and what would be a smart fix?

brats data slice

hello, may I know how did you slice the 3D brats data into 2D data in order to put it in the directory?

Loading custom MRI datasets

Hi!

Thanks for this repo, really exciting stuff!

I have a sagittal MRI dataset that has the following dimensions: (512,512,7) (H, W, Slice) in NIFTI format. How should the input be for the network to train? In my understanding, since the autoencoder is 2D U-Net, the networks will be trained on each slice of each patient individually, however, I'm a bit confused about the input to network should be.

questions about getting test scores

Are the scores in tableⅠ of MedSedDiff the official test scores or 5-fold cross validation scores? nii.gz files are needed to be uploaded to BraTS2020 website to get the official test score, but I cannot find nii.gz generating section in related source code so I don't know how to get the test score of trained model.

Forward function in Generic_UNet

image
Hello, I notice the code in Generic_UNet define a conv in the forward function, using
image
and it will use a different weight in next call, can you tell me the reason?

Question about the mse loss for training segmentation tasks

Thanks for your great work and your effort on sharing this code. Here I am wondering that, is it stable to use mse loss for training segmentation tasks? Usually we use cross-entropy loss to train this task and this is what i am curious about.

Thanks for reading this issue and I am looking forward to your reply!

BRATS Dataset training testing split

Hi there, nice work.
Can you provide me your training and testing split for the BRATS21 dataset? I am trying to reproduce your work so I would like to know how to create the actual samples I need to train and infer upon. In the paper you wrote Train/validation/test sets are split following the default settings of the dataset , but their validation and test split sets don't have labels. Can you tell me how to find them?

Also did you do any preprocessing except slicing the images from 3D to 2D?

loss problem

image
(1) mse_diff here I understand is to predict the noise, target (noisy added) shape=[b,1,h,w], but model_output shape=[b,2,h,w], last issue you answer here two channels represent the mean and variance, can you explain the significance of them doing mse?
(2) loss_cal where target is the segmentation GT, does that model cal output represent the predicted segmentation result? Can the cal output be used directly to represent the segmentation accuracy of the model in the inference stage?
image

(3)Can you explain the meaning of sample, x_noisy, org, cal, cal_out respectively?

V2版本代码什么时候公布

对于扩散做语义分割给予了厚望,但用在自己的数据集V1的代码结果不太理想,希望能够尽早发布V2版本

When will the training stop?

Thank you for your excellent job. I wonder how many iterations will be used for training since I do not find the condition to stop training. Thank you.

Epochs of training

May I ask how many epochs do you train to obtain the result in this paper?

Pretrained model

Hi, could you please provide your pre-trained models? I train a model, but the sampling result is not right. The max value of the pixel is about 10, so the pictures are all black.

sample running error

Error using official run mode.

mentation_sample.py --data_dir /home/yp/diskdata/workspace/medsegdiff/dataset/ISIC --model_path /home/yp/diskdata/workspace/medsegdiff/results/savedmodel020000.pt --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --num_ensemble 5

about the training

When I running the scripts/segmentation_train.py have a problem.

Traceback (most recent call last):
File "D:\jace\pythonProject\MedSegDiffv2-master\scripts\segmentation_train.py", line 110, in
main()
File "D:\jace\pythonProject\MedSegDiffv2-master\scripts\segmentation_train.py", line 62, in main
TrainLoop(
File "D:\jace\pythonProject\MedSegDiffv2-master\guided_diffusion\train_util.py", line 83, in init
self._load_and_sync_parameters()
File "D:\jace\pythonProject\MedSegDiffv2-master\guided_diffusion\train_util.py", line 139, in _load_and_sync_parameters
dist_util.sync_params(self.model.parameters())
File "D:\jace\pythonProject\MedSegDiffv2-master\guided_diffusion\dist_util.py", line 78, in sync_params
dist.broadcast(p, 0)
File "C:\SoftWare\python 3.10\lib\site-packages\torch\distributed\distributed_c10d.py", line 1408, in broadcast
work.wait()
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

Model'difference when I run segmentation_sample.py

Excellent work!
I'm a beginner in the field of deep learning.
I have a question that when I run segmentation_sample.py, what's the difference between savedmodel_XXXX.pt, optsavedmodel_XXXX.pt, emasavedmodel_XXXX.pt.
Thanks a lot.

when i use dpm-solver,cuda out of memory

i looked other's issue,someone said need to change the pytorch version to 1.8.1,but when i try it it won't work,i also tried other version of pytorch still won't work

Problem of dimension

I am curious that why the model output channel dimension is 2, my output is [b image_size image_size], but your code need output [b 2 image_size image_size].
image

When args.in_ch = 5, the following error will occur

Original Traceback (most recent call last):
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 773, in forward
h = module(h, emb)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 86, in forward
x = layer(x)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [128, 5, 3, 3], expected input[8, 4, 256, 256] to have 5 channels, but got 4 channels instead

The number of steps used in training

Hi,

It is an excellent project to share with. I have a question when running the program. Is the number of steps set to 1000 during training and use only 100 steps during inference?

Thanks if the question can be answered~

Best,
CaviarLover

Problems of traing

I encountered the following problems when training with BRATS dataset!Can you help me?Thanks!

File "D:\jace\pythonProject\MedSegDiff-master\MedSegDiff-master\guided_diffusion\train_util.py", line 83, in init
self._load_and_sync_parameters()
File "D:\jace\pythonProject\MedSegDiff-master\MedSegDiff-master\guided_diffusion\train_util.py", line 139, in _load_and_sync_parameters
dist_util.sync_params(self.model.parameters())
File "D:\jace\pythonProject\MedSegDiff-master\MedSegDiff-master\guided_diffusion\dist_util.py", line 76, in sync_params
dist.broadcast(p, 0)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

training&inference time cost

Hi! I am tring to use your code on BRATS2020 with sliced input image.

If I follow README and use this command on one single 3090GPU(24G), what is the correct time cost?

image

btw, could you please share one of your training log as an example? many thx!!

Problems of sample

  1. log.txt and progress.csv will not output anything.
  2. when segmentation_sample.py is running, the terminal says there are connection-errors.
  3. i set num_ensemble=5 but just get one output image. from the terminal, it seems that something stop the iteration.

what parameters or arguments should i revise or what can i do?

JY4%GR1O_}XW{2_%}F7SR
E5%P7GDZ1 _U9BP98GBNYN
8LP) @{L0 RJDNGP1 EHE8](https://user-images.githubusercontent.com/113956389/206350520-8345aa80-6bf8-4edc-926f-26822d33a874.png) ![I6F2R41@NC8B$R33E9GX5L

Problem about calculating loss

Hello, I run scripts/ segmentation_train.py on my own datasets , and I meet the problem:
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel; (2) making sure all forward function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).

Thank you !

For multi-class seg, num_class=3 for example

Hi Junde Wu,

I have some questions for you.

The hyper-parameter in_ch=2 is fixed no matter of binary or multi-class task, where the two dimension includes the image and the mask.
For multi-calss task, what we are supposed to change is only the calibration output, i.e. sigmoid to softmax, then we can get a [1 3 H W] calibration and a [1 2 H W] model_output. Is that correct?

If we change the in_ch = 3 + 1(one-hot with the image condition), we can have the [1 3 H W] calibration, however, i do not know what is the model_output? is it something like [1 3 2 H W]? or it is also the [1 2 H W], if so, using mask rather then one-hot as the input of diffusion model seems to be meaningful?

I grouped a 5-class task into binary case to check the results. Here are one visualization, is it correct? From top to bottom, img, recovery from diffusion model, calibration, linear combination of the recovery and calibration.
image

Thanks!
Ping

An error will be reported when the image size is set to 512

this error will occur
Original Traceback (most recent call last):
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 775, in forward
uemb, cal = self.highway_forward(c, [hs[3],hs[6],hs[9],hs[12]])
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 744, in highway_forward
return self.hwm(x,hs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 2152, in forward
h = self.ffparserd
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 479, in forward
x = x * weight
RuntimeError: The size of tensor a (129) must match the size of tensor b (65) at non-singleton dimension 3

Questions about the forward function of the UNetMode

Dear authors:

I have some questions about the function of the highway_forward (Generic_UNet). Detailed as follows:

  1. On the ISIC dataset, the resolution of the input image is [batch, 3, 64, 64], which means that c is [batch, 3, 64, 64]. But the hs[12] is out of range, so we have changed the index as 2, 5, 8, and 11, corresponding to the resolution of […,64,64], […,32,32], […,16,16], and […,8,8], respectively. This operation is right?
  2. The h and hb on L768 have different resolutions. Should it be resized?

image

  1. When calculating the x=xhahb in froward function of the Generic_UNet, x, ha, and hb have different resolution. Should it be resized?

image

I hope for your response sincerely. Thanks a lot!

分割问题

I would like to ask which part of this part is to be intercepted from the path as the ID. Maybe my data is different from yours, and the error is reported after the code runs here.
elif args.data_name == 'BRATS': # slice_ID=path[0].split("_")[2] + "_" + path[0].split("_")[4] slice_ID=path[0].split("_")[-3] + "_" + path[0].split("slice")[-1].split('.nii')[0]

Why there are some unused parameters?

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.

the hyperparameters setting issue

Thank you for you great job!

Can I leverage these parameters "--diffusion_steps 50 --dpm_solver True " in training process?

请问**--diffusion_steps 50 --dpm_solver True**的参数设置可以被用于training过程中吗?
还是他们只能被用于sampling过程?

DPM-Solver Memory Problem

Hi! Thanks for your excellent work. I successfully trained a MedSegDiff-B model on my dataset but have trouble sampling.

Specifically, while using DPM-Solver to sample, the memory usage of GPU improves with the 'num_ensemble' parameter. In every ensemble model(1/5), the GPU memory improves around 2GB and finally collapses with the "CUDA out of memory" error.

This problem allows me to sample only one image before the inference process collapses. Is this a normal phenomenon? If not, how can I deal with it?

PS: using the original inference process can sample images without increasing GPU memory.

Sampling output image visualization problem?

Hello,

I'm currently training this model on my own dataset, I have created a separate dataloader python file, the file and folder structure of the dataset is the same as ISIC. No other code other than segmentation_train.py and segmentation_sample.py was changed just to load the data. The model is trained for 30000 steps so far. But when I tried to use the segmentation_sample.py for the test images, I am getting these masks.

Are these mask outputs normal for this model?
89_16_output_ens
89_24_output_ens

python 3.8.16
torch 1.13.1
torchvision 0.14.1
torchsummary 1.5.1
opencv 4.7.0.68
scikit-image 0.19.3

question about loss calculation

I have a question regarding loss calculation:
for training loss = (losses["loss"] * weights + losses['loss_cal'] * 10).mean() is used.
Why do you weigh the direct prediction of the ground truth higher compared to the comparison with a less noisy version?

Is there a reason that for inference depending on the Dice-score, different composition of cal and sample is used.

Thanks in advance!

error in create_argparser

defaults.update({k: v for k, v in model_and_diffusion_defaults().items() if k not in defaults})
Hi, i believe this is what you want to have, otherwise the value will be overwriten by those in the predefined values

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.