Git Product home page Git Product logo

project-monai / generativemodels Goto Github PK

View Code? Open in Web Editor NEW
579.0 21.0 81.0 14.26 MB

MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications

License: Apache License 2.0

Python 8.91% Jupyter Notebook 90.93% Shell 0.17%
anomaly-detection diffusion-models generative-adversarial-network generative-models image-synthesis image-translation medical-imaging monai mri-reconstruction

generativemodels's Introduction

project-monai

MONAI Generative Models

Prototyping repository for generative models to be integrated into MONAI core, MONAI tutorials, and MONAI model zoo.

Features

  • Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, Autoregressive transformers, (Multi-scale) Patch-GAN discriminator.
  • Diffusion Model Noise Schedulers: DDPM, DDIM, and PNDM.
  • Losses: Adversarial losses, Spectral losses, and Perceptual losses (for 2D and 3D data using LPIPS, RadImageNet, and 3DMedicalNet pre-trained models).
  • Metrics: Multi-Scale Structural Similarity Index Measure (MS-SSIM) and Fréchet inception distance (FID).
  • Diffusion Models, Latent Diffusion Models, and VQ-VAE + Transformer Inferers classes (compatible with MONAI style) containing methods to train, sample synthetic images, and obtain the likelihood of inputted data.
  • MONAI-compatible trainer engine (based on Ignite) to train models with reconstruction and adversarial components.
  • Tutorials including:
    • How to train VQ-VAEs, VQ-GANs, VQ-VAE + Transformers, AutoencoderKLs, Diffusion Models, and Latent Diffusion Models on 2D and 3D data.
    • Train diffusion model to perform conditional image generation with classifier-free guidance.
    • Comparison of different diffusion model schedulers.
    • Diffusion models with different parameterizations (e.g., v-prediction and epsilon parameterization).
    • Anomaly Detection using VQ-VAE + Transformers and Diffusion Models.
    • Inpainting with diffusion model (using Repaint method)
    • Super-resolution with Latent Diffusion Models (using Noise Conditioning Augmentation)

Roadmap

Our short-term goals are available in the Milestones section of the repository.

In the longer term, we aim to integrate the generative models into the MONAI core repository (supporting tasks such as, image synthesis, anomaly detection, MRI reconstruction, domain transfer)

Installation

To install the current release of MONAI Generative Models, you can run:

pip install monai-generative

To install the current main branch of the repository, run:

pip install git+https://github.com/Project-MONAI/GenerativeModels.git

Requires Python >= 3.8.

Contributing

For guidance on making a contribution to MONAI, see the contributing guidelines.

Community

Join the conversation on Twitter @ProjectMONAI or join our Slack channel.

Citation

If you use MONAI Generative in your research, please cite us! The citation can be exported from the paper.

Links

generativemodels's People

Contributors

aamir-m-khan avatar ashayp31 avatar danieltudosiu avatar ericspod avatar guopengf avatar jessyd avatar juliawolleb avatar kumoliu avatar marksgraham avatar matanat avatar mingxin-zheng avatar nic-ma avatar oesllelucena avatar pedroferreiradacosta avatar sanches-pedro avatar stijnvwijn avatar vacmar01 avatar virginiafdez avatar warvito avatar ycremar avatar yiheng-wang-nv avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

generativemodels's Issues

PatchDiscriminator has a different architecture from original paper and VQGAN/LDM implementation

Comparing the code between PatchDiscriminator and its original implementation (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e2c7618a2f2bf4ee012f43f96d1f62fd3c3bec89/models/networks.py#L539) and the VQGAN implementation (https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/discriminator/model.py#L17), it looks like our implementation have some points that could be similar to original one.

Using these parameters (to simulate the VQGAN's network)

    spatial_dims: 2
    num_channels: 64
    num_layers_d: 3
    in_channels: 1
    out_channels: 1
    kernel_size: 4
    activation: "LEAKYRELU"
    norm: "BATCH"
    bias: False
    padding: 1

We get this from the original implementation:

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

and in our implementation we are getting this

PatchDiscriminator(
  (0): Convolution(
    (conv): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (adn): ADN(
      (N): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.01)
    )
  )
  (1): Convolution(
    (conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (adn): ADN(
      (N): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.01)
    )
  )
  (2): Convolution(
    (conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (adn): ADN(
      (N): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.01)
    )
  )
  (final_conv): Convolution(
    (conv): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (adn): ADN(
      (N): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (D): Dropout(p=0.0, inplace=False)
      (A): LeakyReLU(negative_slope=0.01)
    )
  )
)

There are 3 issues in our models:

  1. In our implementation, we are skipping the first convolution and activation ((0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): LeakyReLU(negative_slope=0.2, inplace=True)) and having it starting with 128 channels instead 64 (caused by this line

    output_channels = num_channels * 2
    ).

  2. In the (2): Convolution, we are using stride 2 instead stride 1 (as shown in here https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e2c7618a2f2bf4ee012f43f96d1f62fd3c3bec89/models/networks.py#L574).

  3. The final_conv does not have NDA. It should be conv_only=True

Should we create common interface to support a second stage model?

@ericspod Should we create a common interface between VQ-VAE and AEKL to make them easy to interchange with each other when we are getting the latent representations (example proposed here #13 (comment))?

It would make these two models different from the MONAI models (such as https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/varautoencoder.py and https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/autoencoder.py) and future generative models.

Create AutoencoderKL tutorial using 2D data and torch

Create a simple tutorial using AutoencoderKL on the 2D MEDNIST data. In order to be faster to train the model, do not use all classes from MEDNIST (a couple of them will be enough to train later the conditioned diffusion models). In this tutorial, use pure torch when creating the training step and valuation (no frameworks like ignite or lightning).

Add 3D Feature Loss

Add 3D Feature Loss based on MedNet as an alternative to 2D and 2.5D perceptual losses

Add VQ-VAE network

Create the VQ-VAE network for 2D and 3D cases with a Vector Quantisation component using exponential moving averages to update the dictionary. Add the relevant unit tests and documentation.

Add AutoencoderKL

Create an AutoencoderKL for 2D and 3D, including unit tests and docstrings.

Add Transformer network

Add transformer network and components to make it compatible with VQ-VAE network. Create the components necessary to generate samples and likelihood of inputted data from the model. Add the relevant unit tests and documentation.

Add installer

Add setup.py and necessary files to be able to install this package with prototypes.

VQGAN tutorial using features in the discriminator loss instead just last layers

To compute the discriminator loss we are using all features in

logits_fake = discriminator(reconstruction.contiguous().detach())

and

logits_real = discriminator(images.contiguous().detach())

it should use just the last features like the generator

logits_fake = discriminator(reconstruction.contiguous().float())[-1]

Improve too slow / too many overlapping of networks' unit tests

Currently the networks are taking significant amount of time in the unit tests. It can be improved by remove redundancy and using smaller networks in the tests.

For example:
For the AutoencoderKL

and

are building similar networks with same components, having 2 test cases that do not increase coverage.

Other example:
For the VQVAE

[1, 3], # Batch size

is testing the network forward using a single image or 3 example in the minibatch. These 2 test cases does not test different methods of the network or different conditions, and this do not test any part added in the VQVAE class.

Add Perceptual loss

Add Perceptual Similarity metric to be used as loss. It should be compatible with a 2D and 2.5D approach.

Fix Torchscript error in latent diffusion models unet network

Try to include torchscript tests for the latent diffusion unet networks. For this, it might be necessary to remove the use of TimestepEmbedSequential and TimestepBlock by creating blocks exclusive to when we have the time embedding included.
A similar solution to Huggingface diffuser might be useful:
https://github.com/huggingface/diffusers/blob/2d35f6733a2d698e8917896071444a5923993ae7/src/diffusers/models/unet_blocks.py#L461
https://github.com/huggingface/diffusers/blob/2d35f6733a2d698e8917896071444a5923993ae7/src/diffusers/models/unet_blocks.py#L379
https://github.com/huggingface/diffusers/blob/2d35f6733a2d698e8917896071444a5923993ae7/src/diffusers/models/unet_blocks.py#L576

Add DDPM Scheduler

Add the variance scheduler proposed in the DDPM paper (with linear and cosine options) similar to huggingface code style.

DataLoader Usage

Adding persistent_workers=True to the arguments for any DataLoader object will speed the training process since it won't have to recreate processes at each epoch. This helped a lot in Windows and may be less helpful elsewhere, and it should be tested with notebooks and scripts. ThreadDataLoader may also provide some additional improvement.

Add MSSIM metric

Add Mean Structural Similarity as metric, including unit tests and documentation.

Fix perceptual loss 3D

When calling the perceptual loss with 3D images I saw the following error:

TypeError                                 Traceback (most recent call last)
Cell In [12], line 25
22 reconstruction, z_mu, z_sigma = model(images)
24 mse_loss = F.mse_loss(reconstruction.float(), images.float())
---> 25 p_loss = perceptual_loss(reconstruction.float(), images.float())
27 kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
28 kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

File ~/miniconda3/envs/genmodels/lib/python3.9/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt_homes/home4T7/jdafflon/GenerativeModels/generative/losses/perceptual.py:124, in PerceptualLoss.forward(self, input, target)
121     loss = self.perceptual_function(input, target)
122 elif self.spatial_dims == 3 and self.is_fake_3d:
123     # Compute 2.5D approach
--> 124     loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2)
125     loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3)
126     loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4)

File /mnt_homes/home4T7/jdafflon/GenerativeModels/generative/losses/perceptual.py:96, in PerceptualLoss._calculate_axis_loss(self, input, target, spatial_axis)
87 input_slices = batchify_axis(
88     x=input,
89     fake_3d_perm=(
(...)
93     + tuple(preserved_axes),
94 )
95 indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)]
---> 96 input_slices = input_slices[indices]
97 target_slices = batchify_axis(
98     x=target,
99     fake_3d_perm=(
(...)
103     + tuple(preserved_axes),
104 )
105 target_slices = target_slices[indices]

File ~/miniconda3/envs/genmodels/lib/python3.9/site-packages/monai/data/meta_tensor.py:274, in MetaTensor.torch_function(cls, func, types, args, kwargs)
272 else:
273     unpack = False
--> 274 ret = MetaTensor.update_meta(ret, func, args, kwargs)
275 return ret[0] if unpack else ret

File ~/miniconda3/envs/genmodels/lib/python3.9/site-packages/monai/data/meta_tensor.py:218, in MetaTensor.update_meta(rets, func, args, kwargs)
214 # if using e.g., batch[:, -1] or batch[..., -1], then the
215 # first element will be slice(None, None, None) and Ellipsis,
216 # respectively. Don't need to do anything with the metadata.
217 if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0:
--> 218     ret_meta = decollate_batch(args[0], detach=False)[batch_idx]
219     if isinstance(ret_meta, list):  # e.g. batch[0:2], re-collate
220         ret_meta = list_data_collate(ret_meta)

TypeError: only integer tensors of a single element can be converted to an index

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.