mehta-lab / viscy Goto Github PK
View Code? Open in Web Editor NEWcomputer vision models for single-cell phenotyping
License: BSD 3-Clause "New" or "Revised" License
computer vision models for single-cell phenotyping
License: BSD 3-Clause "New" or "Revised" License
PSNR is a commonly used metric to evaluate high-resolution features in image reconstruction problems such as deconvolution and natural image super-resolution. This seems like a useful addition to MSE and SSIM that helps with resolution-sensitive tasks.
The students will program a UNet near the start of the course and we have scheduled image translation after segmentation and denoising exercises. We have a 6-hour window for this exercise, which will be divided between deterministic image translation and generative image translation. For deterministic image translation, we'd use UNeXt2 architecture, and for generative translation we'll probably use conditional GAN used in this paper .
Considering the above plan and the need for a demo notebook for release 0.2.0 of VisCy, I suggest that we develop a demo notebook that illustrate the training of the VSCyto3D and VSNeuromast models.
Here are Alishba's fixes to last year's exercise:
https://github.com/alishbaimran/image_translation/blob/solution/solution.ipynb
https://docs.google.com/document/d/1h3u42hodHN7nQz9qND-NQc7uOm72fBi7DxURNYuuYPM/edit
Dataset
HEK293T cells with phase, membrane, and nuclei channels. Let's start with 50 FOVs.
checkpoint 1
Load zarr store, view label-free and fluorescence channels, configure model, browse the 2D UNet with tensorboard, start a training a phase->nuclei model.
checkpoint 2
Examine loss after lunch, see the regression metrics for the phase->nuclei model, train nuclei->phase model, and see the regression metrics for the nuclei->phase model.
checkpoint 3
Adjust the network capacity by different amounts and each student trains one model (phase -> nuclei, phase-> membrane, phase -> nuclei, membrane). Record the metrics on a Google doc.
At a recent meeting, we discussed strategies to achieve class balance across the cell cycle. @ziw-liu proposed a selection of FOVs based on the rough measure of the shape of the cells, which I think is a good way to digitally sort the FOVs while constructing a batch.
Let's continue to think about this. We need:
a) rough measures of the cell cycle stage:
b) strategies to assign a probability of sampling to a FOV or a patch:
viscy --help
prints a useful and succinct help message.
But, viscy subcommand --help
prints a lot of lightning CLI info that is not relevant, e.g.,
viscy preprocess --help
prints:
--lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE
One or more arguments specifying "class_path" and
"init_args" for any subclass of {torch.optim.lr_schedu
ler.LRScheduler,lightning.pytorch.cli.ReduceLROnPlatea
u}. (type: Union[LRScheduler, ReduceLROnPlateau],
known subclasses:
torch.optim.lr_scheduler.LRScheduler,
monai.optimizers.LinearLR,
monai.optimizers.ExponentialLR,
torch.optim.lr_scheduler.LambdaLR,
monai.optimizers.WarmupCosineSchedule,
torch.optim.lr_scheduler.MultiplicativeLR,
torch.optim.lr_scheduler.StepLR,
torch.optim.lr_scheduler.MultiStepLR,
torch.optim.lr_scheduler.ConstantLR,
torch.optim.lr_scheduler.LinearLR,
torch.optim.lr_scheduler.ExponentialLR,
torch.optim.lr_scheduler.SequentialLR,
torch.optim.lr_scheduler.PolynomialLR,
torch.optim.lr_scheduler.CosineAnnealingLR,
torch.optim.lr_scheduler.ChainedScheduler,
torch.optim.lr_scheduler.ReduceLROnPlateau,
lightning.pytorch.cli.ReduceLROnPlateau,
torch.optim.lr_scheduler.CyclicLR,
torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
torch.optim.lr_scheduler.OneCycleLR,
torch.optim.swa_utils.SWALR,
lightning.pytorch.cli.ReduceLROnPlateau)
Compute dataset statistics before training or testing for normalization:
--data_path DATA_PATH
(required, type: <class 'Path'>)
--channel_names CHANNEL_NAMES, --channel_names+ CHANNEL_NAMES
(type: Union[list[str], Literal[-1]], default: -1)
--num_workers NUM_WORKERS
(type: int, default: 1)
--block_size BLOCK_SIZE
(type: int, default: 32)
@ziw-liu please fix this.
The models we report in the preprints were trained with different versions of the codebase:
The VSCyto3D weights can be provided with an updated config file, while #71 will need to add API changes to separate 2D model building and also change config files accordingly before merging.
I am documenting here some of the questions that arose during the use of Viscy and the experience of running this during the DL@MBL course:
I list them as tasks (please cross them out if you think these have been solved) and/or open new issues if these are high priority.
std
or irq
.ome-zarr
stores. Current mantis chunking of (ZYX)<500MBs make IO bottlenecks.config.yml
, and order of CLI calls. Mostly solved by #43 #45.The test infrastracture of microDL 1.0 was a mixture of unittest
and nose
. However since we are rewriting it in a different framework, very few (if any) test can be reused. This brings up an opportunity to switch to pytest
, which most contributors of this project will be more familiar with since all of our other projects use it.
For 2D models, logging one sample from each batch resulted in too few samples being logged for the validation set, since the number of FOVs for validation is relatively small with regard to typical batch sizes. However, changing the logging scheme to use the first N samples from only the first batch will log samples from only the first FOV in validation due to sliding window sampling. To support both use cases, I propose the following change to the logging interface:
Current:
model:
log_num_samples: 12
Change to:
model:
log_batches_per_epoch: 4
log_samples_per_batch: 3
So for 2D training, these numbers can be (1, 12) so that even the validation epoch only has 1 batch, this will still log the same number of images.
@mattersoflight should we add a citation for the Mantis paper and Zenodo entry in the README?
@Christianfoley and I are trying to decide which is the best image format to save the inference predicted images, whether to use zarr or tiff. Zarr is better for storing the data, but there are some softwares used for processing the predicted image which works with single page tiffs. @mattersoflight has commented that we should aim to store the predictions as zarr. We can read zarr to numpy array and then perform downstream analysis (i.e., metrics evaluation, this links to issue #202). Anything to add @Christianfoley , @mattersoflight , @ziw-liu ?
I was running the MBL DL2023 example notebook
and ran into this issue at the end trying to predict the phase using 2 fluorescence channels.
tune_data = HCSDataModule(
data_path,
source_channel= ["Nuclei","Membrane"],
target_channel="Phase",
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=10,
architecture="2D",
yx_patch_size=YX_PATCH_SIZE,
augment=True,
)
tune_data.setup("fit")
tune_config = {
"architecture": "2D",
"num_filters": [24, 48, 96, 192, 384],
"in_channels": 2,
"out_channels":1,
"residual": True,
"dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
"task": "reg", # reg = regression task.
}
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[30], line 58
42 n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
44 trainer = VSTrainer(
45 accelerator="gpu",
46 devices=[GPU_ID],
(...)
55 ),
56 )
---> 58 trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
61 # Visualize the graph of fluor2phase model as image.
62 model_graph_fluor2phase = torchview.draw_graph(
63 fluor2phase_model,
64 fluor2phase_data.train_dataset[0]["source"],
65 depth=2, # adjust depth to zoom in.
66 device="cpu",
67 )
File [~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/DL-MBL-2023/04_image_translation/~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
530 self.strategy._lightning_module = model
531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
534 )
...
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [24, 2, 3, 3], expected input[1, 1, 512, 512] to have 2 channels, but got 1 channels instead---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[30], line 58
42 n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
44 trainer = VSTrainer(
45 accelerator="gpu",
46 devices=[GPU_ID],
(...)
55 ),
56 )
---> 58 trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
61 # Visualize the graph of fluor2phase model as image.
62 model_graph_fluor2phase = torchview.draw_graph(
63 fluor2phase_model,
64 fluor2phase_data.train_dataset[0]["source"],
65 depth=2, # adjust depth to zoom in.
66 device="cpu",
67 )
File [~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/DL-MBL-2023/04_image_translation/~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
530 self.strategy._lightning_module = model
531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
534 )
...
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [24, 2, 3, 3], expected input[1, 1, 512, 512] to have 2 channels, but got 1 channels instead
@Soorya19Pradeep @mattersoflight is code for aligning still needed in this repo?
One inconvenient step in the mantis analysis pipeline is having to go and remember to change the scale metadata on the virtually stained predictions. Is it possible to write the input dataset's scale as the output scale?
Feature request: Saving a faction of augmented tiles will help to get a better feeling for the augmentation parameters with the possibility to visualize them. As a lot of tiles are created during training, really only a fraction should be saved. For example, the training config could define a path to save the tiles, how many to be saved, and whether or not to save them.
BERT-style layer-wise LR decay following ConvNext v2 during fine-tuning.
In microDL, using U-Nets for segmentation was explored, leaving branching code paths for both regression (virtual staining) and segmentation tasks (e.g. in preprocessing and model architecture).
@mattersoflight Given the strategy to use virtual staining models with off-the-shelf segmentation tools such as CellPose, is it still useful to keep these under-tested code?
We use intensity scaling and noise augmentations to make the virtual staining model invariant. We should leverage augmentations while keeping the training process stable and efficient.
This paper suggests a simple strategy and reports that it is effective: include many augmentations of the same sample to construct the batch, and average the losses (which happens naturally). @ziw-liu what is the current strategy in HCSDataModule
? Can you test the strategy reported in Fig. 1B (top) of this paper?
PS: The paper also reports the regularization of a classification model with KL divergence over the augmentations. This doesn't translate naturally to virtual staining.
I was expecting that if you call HCSDataModule().setup('fit')
the DataModule should fit the data and re-write the normalization dictionary. However, when this is called twice in a row, we get:
KeyError Traceback (most recent call last)
[/home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py) in line 14
[34](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=33) # %%
[35](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=34) data_module = HCSDataModule(
[36](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=35) input_data_path,
[37](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=36) source_channel="Phase3D",
(...)
[45](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=44) augment=False, # Turn off augmentation for now.
[46](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=45) )
---> [47](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=46) data_module.setup("fit")
File [~/VisCy/viscy/light/data.py:404](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:404), in HCSDataModule.setup(self, stage)
[402](file:///home/eduardoh/VisCy/viscy/light/data.py?line=401) dataset_settings = dict(channels=channels, z_window_size=self.z_window_size)
[403](file:///home/eduardoh/VisCy/viscy/light/data.py?line=402) if stage in ("fit", "validate"):
--> [404](file:///home/eduardoh/VisCy/viscy/light/data.py?line=403) self._setup_fit(dataset_settings)
[405](file:///home/eduardoh/VisCy/viscy/light/data.py?line=404) elif stage == "test":
[406](file:///home/eduardoh/VisCy/viscy/light/data.py?line=405) self._setup_test(dataset_settings)
File [~/VisCy/viscy/light/data.py:429](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:429), in HCSDataModule._setup_fit(self, dataset_settings)
[428](file:///home/eduardoh/VisCy/viscy/light/data.py?line=427) def _setup_fit(self, dataset_settings: dict):
--> [429](file:///home/eduardoh/VisCy/viscy/light/data.py?line=428) plate, normalize_transform = self._setup_eval(dataset_settings)
[430](file:///home/eduardoh/VisCy/viscy/light/data.py?line=429) fit_transform = self._fit_transform()
[431](file:///home/eduardoh/VisCy/viscy/light/data.py?line=430) train_transform = Compose(
[432](file:///home/eduardoh/VisCy/viscy/light/data.py?line=431) [normalize_transform] + self._train_transform() + fit_transform
[433](file:///home/eduardoh/VisCy/viscy/light/data.py?line=432) )
File [~/VisCy/viscy/light/data.py:424](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:424), in HCSDataModule._setup_eval(self, dataset_settings)
[420](file:///home/eduardoh/VisCy/viscy/light/data.py?line=419) if self.normalize_source:
[421](file:///home/eduardoh/VisCy/viscy/light/data.py?line=420) norm_keys += self.source_channel
[422](file:///home/eduardoh/VisCy/viscy/light/data.py?line=421) normalize_transform = NormalizeSampled(
[423](file:///home/eduardoh/VisCy/viscy/light/data.py?line=422) norm_keys,
--> [424](file:///home/eduardoh/VisCy/viscy/light/data.py?line=423) plate.zattrs["normalization"],
[425](file:///home/eduardoh/VisCy/viscy/light/data.py?line=424) )
[426](file:///home/eduardoh/VisCy/viscy/light/data.py?line=425) return plate, normalize_transform
File [~/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py:73](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py:73), in Attributes.__getitem__(self, item)
[72](file:///home/eduardoh/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py?line=71) def __getitem__(self, item):
---> [73](file:///home/eduardoh/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py?line=72) return self.asdict()[item]
KeyError: 'normalization'
To train a model with datasets of different dimensionalities (2D/3D) and imaging modalities (bright field, quantitative phase, Zernike phase contrast etc.), the training pipeline needs to have these new features:
A key hyperparameter of our training process is the choice of augmentations. At this point, the augmentations are hard coded. @ziw-liu how can we make the augmentations configurable via pytorch lightning config? A useful side-effect is that augmentations that go into each computational experiment will be automatically documented.
I want to train the covnext model for the infection classifier problem. The model will use two or more input channels (phase + HSP90 channels + infection score channels, etc.) and should output three channels (background + uninfected + infected channels). Currently, I can perform this training using 2D and 2.5D unets, but not using covnext. It doesn't allow three-channel output due to a hardcoded scaling parameter value. @ziw-liu, can you help with this? Thanks!
Improving the visualization of predictions logged on Tensorboard by setting appropriate intensity scale. Currently the intensity scaling of the predicted images shows artifacts from other structures of cells not involved in the fluorescence expression. The scaling enhances them and produces false signals for visualization. Setting similar intensity scale for all windows of prediction can reduce the effect due to artifact. @ziw-liu , can you help implement this? Thank you!
When deploying on a microscope system for online inference, we may not have access to dataset-level summary statistics for normalization.
During a discussion with @ieivanov and @Soorya19Pradeep, the necessity of normalizing phase input came under question. If deconvolution already guarantees zero mean for background, and the exact values of foreground pixels carry physical meaning (relative phase delay), further subtraction and scaling seems redundant for thin samples such as monolayer cell cultures. The remaining small variations can be adjusted by augmentation during model training.
@mattersoflight may have more input on whether there is empirical evidence that normalization for phase is still necessary at inference time.
Label-free datasets acquired on the mantis have similar in-plane pixel size as hummingbird@63x, where a lot of the training data was acquired. To apply 2.5D virtual staining across microscopes, however, axial spatial augmentation is needed to match training data distribution (250 nm) to the wide range of existing (570 nm, czbiohub-sf/shrimPy#69) and future (205 nm) Z-sampling on mantis.
@talonchandler suggests that trilinear is a good interpolation to start with (0.4x to 1.2x scaling in Z). We will also investigate the unit of voxel intensity values in the reconstructed phase images, since mantis has a different illumination wavelength (450 nm) than hummingbird (532 nm).
Pinging @Soorya19Pradeep and @edyoshikun in case these numbers are not accurate.
Different magnification of the microscope alters the sampling of all 3 spatial dimensions. And the changes in Z is different from that in XY. If we want to train a model on a single dataset that generalizes across magnifications, we need to employ augmentation strategies that simulate the changes in spatial sampling.
@mattersoflight pointed out that scaling down is not a good approximate of reducing magnification. Blurring (e.g. Gaussian filtering) before rescaling can simulate the integration of information along the light path and reduce artifacts.
Another question is that how do we determine the training time Z sampling for better utilization of defocus information. This can potentially be estimated from magnification, Z step size, and the NA of illumination and detection.
Each time inference is run it generates a tensorboard output displaying the outputs of inference and metrics calculated along multiple orientations. Each scalar added (each sub-tag corresponding to a metric calculated) generates a new folder in a recursive structure, resulting in potentially prohibitively long session load-times depending on the number and type of metrics you calculate.
Let's have a discussion here about:
We want to compute statistics from the FOV-scale zarr store and store it with patch-scale zarr store, which dataloader will parse.
The process for this is:
This is current structure of the patches ome_zarr:
=== Summary ===
Format: omezarr v0.4
Axes: T (time); C (channel); Z (space); Y (space); X (space);
Channel names: ['RFP', 'Phase3D']
Row names: ['A', 'B']
Column names: ['3', '4']
Wells: 4
Positions: 2629
This is structure of track_labels zarr where we can store the meta data: (fov information can be stored here)
=== Summary ===
Format: omezarr v0.4
Axes: T (time); C (channel); Z (space); Y (space); X (space);
Channel names: ['tracking']
Row names: ['A', 'B']
Column names: ['3', '4']
Wells: 4
Positions: 61
Need help on how to integrate metadata into the track ome_zarr, how this will be used to generate the patch ome_zarr and then eventually the dataloader for normalization.
@ziw-liu @edyoshikun I propose we organize the scripts that use VisCy (with other libraries) for specific applications in applications/infection_screens
, applications/organelle_phenotyping
, and applications/ultrack
. I suggest this approach, because the scope of the repo is single-cell phenotyping. It is much easier to find code on git rather than disk, and encourages clean code organization.
Since we are training a number of models with HCSDataModule, it is timely to document data.py to clarify the flow of data through different methods.
I suggest the following:
@ziw-liu additional improvements are welcome, but the above should be sufficient to understand the design.
We were expecting that after setting an HCSDataModule().test_dataloader()
for i, sample in enumerate(HCSDataModule().test_dataloader()):
break
sample['target'].shape # returns( torch.Size([64, 2, 5, 512, 512]))
The expected behavior should be that sample['target']
is 1.
Note:
When predicting, I see that we select the middle of the stack, but the default setting for the target should be always be one
def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
predicting = False
if self.trainer:
if self.trainer.predicting:
predicting = True
if predicting or isinstance(batch, torch.Tensor):
# skipping example input array
return batch
if self.target_2d:
# slice the center during training or testing
z_index = self.z_window_size // 2
batch["target"] = batch["target"][:, :, slice(z_index, z_index + 1)]
return batch
With a variety of models, we see that virtually stained structures show high-frequency fluctuations over time. This may be due to the sensitivity of ConvNets to small changes in local texture and orientation of edges. One way to average over these effects is to do test time augmentation. Let's experiment with one of our trained models and compare the temporal consistency of predictions with and without test-time augmentations.
Relevant paper: https://www.nature.com/articles/s41598-020-61808-3
My understanding of the current classifier and task with 60X res data can be found here: https://docs.google.com/document/d/1j3UePmDJL_1V_9j7v3I4nLgAgKuFXXuqmlW8ZFOyTk0/edit?usp=sharing.
Given this approach, we'd need to modify HCSDataModule to support triplet sampling. Specifically:
The goal of triplet sampling is to minimize the distance between the anchor and the positive while maximizing the distance between the anchor and the negative in the learned embedding space.
# takes a base_transform and applies it to a sample to generate anchor and positive samples.
# When the __call__ method is invoked with a sample, it applies the base_transform to the sample twice: first to create the anchor and second to create the positive.
class TripletTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
anchor = self.transform(sample)
positive = self.transform(sample)
return anchor, positive
# The TripletDataset class is initialized with the dataset and a transform function. When the __getitem__ method is called with an index (idx):
# Anchor and Positive: The same data sample is retrieved for both the anchor and positive.
# Negative Sampling: A different sample is randomly selected as the negative.
# If a transform is provided:
# The TripletTransform is used to apply the base_transform to both the anchor and positive samples, creating augmented versions.
# The base_transform is applied directly to the negative sample to create its augmented version (if wanted).
class TripletDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
anchor = self.data[idx]
positive = self.data[idx]
# simple negative sampling
negative_idx = ...
negative = self.data[negative_idx]
if self.transform:
anchor = self.transform(anchor)
positive = self.transform(positive)
negative = self.transform(negative)
return (anchor, positive, negative)
Here the TripletTransform class takes a base transformation (defined in base_transform) and applies it to create the anchor and positive samples.
Modify HCSDataModule:
class TripletHCSDataModule(HCSDataModule):
def __init__(
self,
data_path: str,
source_channel: Union[str, Sequence[str]],
target_channel: Union[str, Sequence[str]],
z_window_size: int,
split_ratio: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: list[MapTransform] = [],
augmentations: list[MapTransform] = [],
caching: bool = False,
ground_truth_masks: Optional[Path] = None,
):
super().__init__(
data_path,
source_channel,
target_channel,
z_window_size,
split_ratio,
batch_size,
num_workers,
architecture,
yx_patch_size,
normalizations,
augmentations,
caching,
ground_truth_masks
)
self.triplet_transform = TripletTransform(transforms.Compose(normalizations + augmentations))
#update to use TripletDataset
def setup(self, stage: Optional[str] = None):
super().setup(stage)
if stage in ("fit", "validate"):
self.train_dataset = TripletDataset(self.train_dataset.data, transform=self.triplet_transform)
self.val_dataset = TripletDataset(self.val_dataset.data, transform=self.triplet_transform)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size // 3, # adjust batch size for triplets
num_workers=self.num_workers,
shuffle=True,
persistent_workers=bool(self.num_workers),
prefetch_factor=4 if self.num_workers else None,
drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size // 3, # adjust batch size for triplets
num_workers=self.num_workers,
shuffle=False,
prefetch_factor=4 if self.num_workers else None,
persistent_workers=bool(self.num_workers),
)
# example of what could be included in the augmentations list
base_transform = transforms.Compose([
transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)], p=0.5),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor()
])
Using this updated dataloader:
data_module = TripletHCSDataModule(
dataset_path="...",
source_channel=["Phase", "Sensor"],
target_channel=["Inf_mask"],
yx_patch_size=[128, 128],
split_ratio=0.8,
z_window_size=1,
architecture="2D",
num_workers=4,
batch_size=64,
normalizations=[
NormalizeSampled(
keys=["Sensor", "Phase"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=8,
spatial_size=[-1, 128, 128],
keys=["Sensor", "Phase", "Inf_mask"],
w_key="Inf_mask",
)
]
)
Model details:
Other ideas: try simclr vs triplet sampling
Good resource: https://lilianweng.github.io/posts/2021-05-31-contrastive/
Use ModuleList
instead of the repetitive custom method.
Originally posted in mehta-lab/microDL#214 (comment)
This error does not happen in lightning==2.0.1
, which is what I had installed by default from viscy. However, I tried upgrading to lighting 2.3.0.dev0
to circumvent the caching timeout issue here, but I got the following error for which we will have to make sure our tensors are on the right device according to this. Flagging it just in case you also encounter it @ziw-liu .
Traceback (most recent call last):
File "/hpc/projects/comp.micro/virtual_staining/models/fcmae-3d/fit/pretrain_scratch_path.py", line 141, in <module>
trainer.fit(model, data)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
return function(*args, **kwargs)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
results = self._run_stage()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
self._run_sanity_check()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
val_loop.run()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
return loop_run(self, *args, **kwargs)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 142, in run
return self.on_run_end()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 254, in on_run_end
self._on_evaluation_epoch_end()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 336, in _on_evaluation_epoch_end
trainer._logger_connector.on_epoch_end()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 195, in on_epoch_end
metrics = self.metrics
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 234, in metrics
return self.trainer._results.metrics(on_step)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 483, in metrics
value = self._get_cache(result_metric, on_step)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 447, in _get_cache
result_metric.compute()
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 289, in wrapped_func
self._computed = compute(*args, **kwargs)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 249, in compute
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 342, in reduce
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 173, in _sync_ddp_if_available
return _sync_ddp(result, group=group, reduce_op=reduce_op)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 223, in _sync_ddp
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
return func(*args, **kwargs)
File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1992, in all_reduce
work = group.allreduce([tensor], opts)
RuntimeError: No backend type associated with device type cpu
Now that 2D and 2.5D UNets from our 2020 paper are implemented in pytorch, we are exploring the space of architectures in two ways:
a) input-output tensor dimensions.
b) Using SOTA convolutional layers, particularly inspired by ConvNeXt.
At this point, the 2.1D network combines both. It is useful to have distinct nomenclature and models to compare these two innovations.
I suggest:
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.