Git Product home page Git Product logo

viscy's People

Contributors

edyoshikun avatar mattersoflight avatar ziw-liu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

viscy's Issues

image translation exercise for DL@MBL

The students will program a UNet near the start of the course and we have scheduled image translation after segmentation and denoising exercises. We have a 6-hour window for this exercise, which will be divided between deterministic image translation and generative image translation. For deterministic image translation, we'd use UNeXt2 architecture, and for generative translation we'll probably use conditional GAN used in this paper .

Considering the above plan and the need for a demo notebook for release 0.2.0 of VisCy, I suggest that we develop a demo notebook that illustrate the training of the VSCyto3D and VSNeuromast models.

Here are Alishba's fixes to last year's exercise:
https://github.com/alishbaimran/image_translation/blob/solution/solution.ipynb
https://docs.google.com/document/d/1h3u42hodHN7nQz9qND-NQc7uOm72fBi7DxURNYuuYPM/edit

draft demo dataset and notebook

Dataset
HEK293T cells with phase, membrane, and nuclei channels. Let's start with 50 FOVs.

checkpoint 1
Load zarr store, view label-free and fluorescence channels, configure model, browse the 2D UNet with tensorboard, start a training a phase->nuclei model.

checkpoint 2
Examine loss after lunch, see the regression metrics for the phase->nuclei model, train nuclei->phase model, and see the regression metrics for the nuclei->phase model.

checkpoint 3
Adjust the network capacity by different amounts and each student trains one model (phase -> nuclei, phase-> membrane, phase -> nuclei, membrane). Record the metrics on a Google doc.

sampling strategy to improve class balance across cell cycle

At a recent meeting, we discussed strategies to achieve class balance across the cell cycle. @ziw-liu proposed a selection of FOVs based on the rough measure of the shape of the cells, which I think is a good way to digitally sort the FOVs while constructing a batch.

Let's continue to think about this. We need:
a) rough measures of the cell cycle stage:

  • cytoplasm/nucleus ratio as measured from 'dirty' segmentations of target channels comes to my mind as the first measure to try. It should work for multiple cell types and microscopes.

b) strategies to assign a probability of sampling to a FOV or a patch:

  • @ziw-liu I recall you used the fluorescence channel itself as a weight mask. Can you point to that call?
  • We can preprocess or annotate each FOV to assign it a score and use the score to achieve class balance.

clean up viscy cli display

viscy --help prints a useful and succinct help message.

But, viscy subcommand --help prints a lot of lightning CLI info that is not relevant, e.g.,

viscy preprocess --help prints:

  --lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE
                        One or more arguments specifying "class_path" and
                        "init_args" for any subclass of {torch.optim.lr_schedu
                        ler.LRScheduler,lightning.pytorch.cli.ReduceLROnPlatea
                        u}. (type: Union[LRScheduler, ReduceLROnPlateau],
                        known subclasses:
                        torch.optim.lr_scheduler.LRScheduler,
                        monai.optimizers.LinearLR,
                        monai.optimizers.ExponentialLR,
                        torch.optim.lr_scheduler.LambdaLR,
                        monai.optimizers.WarmupCosineSchedule,
                        torch.optim.lr_scheduler.MultiplicativeLR,
                        torch.optim.lr_scheduler.StepLR,
                        torch.optim.lr_scheduler.MultiStepLR,
                        torch.optim.lr_scheduler.ConstantLR,
                        torch.optim.lr_scheduler.LinearLR,
                        torch.optim.lr_scheduler.ExponentialLR,
                        torch.optim.lr_scheduler.SequentialLR,
                        torch.optim.lr_scheduler.PolynomialLR,
                        torch.optim.lr_scheduler.CosineAnnealingLR,
                        torch.optim.lr_scheduler.ChainedScheduler,
                        torch.optim.lr_scheduler.ReduceLROnPlateau,
                        lightning.pytorch.cli.ReduceLROnPlateau,
                        torch.optim.lr_scheduler.CyclicLR,
                        torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
                        torch.optim.lr_scheduler.OneCycleLR,
                        torch.optim.swa_utils.SWALR,
                        lightning.pytorch.cli.ReduceLROnPlateau)

Compute dataset statistics before training or testing for normalization:
  --data_path DATA_PATH
                        (required, type: <class 'Path'>)
  --channel_names CHANNEL_NAMES, --channel_names+ CHANNEL_NAMES
                        (type: Union[list[str], Literal[-1]], default: -1)
  --num_workers NUM_WORKERS
                        (type: int, default: 1)
  --block_size BLOCK_SIZE
                        (type: int, default: 32)

@ziw-liu please fix this.

Compatibility issues

The models we report in the preprints were trained with different versions of the codebase:

  • VSCyto3D and the model used in the Mantis preprint was trained with v0.1.0a0.
    • Old config files are not compatible with the current HEAD of main (d7d1200)
    • Weights can be loaded through Python API.
  • VSNeuromast was trained with the current HEAD of main (after v0.1.0a1?)
  • VSCyto2D was trained with #71.
    • Weights are not compatible with main due to architectural changes.

The VSCyto3D weights can be provided with an updated config file, while #71 will need to add API changes to separate 2D model building and also change config files accordingly before merging.

Viscy workflow questions and experience from DL@MBL

I am documenting here some of the questions that arose during the use of Viscy and the experience of running this during the DL@MBL course:

I list them as tasks (please cross them out if you think these have been solved) and/or open new issues if these are high priority.

  • Capability to select normalization strategy between std or irq.
  • Tensorboard validation shows a (Z-stack) rather than the multiple batches
  • Difficulty determining what config file to use and what parameters to change.
  • Ability to do tiled predictions.
  • HCS DataLoader outputting the wrong shape #47
  • Tool to easily crop and/or rechunk ome-zarr stores. Current mantis chunking of (ZYX)<500MBs make IO bottlenecks.
  • Unclear preprocessing steps, config.yml, and order of CLI calls. Mostly solved by #43 #45.

Test framework

The test infrastracture of microDL 1.0 was a mixture of unittest and nose. However since we are rewriting it in a different framework, very few (if any) test can be reused. This brings up an opportunity to switch to pytest, which most contributors of this project will be more familiar with since all of our other projects use it.

Configure sample image logging during training

For 2D models, logging one sample from each batch resulted in too few samples being logged for the validation set, since the number of FOVs for validation is relatively small with regard to typical batch sizes. However, changing the logging scheme to use the first N samples from only the first batch will log samples from only the first FOV in validation due to sliding window sampling. To support both use cases, I propose the following change to the logging interface:

Current:

model:
  log_num_samples: 12

Change to:

model:
  log_batches_per_epoch: 4
  log_samples_per_batch: 3

So for 2D training, these numbers can be (1, 12) so that even the validation epoch only has 1 batch, this will still log the same number of images.

Inference output image format

@Christianfoley and I are trying to decide which is the best image format to save the inference predicted images, whether to use zarr or tiff. Zarr is better for storing the data, but there are some softwares used for processing the predicted image which works with single page tiffs. @mattersoflight has commented that we should aim to store the predictions as zarr. We can read zarr to numpy array and then perform downstream analysis (i.e., metrics evaluation, this links to issue #202). Anything to add @Christianfoley , @mattersoflight , @ziw-liu ?

Unable to load two channesl as inputs to do fluoresence to phase image translation

I was running the MBL DL2023 example notebook and ran into this issue at the end trying to predict the phase using 2 fluorescence channels.

tune_data = HCSDataModule(
    data_path,
    source_channel= ["Nuclei","Membrane"],
    target_channel="Phase",
    z_window_size=1,
    split_ratio=0.8,
    batch_size=BATCH_SIZE,
    num_workers=10,
    architecture="2D",
    yx_patch_size=YX_PATCH_SIZE,
    augment=True,
)
tune_data.setup("fit")

tune_config = {
    "architecture": "2D",
    "num_filters": [24, 48, 96, 192, 384],
    "in_channels": 2,
    "out_channels":1,
    "residual": True,
    "dropout": 0.1,  # dropout randomly turns off weights to avoid overfitting of the model to data.
    "task": "reg",  # reg = regression task.
}
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[30], line 58
     42 n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
     44 trainer = VSTrainer(
     45     accelerator="gpu",
     46     devices=[GPU_ID],
   (...)
     55         ),
     56 )  
---> 58 trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
     61 # Visualize the graph of fluor2phase model as image.
     62 model_graph_fluor2phase = torchview.draw_graph(
     63     fluor2phase_model,
     64     fluor2phase_data.train_dataset[0]["source"],
     65     depth=2,  # adjust depth to zoom in.
     66     device="cpu",
     67 )

File [~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/DL-MBL-2023/04_image_translation/~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )
...
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [24, 2, 3, 3], expected input[1, 1, 512, 512] to have 2 channels, but got 1 channels instead---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[30], line 58
     42 n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
     44 trainer = VSTrainer(
     45     accelerator="gpu",
     46     devices=[GPU_ID],
   (...)
     55         ),
     56 )  
---> 58 trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
     61 # Visualize the graph of fluor2phase model as image.
     62 model_graph_fluor2phase = torchview.draw_graph(
     63     fluor2phase_model,
     64     fluor2phase_data.train_dataset[0]["source"],
     65     depth=2,  # adjust depth to zoom in.
     66     device="cpu",
     67 )

File [~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/DL-MBL-2023/04_image_translation/~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )
...
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [24, 2, 3, 3], expected input[1, 1, 512, 512] to have 2 channels, but got 1 channels instead

Save augmented tiles

Feature request: Saving a faction of augmented tiles will help to get a better feeling for the augmentation parameters with the possibility to visualize them. As a lot of tiles are created during training, really only a fraction should be saved. For example, the training config could define a path to save the tiles, how many to be saved, and whether or not to save them.

Keep segmentation task for U-Nets?

In microDL, using U-Nets for segmentation was explored, leaving branching code paths for both regression (virtual staining) and segmentation tasks (e.g. in preprocessing and model architecture).

@mattersoflight Given the strategy to use virtual staining models with off-the-shelf segmentation tools such as CellPose, is it still useful to keep these under-tested code?

strategies to improve invariance with data augmentation

We use intensity scaling and noise augmentations to make the virtual staining model invariant. We should leverage augmentations while keeping the training process stable and efficient.

This paper suggests a simple strategy and reports that it is effective: include many augmentations of the same sample to construct the batch, and average the losses (which happens naturally). @ziw-liu what is the current strategy in HCSDataModule? Can you test the strategy reported in Fig. 1B (top) of this paper?

PS: The paper also reports the regularization of a classification model with KL divergence over the augmentations. This doesn't translate naturally to virtual staining.

Clarify how to use different stages of VisCy

I was expecting that if you call HCSDataModule().setup('fit') the DataModule should fit the data and re-write the normalization dictionary. However, when this is called twice in a row, we get:

KeyError                                  Traceback (most recent call last)
[/home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py) in line 14
      [34](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=33) # %%
      [35](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=34) data_module = HCSDataModule(
      [36](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=35)     input_data_path,
      [37](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=36)     source_channel="Phase3D",
   (...)
     [45](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=44)     augment=False,  # Turn off augmentation for now.
     [46](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=45) )
---> [47](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=46) data_module.setup("fit")

File [~/VisCy/viscy/light/data.py:404](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:404), in HCSDataModule.setup(self, stage)
    [402](file:///home/eduardoh/VisCy/viscy/light/data.py?line=401) dataset_settings = dict(channels=channels, z_window_size=self.z_window_size)
    [403](file:///home/eduardoh/VisCy/viscy/light/data.py?line=402) if stage in ("fit", "validate"):
--> [404](file:///home/eduardoh/VisCy/viscy/light/data.py?line=403)     self._setup_fit(dataset_settings)
    [405](file:///home/eduardoh/VisCy/viscy/light/data.py?line=404) elif stage == "test":
    [406](file:///home/eduardoh/VisCy/viscy/light/data.py?line=405)     self._setup_test(dataset_settings)

File [~/VisCy/viscy/light/data.py:429](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:429), in HCSDataModule._setup_fit(self, dataset_settings)
    [428](file:///home/eduardoh/VisCy/viscy/light/data.py?line=427) def _setup_fit(self, dataset_settings: dict):
--> [429](file:///home/eduardoh/VisCy/viscy/light/data.py?line=428)     plate, normalize_transform = self._setup_eval(dataset_settings)
    [430](file:///home/eduardoh/VisCy/viscy/light/data.py?line=429)     fit_transform = self._fit_transform()
    [431](file:///home/eduardoh/VisCy/viscy/light/data.py?line=430)     train_transform = Compose(
    [432](file:///home/eduardoh/VisCy/viscy/light/data.py?line=431)         [normalize_transform] + self._train_transform() + fit_transform
    [433](file:///home/eduardoh/VisCy/viscy/light/data.py?line=432)     )

File [~/VisCy/viscy/light/data.py:424](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:424), in HCSDataModule._setup_eval(self, dataset_settings)
    [420](file:///home/eduardoh/VisCy/viscy/light/data.py?line=419) if self.normalize_source:
    [421](file:///home/eduardoh/VisCy/viscy/light/data.py?line=420)     norm_keys += self.source_channel
    [422](file:///home/eduardoh/VisCy/viscy/light/data.py?line=421) normalize_transform = NormalizeSampled(
    [423](file:///home/eduardoh/VisCy/viscy/light/data.py?line=422)     norm_keys,
--> [424](file:///home/eduardoh/VisCy/viscy/light/data.py?line=423)     plate.zattrs["normalization"],
    [425](file:///home/eduardoh/VisCy/viscy/light/data.py?line=424) )
    [426](file:///home/eduardoh/VisCy/viscy/light/data.py?line=425) return plate, normalize_transform

File [~/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py:73](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py:73), in Attributes.__getitem__(self, item)
     [72](file:///home/eduardoh/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py?line=71) def __getitem__(self, item):
---> [73](file:///home/eduardoh/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py?line=72)     return self.asdict()[item]

KeyError: 'normalization'

Variable input size training and data pooling

To train a model with datasets of different dimensionalities (2D/3D) and imaging modalities (bright field, quantitative phase, Zernike phase contrast etc.), the training pipeline needs to have these new features:

  • Data loading that draws samples from different data stores
  • Batch collation and training loop that works with 2D and 3D data at the same time
    • Use loss aggregation or alternation?
  • A dynamic model stem with projection layers for different input dimensions

configurable augmentations

A key hyperparameter of our training process is the choice of augmentations. At this point, the augmentations are hard coded. @ziw-liu how can we make the augmentations configurable via pytorch lightning config? A useful side-effect is that augmentations that go into each computational experiment will be automatically documented.

Three channel output using covnext model

I want to train the covnext model for the infection classifier problem. The model will use two or more input channels (phase + HSP90 channels + infection score channels, etc.) and should output three channels (background + uninfected + infected channels). Currently, I can perform this training using 2D and 2.5D unets, but not using covnext. It doesn't allow three-channel output due to a hardcoded scaling parameter value. @ziw-liu, can you help with this? Thanks!

Improving visualization of predictions on tensorboard with appropriate intensity scale

Improving the visualization of predictions logged on Tensorboard by setting appropriate intensity scale. Currently the intensity scaling of the predicted images shows artifacts from other structures of cells not involved in the fluorescence expression. The scaling enhances them and produces false signals for visualization. Setting similar intensity scale for all windows of prediction can reduce the effect due to artifact. @ziw-liu , can you help implement this? Thank you!

Does phase input need to be normalized with dataset statistics?

When deploying on a microscope system for online inference, we may not have access to dataset-level summary statistics for normalization.

During a discussion with @ieivanov and @Soorya19Pradeep, the necessity of normalizing phase input came under question. If deconvolution already guarantees zero mean for background, and the exact values of foreground pixels carry physical meaning (relative phase delay), further subtraction and scaling seems redundant for thin samples such as monolayer cell cultures. The remaining small variations can be adjusted by augmentation during model training.

@mattersoflight may have more input on whether there is empirical evidence that normalization for phase is still necessary at inference time.

Augmentation strategy for mantis data

Label-free datasets acquired on the mantis have similar in-plane pixel size as hummingbird@63x, where a lot of the training data was acquired. To apply 2.5D virtual staining across microscopes, however, axial spatial augmentation is needed to match training data distribution (250 nm) to the wide range of existing (570 nm, czbiohub-sf/shrimPy#69) and future (205 nm) Z-sampling on mantis.

@talonchandler suggests that trilinear is a good interpolation to start with (0.4x to 1.2x scaling in Z). We will also investigate the unit of voxel intensity values in the reconstructed phase images, since mantis has a different illumination wavelength (450 nm) than hummingbird (532 nm).

Pinging @Soorya19Pradeep and @edyoshikun in case these numbers are not accurate.

Augmentation strategy for generalization across magnification

Different magnification of the microscope alters the sampling of all 3 spatial dimensions. And the changes in Z is different from that in XY. If we want to train a model on a single dataset that generalizes across magnifications, we need to employ augmentation strategies that simulate the changes in spatial sampling.

@mattersoflight pointed out that scaling down is not a good approximate of reducing magnification. Blurring (e.g. Gaussian filtering) before rescaling can simulate the integration of information along the light path and reduce artifacts.

Another question is that how do we determine the training time Z sampling for better utilization of defocus information. This can potentially be estimated from magnification, Z step size, and the NA of illumination and detection.

Discussion of what metrics to prioritize in inference tensorboard

Each time inference is run it generates a tensorboard output displaying the outputs of inference and metrics calculated along multiple orientations. Each scalar added (each sub-tag corresponding to a metric calculated) generates a new folder in a recursive structure, resulting in potentially prohibitively long session load-times depending on the number and type of metrics you calculate.

Let's have a discussion here about:

  1. What metrics have we implemented (@Soorya19Pradeep) and whether they are integrated into the inference module.
  2. Which of these metrics should we prioritize supporting on inference calls.

Normalization for patches ome-zarr

We want to compute statistics from the FOV-scale zarr store and store it with patch-scale zarr store, which dataloader will parse.

The process for this is:

  • Compute statistics per FOV (not patch) and store these with metadata.
  • Normalize using FOV statistics at training and test time - use existing CLI or preprocessing script.

This is current structure of the patches ome_zarr:

=== Summary ===
Format: omezarr v0.4
Axes: T (time); C (channel); Z (space); Y (space); X (space);
Channel names: ['RFP', 'Phase3D']
Row names: ['A', 'B']
Column names: ['3', '4']
Wells: 4
Positions: 2629

image
image

This is structure of track_labels zarr where we can store the meta data: (fov information can be stored here)
=== Summary ===
Format: omezarr v0.4
Axes: T (time); C (channel); Z (space); Y (space); X (space);
Channel names: ['tracking']
Row names: ['A', 'B']
Column names: ['3', '4']
Wells: 4
Positions: 61

Screenshot 2024-06-28 at 10 00 40 AM

Need help on how to integrate metadata into the track ome_zarr, how this will be used to generate the patch ome_zarr and then eventually the dataloader for normalization.

Testing the training pipeline

The training loop lacks automated tests. Although unit-testing DL code is not as straightforward as other software, there are some strategies to improve the coverage. See these blog posts for some ideas: 1, 2.

organize application scripts on GitHub rather than disk

@ziw-liu @edyoshikun I propose we organize the scripts that use VisCy (with other libraries) for specific applications in applications/infection_screens, applications/organelle_phenotyping, and applications/ultrack. I suggest this approach, because the scope of the repo is single-cell phenotyping. It is much easier to find code on git rather than disk, and encourages clean code organization.

documentation of data.py

Since we are training a number of models with HCSDataModule, it is timely to document data.py to clarify the flow of data through different methods.

I suggest the following:

  • Add a note to HCSDataModule docstring that explains how the data on disk is turned into a batch (what public and private methods are called in the process).
  • Add docstrings to all of the above methods.

@ziw-liu additional improvements are welcome, but the above should be sufficient to understand the design.

sample['target'] from a dataloader batch returns the wrong z-shape

We were expecting that after setting an HCSDataModule().test_dataloader()

for i, sample in enumerate(HCSDataModule().test_dataloader()):
     break
sample['target'].shape # returns( torch.Size([64, 2, 5, 512, 512]))

The expected behavior should be that sample['target'] is 1.

Note:
When predicting, I see that we select the middle of the stack, but the default setting for the target should be always be one

def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
        predicting = False
        if self.trainer:
            if self.trainer.predicting:
                predicting = True
        if predicting or isinstance(batch, torch.Tensor):
            # skipping example input array
            return batch
        if self.target_2d:
            # slice the center during training or testing
            z_index = self.z_window_size // 2
            batch["target"] = batch["target"][:, :, slice(z_index, z_index + 1)]
        return batch

Implement test time augmentations to avoid high frequency fluctuations

With a variety of models, we see that virtually stained structures show high-frequency fluctuations over time. This may be due to the sensitivity of ConvNets to small changes in local texture and orientation of edges. One way to average over these effects is to do test time augmentation. Let's experiment with one of our trained models and compare the temporal consistency of predictions with and without test-time augmentations.

Relevant paper: https://www.nature.com/articles/s41598-020-61808-3

Contrastive Learning Implementation

My understanding of the current classifier and task with 60X res data can be found here: https://docs.google.com/document/d/1j3UePmDJL_1V_9j7v3I4nLgAgKuFXXuqmlW8ZFOyTk0/edit?usp=sharing.

Given this approach, we'd need to modify HCSDataModule to support triplet sampling. Specifically:

  • We can apply transformations like rotation, cropping, and color jitter to create different views of the same cell.
  • Generate an anchor and a positive sample using different augmentations of the same cell image.
  • Select a different cell with a different label (infected vs. uninfected) as the negative sample.

The goal of triplet sampling is to minimize the distance between the anchor and the positive while maximizing the distance between the anchor and the negative in the learned embedding space.

# takes a base_transform and applies it to a sample to generate anchor and positive samples.
# When the __call__ method is invoked with a sample, it applies the base_transform to the sample twice: first to create the anchor and second to create the positive.

class TripletTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        anchor = self.transform(sample)
        positive = self.transform(sample)
        return anchor, positive


# The TripletDataset class is initialized with the dataset and a transform function. When the __getitem__ method is called with an index (idx):
# Anchor and Positive: The same data sample is retrieved for both the anchor and positive.
# Negative Sampling: A different sample is randomly selected as the negative.
#  If a transform is provided:
# The TripletTransform is used to apply the base_transform to both the anchor and positive samples, creating augmented versions.
# The base_transform is applied directly to the negative sample to create its augmented version (if wanted).

class TripletDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        anchor = self.data[idx]
        positive = self.data[idx]
        # simple negative sampling
        negative_idx = ...
        negative = self.data[negative_idx]
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)
        return (anchor, positive, negative)


Here the TripletTransform class takes a base transformation (defined in base_transform) and applies it to create the anchor and positive samples.

Modify HCSDataModule:

class TripletHCSDataModule(HCSDataModule):
    def __init__(
        self,
        data_path: str,
        source_channel: Union[str, Sequence[str]],
        target_channel: Union[str, Sequence[str]],
        z_window_size: int,
        split_ratio: float = 0.8,
        batch_size: int = 16,
        num_workers: int = 8,
        architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
        yx_patch_size: tuple[int, int] = (256, 256),
        normalizations: list[MapTransform] = [],
        augmentations: list[MapTransform] = [],
        caching: bool = False,
        ground_truth_masks: Optional[Path] = None,
    ):
        super().__init__(
            data_path,
            source_channel,
            target_channel,
            z_window_size,
            split_ratio,
            batch_size,
            num_workers,
            architecture,
            yx_patch_size,
            normalizations,
            augmentations,
            caching,
            ground_truth_masks
        )
        self.triplet_transform = TripletTransform(transforms.Compose(normalizations + augmentations))

#update to use TripletDataset
    def setup(self, stage: Optional[str] = None):
        super().setup(stage)
        if stage in ("fit", "validate"):
            self.train_dataset = TripletDataset(self.train_dataset.data, transform=self.triplet_transform)
            self.val_dataset = TripletDataset(self.val_dataset.data, transform=self.triplet_transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size // 3,  # adjust batch size for triplets
            num_workers=self.num_workers,
            shuffle=True,
            persistent_workers=bool(self.num_workers),
            prefetch_factor=4 if self.num_workers else None,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size // 3,  # adjust batch size for triplets
            num_workers=self.num_workers,
            shuffle=False,
            prefetch_factor=4 if self.num_workers else None,
            persistent_workers=bool(self.num_workers),
        )

# example of what could be included in the augmentations list
base_transform = transforms.Compose([
    transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)], p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()
])

Using this updated dataloader:

data_module = TripletHCSDataModule(
    dataset_path="...",
    source_channel=["Phase", "Sensor"],
    target_channel=["Inf_mask"],
    yx_patch_size=[128, 128],
    split_ratio=0.8,
    z_window_size=1,
    architecture="2D",
    num_workers=4,
    batch_size=64,
    normalizations=[
        NormalizeSampled(
            keys=["Sensor", "Phase"],
            level="fov_statistics",
            subtrahend="median",
            divisor="iqr",
        )
    ],
    augmentations=[
        RandWeightedCropd(
            num_samples=8,
            spatial_size=[-1, 128, 128],
            keys=["Sensor", "Phase", "Inf_mask"],
            w_key="Inf_mask",
        )
    ]
)

Model details:

  • Use an encoder for embeddings. The contrastive learning model uses this encoder to generate embeddings and compute the triplet loss. Different losses: Triplet Margin Loss, AllTripletMiner, NTXent.
  • Input: The input to the model is the same (e.g., phase and sensor data).
  • Output: The model outputs embeddings.
  • Loss Function: Triplet loss is used to train the model to minimize the distance between embeddings of similar samples and maximize the distance between embeddings of dissimilar samples.
  • Validation: The validation process compares the embeddings using the triplet loss, ensuring that the model learns useful representations of the cells.

Other ideas: try simclr vs triplet sampling

  • SimCLR: generates positive pairs by applying different augmentations to the same sample. Negative samples are implicitly created from other samples in the same batch.
  • Triplet Sampling: explicitly forms triplets consisting of an anchor, a positive, and a negative.

Good resource: https://lilianweng.github.io/posts/2021-05-31-contrastive/

upgrading lightning to >2.0.8 results to trainer issues

This error does not happen in lightning==2.0.1, which is what I had installed by default from viscy. However, I tried upgrading to lighting 2.3.0.dev0 to circumvent the caching timeout issue here, but I got the following error for which we will have to make sure our tensors are on the right device according to this. Flagging it just in case you also encounter it @ziw-liu .

Traceback (most recent call last):
  File "/hpc/projects/comp.micro/virtual_staining/models/fcmae-3d/fit/pretrain_scratch_path.py", line 141, in <module>
    trainer.fit(model, data)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 142, in run
    return self.on_run_end()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 254, in on_run_end
    self._on_evaluation_epoch_end()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 336, in _on_evaluation_epoch_end
    trainer._logger_connector.on_epoch_end()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 195, in on_epoch_end
    metrics = self.metrics
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 234, in metrics
    return self.trainer._results.metrics(on_step)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 483, in metrics
    value = self._get_cache(result_metric, on_step)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 447, in _get_cache
    result_metric.compute()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 289, in wrapped_func
    self._computed = compute(*args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 249, in compute
    value = self.meta.sync(self.value.clone())  # `clone` because `sync` is in-place
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 342, in reduce
    return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 173, in _sync_ddp_if_available
    return _sync_ddp(result, group=group, reduce_op=reduce_op)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 223, in _sync_ddp
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1992, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: No backend type associated with device type cpu

architectures and nomenculature

Now that 2D and 2.5D UNets from our 2020 paper are implemented in pytorch, we are exploring the space of architectures in two ways:
a) input-output tensor dimensions.
b) Using SOTA convolutional layers, particularly inspired by ConvNeXt.

At this point, the 2.1D network combines both. It is useful to have distinct nomenclature and models to compare these two innovations.

I suggest:

  • 2D, 2.5D, 2.1D, 3D architectures use classical convolutional layers and activations.
  • Architectures that use ConvNeXt design principles can use 2NeX, 2.5NeX, 2.1Nex, ... nomenclature.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.