Git Product home page Git Product logo

jameschapman19 / cca_zoo Goto Github PK

View Code? Open in Web Editor NEW
179.0 3.0 38.0 7.93 MB

Canonical Correlation Analysis Zoo: A collection of Regularized, Deep Learning based, Kernel, and Probabilistic methods in a scikit-learn style framework

Home Page: https://cca-zoo.readthedocs.io/en/latest/

License: MIT License

Python 99.70% Shell 0.30%
dcca deep cca canonical-correlation-analysis kernel multiview pytorch cca-zoo multiset-cca tensor-cca

cca_zoo's Introduction

drawing

CCA-Zoo

Unlock the hidden relationships in multiview data.

DOI codecov Build Status Documentation Status version downloads DOI

Introduction

In today's data-driven world, revealing hidden relationships across multiview datasets is critical. CCA-Zoo is your go-to library, featuring a robust selection of linear, kernel, and deep canonical correlation analysis methods.

Designed to be user-friendly, CCA-Zoo is inspired by the ease of use in scikit-learn and mvlearn. It provides a seamless programming experience with familiar fit, transform, and fit_transform methods.

๐Ÿ“– Table of Contents

๐Ÿš€ Quick Start

Installation

Whether you're a pip enthusiast or a poetry aficionado, installing CCA-Zoo is a breeze:

pip install cca-zoo
# For additional features
pip install cca-zoo[probabilistic, visualisation, deep]

For Poetry users:

poetry add cca-zoo
# For extra features
poetry add cca-zoo[probabilistic, visualisation, deep]

Note that deep requires torch and lightning which may be better installed separately following the PyTorch installation guide.

probabilistic requires numpyro which may be better installed separately following the NumPyro installation guide.

visualisation requires matplotlib and seaborn

Plug into the Machine Learning Ecosystem

CCA-Zoo is designed to be compatible with the machine learning ecosystem. It is built on top of scikit-learn, tensorly, torch, pytorch-lightning, and numpyro.

drawing

๐ŸŽ๏ธ Performance Highlights

CCA-Zoo shines when it comes to high-dimensional data analysis. It significantly outperforms scikit-learn, particularly as dimensionality increases. For comprehensive benchmarks, see our script and the graph below.

Benchmark Plot CCA Benchmark Plot PLS

๐Ÿ“š Detailed Documentation

Embark on a journey through multiview correlations with our comprehensive guide.

๐Ÿ™ How to Cite

Your support means a lot to us! If CCA-Zoo has been beneficial for your research, there are two ways to show your appreciation:

  1. Star our GitHub repository.
  2. Cite our research paper in your publications.

For citing our work, please use the following BibTeX entry:

@software{Chapman_CCA-Zoo_2023,
author = {Chapman, James and Wang, Hao-Ting and Wells, Lennie and Wiesner, Johannes},
doi = {10.5281/zenodo.4382739},
month = aug,
title = {{CCA-Zoo}},
url = {https://github.com/jameschapman19/cca_zoo},
version = {2.3.0},
year = {2023}
}

Or check out our JOSS paper:

๐Ÿ“œ Chapman et al., (2021). CCA-Zoo: A collection of Regularized, Deep Learning based, Kernel, and Probabilistic CCA methods in a scikit-learn style framework. Journal of Open Source Software, 6(68), 3823, Link.

๐Ÿ‘ฉโ€๐Ÿ’ป Contribute

Every idea, every line of code adds value. Check out our contribution guide and help CCA-Zoo soar to new heights!

๐Ÿ™Œ Acknowledgments

Special thanks to the pioneers whose work has shaped this field. Explore their work:

cca_zoo's People

Contributors

htwangtw avatar jameschapman19 avatar johanneswiesner avatar w-l-w avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

cca_zoo's Issues

Problem with linear transforms

Had forgotten to subtract the mean of the training data when performing out of sample transformations. This resulted in correlations of transforms not matching the correlations found by the model.

RuntimeError: symeig_cpu: the algorithm failed to converge; 149 off-diagonal elements of an intermediate tridiagonal form did not converge to zero.

When I run the following code, I get the above error

import numpy as np
from torch import optim
from cca_zoo import data
import scipy.io as sio
from cca_zoo.deepmodels import DCCA, DCCAE, DVCCA, DCCA_NOI, DTCCA, SplitAE, DeepWrapper
from cca_zoo.deepmodels import objectives, architectures
X = np.random.rand(2000, 2048)
Y = np.random.rand(2000, 2048)
latent_dims = 150
device = 'cpu'
encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=2048)
encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=2048)

DCCA

dcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2],
objective=objectives.CCA)

hidden_layer_sizes are shown explicitly but these are also the defaults

dcca_model = DeepWrapper(dcca_model, device=device)
dcca_model.fit((X, Y), epochs=10)

nan in transformed matrix

Hi,
I run basic CCA on some dataset, it works well, but sometimes I get nans at the final columns of now of the transformed matrices, for example:

basic_cca = CCA(latent_dims=768)
basic_cca.fit((layer, second_layer))
U, V = basic_cca.transform((layer, second_layer))

Here layer and second_layer are both matrices of shape [Nx768], and sometimes I get that the last column of U (i.e. U[:, 767]) is all nans. Just to mention, if I change latent_dims to 767 or lower, everything works(but I need it to be 768). Is there any idea how to solve it? Or what can I change in order to solve it?
Thanks

Flexible model architectures

Latest versions that pass architectures as classes using Config() class are close to giving the right amount of flexibility but the functions to build the models look a bit clunky and also only pass input size, latent dimension size and hidden layer sizes which limits flexibility with e.g. stride/padding. So will look to tidy this up.

Deflation

Need to have a think about how to implement deflation for iterative methods. In particular methods based on regularised alternating least squares need to be treated differently to the rotations method used by sklearn and adapted here.

weighted GCCA

hi,
first of all, i wanna thank you for such great cca-family package.
i wonder for multiple views input like GCCA, now are we allowed to set different weights for different views input which may be quite common in practical use, for example some view is more important than others
may we expect this in future release?
paper mentioned such idea http://www.cs.jhu.edu/~mdredze/publications/2016_acl_multiview.pdf

thanks again for your wonderful package

All estimators failed to fit

I am using cca_zoo.models.PMD in combination with cca_zoo.model_selection.GridSearchCV, but I am getting this error message (even when using artificial data):

  File "C:\Users\Johannes.Wiesner\Documents\projects\hcp_project\testing\cca_zoo_bug.py", line 49, in <module>
    grid.fit([train_view_1,train_view_2])

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\cca_zoo\model_selection\_search.py", line 344, in fit
    self._run_search(evaluate_candidates)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\cca_zoo\model_selection\_search.py", line 657, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\cca_zoo\model_selection\_search.py", line 328, in evaluate_candidates
    _insert_error_scores(out, self.error_score)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\sklearn\model_selection\_validation.py", line 331, in _insert_error_scores
    raise NotFittedError("All estimators failed to fit")

NotFittedError: All estimators failed to fit
import numpy as np
from cca_zoo.models import PMD
from cca_zoo.model_selection import GridSearchCV
from cca_zoo.data import generate_covariance_data
(train_view_1,train_view_2),_=generate_covariance_data(n=200,view_features=[10,10],latent_dims=1,correlation=1)

rng = np.random.RandomState(42)

## Set parameters #############################################################

deflation='pls'
max_iter= 100
latent_dims = 3
tol = 1e-15
c1 = np.linspace(0.1,1,10)
c2 = np.linspace(0.1,1,10)
param_grid = {'c': [c1,c2]}

###############################################################################
## Run CCA analysis with cca-zoo package ######################################
###############################################################################

estimator = PMD(latent_dims=latent_dims,
                random_state=rng,
                deflation=deflation,
                max_iter=max_iter,
                tol=tol)

def scorer(estimator,X):
    dim_corrs=estimator.score(X)
    return dim_corrs.mean()

grid = GridSearchCV(estimator,
                    param_grid=param_grid,
                    cv=None,
                    verbose=True,
                    scoring=scorer)

grid.fit([train_view_1,train_view_2])

Ridge CCA

Realised that I don't actually have ridge CCA.

Replace assertion with exception

Hi, thanks for creating this super useful project!
I noticed you used assertion for error catching in the modules.
It would be clearer for the user if these are handled with exceptions and improve the error message.
Let me know if you think this is a good idea. Happy to work on a PR for this.

Permutation tests for CCA

I would like to implement (if not already available elsewhere) a Python version of the permutation tests for CCA described in

Winkler AM, Renaud O, Smith SM, Nichols TE. Permutation Inference for Canonical Correlation Analysis. NeuroImage. 2020; 117065 (see article here)

The paper comes with a repository (Matlab) that could be ported to Python without requiring additional dependencies.

nice plotting functions?

Used to have these in the package and at some point removed them but now think they should be back

About DeCCA

Hi,

Thanks for the excellent toolbox in the first place!

When I tried to run the toolbox ("example_main.py"), there is an error saying that "No module named 'DeCCA'". I checked and found that files about DeCCA are indeed missing in the repo. Would you please confirm that? Are they not released yet?

Thanks and have a nice day!

DCCA transform method requires dataloader

Hi,

After the recent update that uses pytorch lightning instead of deep wrapper, there is no option to pass a tuple for the transform function and it accepts only pytorch dataloader, for example, that simple code worked before:

a = np.random.randn(2000, 50)
b = np.random.randn(2000, 100)
m1 = min(a.shape[1], b.shape[1])
train_dataset = data.CCA_Dataset([a, b])
encoder_a = Encoder(latent_dims=m1, feature_size=50, layer_sizes=[128, 256])
encoder_b = Encoder(latent_dims=m1, feature_size=100, layer_sizes=[128, 256])
dcca = DCCA(latent_dims=m1, objective=objectives.CCA, encoders=[encoder_a, encoder_b])
dcca = DeepWrapper(dcca, device='cpu').fit(train_dataset, epochs=10)
U, V = dcca.transform((a, b))

Now I wrote it using pytoch lightning:

a = np.random.randn(2000, 50)
b = np.random.randn(2000, 100)
c = np.random.randn(2000, 50)
d = np.random.randn(2000, 100)
m1 = min(a.shape[1], b.shape[1])

train_dataset = data.CCA_Dataset([a, b])
val_dataset = data.CCA_Dataset([c, d])
train_loader, val_loader = get_dataloaders(train_dataset, val_dataset)

# feature_size - input, latent_dim - output
encoder_a = Encoder(latent_dims=m1, feature_size=50, layer_sizes=[128, 256])
encoder_b = Encoder(latent_dims=m1, feature_size=100, layer_sizes=[128, 256])

dcca = DCCA(latent_dims=m1, objective=objectives.CCA, encoders=[encoder_a, encoder_b])
optimizer = optim.Adam(dcca.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1)
dcca = CCALightning(dcca, optimizer=optimizer, lr_scheduler=scheduler)
trainer = pl.Trainer(max_epochs=2, enable_checkpointing=False, gpus=1 if torch.cuda.is_available() else 0)
trainer.fit(dcca, train_loader, val_loader)
U, V = dcca.transform((a, b))

Now raises an error. Is there an option to add again the option to pass a tuple?

Not able to run examples

I installed both the below

pip install cca-zoo
pip install ccc-zoo[deep]

and I am trying to run the example in the link below
https://cca-zoo.readthedocs.io/en/latest/auto_examples/plot_dcca.html#sphx-glr-auto-examples-plot-dcca-py

I am getting an error
(ImportError: cannot import name 'Split_MNIST_Dataset')
But I am able to import Tangled_MNIST_dataset

By executing "dir(cca_zoo)", I get the following
['builtins', 'cached', 'doc', 'file', 'loader', 'name', 'package', 'path', 'spec', 'cca_zoo', 'data', 'deepmodels', 'models', 'utils']

Has Split_MNIST_Dataset been removed ?

Also there are several other functions not available.
For ex "dir(cca_zoo.deepmodels)", gives

['DCCA', 'DCCAE', 'DCCA_NOI', 'DTCCA', 'DVCCA', 'DeepWrapper', 'SplitAE', '_DCCA_base', 'builtins', 'cached', 'doc', 'file', 'loader', 'name', 'package', 'path', 'spec', '_dcca_base', 'architectures', 'cca_zoo', 'dcca', 'dcca_noi', 'dccae', 'deepwrapper', 'dtcca', 'dvcca', 'objectives', 'splitae']

So I cannot find "CCALightning", "get_dataloaders" etc.

I am currently using the following versions
PyTorch: Version: 1.10.2
PyTorch-lighning: Version: 1.5.9
cca-zoo: Version: 1.7.13

Please help with this and please let me know if anything has been changed within the versions.
Thank you in advance for the help.

[JOSS Review] - CCA_Tutorial notebook produces errors on fresh install

Including this here as part of my JOSS review:

Looks like the first notebook produces errors in the following sections. These can also been seen in the notebook cell outputs if opening in Google Collab. Building tutorial notebook as part of your CI might be a helpful way to catch these and other issues in the future:

  • The "Overfitting and Sample Size" section produces an error: rCCA object has no attribute scores
  • The "Ridge Regularised CCA: from CCA to PLS produces an error: Expected 2D array, got 1D array instead
  • The "Regularised CCA" under the neuroimaging example also produces a rCCA object has no attribute scores

DCCA float input [help]

Does the DCCA model not support float values? Because I get the error below when I try to train the model.

I got the code working before the shift to pylightning. My data looks like this
image

Anyone can help me get this working?

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_22306/563663734.py in <module>
     12 
     13 trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False)
---> 14 trainer.fit(dcca.float(), train_loader, val_loader)
     15 print("Time taken to train:", datetime.now() - then)

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    735             )
    736             train_dataloaders = train_dataloader
--> 737         self._call_and_handle_interrupt(
    738             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    739         )

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    680         """
    681         try:
--> 682             return trainer_fn(*args, **kwargs)
    683         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    684         except KeyboardInterrupt as exception:

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    770         # TODO: ckpt_path only in v1.7
    771         ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 772         self._run(model, ckpt_path=ckpt_path)
    773 
    774         assert self.state.stopped

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1193 
   1194         # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1195         self._dispatch()
   1196 
   1197         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
   1273             self.training_type_plugin.start_predicting(self)
   1274         else:
-> 1275             self.training_type_plugin.start_training(self)
   1276 
   1277     def run_stage(self):

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    200     def start_training(self, trainer: "pl.Trainer") -> None:
    201         # double dispatch to initiate the training loop
--> 202         self._results = trainer.run_stage()
    203 
    204     def start_evaluating(self, trainer: "pl.Trainer") -> None:

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
   1283         if self.predicting:
   1284             return self._run_predict()
-> 1285         return self._run_train()
   1286 
   1287     def _pre_training_routine(self):

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1305             self.progress_bar_callback.disable()
   1306 
-> 1307         self._run_sanity_check(self.lightning_module)
   1308 
   1309         # enable train mode

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_sanity_check(self, ref_model)
   1369             # run eval step
   1370             with torch.no_grad():
-> 1371                 self._evaluation_loop.run()
   1372 
   1373             self.call_hook("on_sanity_check_end")

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
    108         dl_max_batches = self._max_batches[dataloader_idx]
    109 
--> 110         dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
    111 
    112         # store batch level output per dataloader

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in advance(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
    120         # lightning module methods
    121         with self.trainer.profiler.profile("evaluation_step_and_end"):
--> 122             output = self._evaluation_step(batch, batch_idx, dataloader_idx)
    123             output = self._evaluation_step_end(output)
    124 

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in _evaluation_step(self, batch, batch_idx, dataloader_idx)
    215             self.trainer.lightning_module._current_fx_name = "validation_step"
    216             with self.trainer.profiler.profile("validation_step"):
--> 217                 output = self.trainer.accelerator.validation_step(step_kwargs)
    218 
    219         return output

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, step_kwargs)
    234         """
    235         with self.precision_plugin.val_step_context():
--> 236             return self.training_type_plugin.validation_step(*step_kwargs.values())
    237 
    238     def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
    217 
    218     def validation_step(self, *args, **kwargs):
--> 219         return self.model.validation_step(*args, **kwargs)
    220 
    221     def test_step(self, *args, **kwargs):

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/cca_zoo/deepmodels/trainers.py in validation_step(self, batch, batch_idx)
     55     def validation_step(self, batch, batch_idx):
     56         data, label = batch
---> 57         loss = self.model.loss(*data)
     58         self.log("val loss", loss)
     59         return loss

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/cca_zoo/deepmodels/dcca.py in loss(self, *args)
     53         :return:
     54         """
---> 55         z = self(*args)
     56         return self.objective.loss(*z)
     57 

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/cca_zoo/deepmodels/dcca.py in forward(self, *args)
     43         z = []
     44         for i, encoder in enumerate(self.encoders):
---> 45             z.append(encoder(args[i]))
     46         return z
     47 

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/cca_zoo/deepmodels/architectures.py in forward(self, x)
     66 
     67     def forward(self, x):
---> 68         x = self.layers(x)
     69         if self.variational:
     70             mu = self.fc_mu(x)

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

~/nobackup/miniconda3/envs/matchms/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: expected scalar type Double but found Float```

extracting latent dimensions from CCA fit + GridsearchCV import error

Hi,

Thank you for this amazing toolbox, I think that is something that was really missing in the Python ecosystem ๐Ÿ™Œ

I would like to apply different flavours of CCA to resting-state fMRI connectivity and behavioural data, and would like to be able to reproduce the analyses described in this paper:

Wang HT, Poerio G, Murphy C, Bzdok D, Jefferies E, Smallwood J. Dimensions of Experience: Exploring the Heterogeneity of the Wandering Mind. Psychol Sci. 2018 Jan;29(1):56-71. doi: 10.1177/0956797617728727. Epub 2017 Nov 13. PMID: 29131720; PMCID: PMC6346304.

However, I am not sure exactly how the API let you extract the different views of the CCA fit (i.e. the different orthogonal dimensions). I tried to change the latent_dims parameters but the size of the CCA weights is still the same (should we get different weights from each view?). Sorry if this is a very simple question, I am just starting with CCA.

Then I tried to fit_transform the data using a code example from the user manual to apply recursively many CCA, but get this error:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
/tmp/ipykernel_1377153/194783063.py in <module>
      1 from cca_zoo.models import rCCA
----> 2 from cca_zoo.model_selection import GridsearchCV
      3 
      4 def scorer(estimator,X):
      5     dim_corrs=estimator.score(X)

ImportError: cannot import name 'GridsearchCV' from 'cca_zoo.model_selection' (/opt/anaconda3/lib/python3.8/site-packages/cca_zoo/model_selection/__init__.py)

I am happy to help writing documentation, plotting functionalities or testing things if that can help.

RuntimeError: expected scalar type Double but found Float

When running the tutorial Fashion MINST examples of Deep CCA, the error "RuntimeError: expected scalar type Double but found Float" appears in the training phase. Moreover when running the DCCAE example the same errors is thrown for the dcca_model.predict_corr function. The last part of the error seems to be due to some Torch problem.

~\AppData\Local\Temp/ipykernel.py in
13
14 dccae_results = np.stack(
---> 15 (dccae_model.train_correlations[0, 1], dccae_model.predict_corr(test_dataset)[0, 1]))

~\anaconda3\envs\CCA\lib\site-packages\cca_zoo\deepmodels\deepwrapper.py in predict_corr(self, test_dataset, train, batch_size)
191 :return: numpy array containing correlations between each pair of views for each dimension (#views*#views*#latent_dimensions)
192 """
--> 193 transformed_views = self.transform(test_dataset, train=train, batch_size=batch_size)
194 all_corrs = []
195 for x, y in itertools.product(transformed_views, repeat=2):

~\anaconda3\envs\CCA\lib\site-packages\cca_zoo\deepmodels\deepwrapper.py in transform(self, test_dataset, test_labels, train, batch_size)
209 for batch_idx, (data, label) in enumerate(test_dataloader):
210 data = [d.to(self.device) for d in list(data)]
--> 211 z = self.model(*data)
212 if batch_idx == 0:
213 z_list = [z_i.detach().cpu().numpy() for i, z_i in enumerate(z)]

~\anaconda3\envs\CCA\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []

................................

~\anaconda3\envs\CCA\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
94
95 def forward(self, input: Tensor) -> Tensor:
---> 96 return F.linear(input, self.weight, self.bias)
97
98 def extra_repr(self) -> str:

~\anaconda3\envs\CCA\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
1845 if has_torch_function_variadic(input, weight):
1846 return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1847 return torch._C._nn.linear(input, weight, bias)
1848
1849

TCCA Comments

As raised on twitter. The code around TCCA could do with some more explanation

Implement reordering function to assess feature significance?

In their paper from 2018, Xia et a. implemented a method to match canonical variates from resampled data sets to the original data set in order to be able to compute p-values for their canonical weights.

Page 12 (Methods Section):

As permutation could induce arbitrary axis rotation, which changes the order of canonical variates, or axis reflection, which causes a sign change for the weights, we matched the canonical variates resulting from permuted data matrices to the ones derived from the original data matrix by comparing the clinical loadings (v) (75. Miลกiฤ‡, B. et al. Network-level structure-function relationships in human
neocortex. Cereb. Cortex 26, 3285โ€“96 (2016).)

Here's the code they implemented to achieve this:
https://github.com/cedricx/sCCA/blob/d5a2f4cb071bddd3f7d805e02ff27828b8494c66/sCCA/code/final/cca_functions.R#L191

Would it make sense to implement this method for cca-zoo? I am not even sure if this is a 'good' method having this issue in mind? But if I got it right, it's one thing to assess the overall significance of the canonical variates themselves and another thing to assess the significance of the feature weights on the canonical variates?

AttributeError: 'GCCA' object has no attribute 'tol'

First off, thanks for the package!

When I run the below code:

from cca_zoo import wrappers
import numpy as np

linear_cca = wrappers.GCCA(latent_dims=22, tol=1e-05)

# create data in advance
a = np.random.rand(50, 50)
b = np.random.rand(50, 60)
x = np.random.rand(50, 90)

linear_cca.fit(a,b,x)

it returns the following error at the last line:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\TEMP\python_envs\notebook_env\lib\site-packages\sklearn\base.py", line 260, in __repr__
    repr_ = pp.pformat(self)
  File "C:\TEMP\python_envs\notebook_env\lib\pprint.py", line 144, in pformat
    self._format(object, sio, 0, 0, {}, 0)
  File "C:\TEMP\python_envs\notebook_env\lib\pprint.py", line 161, in _format
    rep = self._repr(object, context, level)
  File "C:\TEMP\python_envs\notebook_env\lib\pprint.py", line 393, in _repr
    self._depth, level)
  File "C:\TEMP\python_envs\notebook_env\lib\site-packages\sklearn\utils\_pprint.py", line 181, in format
    changed_only=self._changed_only)
  File "C:\TEMP\python_envs\notebook_env\lib\site-packages\sklearn\utils\_pprint.py", line 425, in _safe_repr
    params = _changed_params(object)
  File "C:\TEMP\python_envs\notebook_env\lib\site-packages\sklearn\utils\_pprint.py", line 91, in _changed_params
    params = estimator.get_params(deep=False)
  File "C:\TEMP\python_envs\notebook_env\lib\site-packages\sklearn\base.py", line 195, in get_params
    value = getattr(self, key)
AttributeError: 'GCCA' object has no attribute 'tol'

Can you reproduce this error?

Thanks again

Using scikit BaseEstimator effectively.

Ideally I'd like to adjust the API slightly so that it works effectively with the BaseEstimator class. At the moment I require parameters to be passed with fit(*views, params=dict) or gridsearch_fit(*views, param_candidates=dict). A better solution would be to follow scikit api and have model(params=X).fit(*views). However, BaseEstimator requires all arguments to be specified as keyword arguments and I have thus far been using lists of parameters (one for each view). This would then prevent me from being able to use all the gridsearch stuff down the line anyway. If I can come up with a solution this would be really valuable.

Support for different dimension vector for DCCA

Hi,

I'm the current version, DCCA only works if the two vectors have the same dimensionality, however this is not a necessary assumption as DCCA also works for different size vectors. Is there a possibility to add support for different dimension vector in DCCA?

AttributeError: 'DeepWrapper' object has no attribute 'lr'

So I get this error twice (see output below) after my deep CCA is done training/fitting. It seems to be related to a print, and I can still use the trained model, so hopefully it doesn't actually affect much.

My data are numpy arrays combined into tensor datasets.
image

Output

====> Epoch: 49 Average val loss: -5.5405
Min loss -5.54
====> Epoch: 50 Average train loss: -5.9762
====> Epoch: 50 Average val loss: -5.5670
Min loss -5.57
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/IPython/core/formatters.py in __call__(self, obj, include, exclude)
    968 
    969             if method is not None:
--> 970                 return method(include=include, exclude=exclude)
    971             return None
    972         else:

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/base.py in _repr_mimebundle_(self, **kwargs)
    462     def _repr_mimebundle_(self, **kwargs):
    463         """Mime bundle used by jupyter kernels to display estimator"""
--> 464         output = {"text/plain": repr(self)}
    465         if get_config()["display"] == 'diagram':
    466             output["text/html"] = estimator_html_repr(self)

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/base.py in __repr__(self, N_CHAR_MAX)
    258             n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
    259 
--> 260         repr_ = pp.pformat(self)
    261 
    262         # Use bruteforce ellipsis when there are a lot of non-blank characters

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/pprint.py in pformat(self, object)
    151     def pformat(self, object):
    152         sio = _StringIO()
--> 153         self._format(object, sio, 0, 0, {}, 0)
    154         return sio.getvalue()
    155 

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/pprint.py in _format(self, object, stream, indent, allowance, context, level)
    168             self._readable = False
    169             return
--> 170         rep = self._repr(object, context, level)
    171         max_width = self._width - indent - allowance
    172         if len(rep) > max_width:

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/pprint.py in _repr(self, object, context, level)
    402 
    403     def _repr(self, object, context, level):
--> 404         repr, readable, recursive = self.format(object, context.copy(),
    405                                                 self._depth, level)
    406         if not readable:

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/utils/_pprint.py in format(self, object, context, maxlevels, level)
    178 
    179     def format(self, object, context, maxlevels, level):
--> 180         return _safe_repr(object, context, maxlevels, level,
    181                           changed_only=self._changed_only)
    182 

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
    423         recursive = False
    424         if changed_only:
--> 425             params = _changed_params(object)
    426         else:
    427             params = object.get_params(deep=False)

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _changed_params(estimator)
     89     estimator with non-default values."""
     90 
---> 91     params = estimator.get_params(deep=False)
     92     init_func = getattr(estimator.__init__, 'deprecated_original',
     93                         estimator.__init__)

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/base.py in get_params(self, deep)
    193         out = dict()
    194         for key in self._get_param_names():
--> 195             value = getattr(self, key)
    196             if deep and hasattr(value, 'get_params'):
    197                 deep_items = value.get_params().items()

AttributeError: 'DeepWrapper' object has no attribute 'lr'

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/IPython/core/formatters.py in __call__(self, obj)
    700                 type_pprinters=self.type_printers,
    701                 deferred_pprinters=self.deferred_printers)
--> 702             printer.pretty(obj)
    703             printer.flush()
    704             return stream.getvalue()

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/IPython/lib/pretty.py in pretty(self, obj)
    392                         if cls is not object \
    393                                 and callable(cls.__dict__.get('__repr__')):
--> 394                             return _repr_pprint(obj, self, cycle)
    395 
    396             return _default_pprint(obj, self, cycle)

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/IPython/lib/pretty.py in _repr_pprint(obj, p, cycle)
    698     """A pprint that just redirects to the normal repr function."""
    699     # Find newlines and replace them with p.break_()
--> 700     output = repr(obj)
    701     lines = output.splitlines()
    702     with p.group():

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/base.py in __repr__(self, N_CHAR_MAX)
    258             n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
    259 
--> 260         repr_ = pp.pformat(self)
    261 
    262         # Use bruteforce ellipsis when there are a lot of non-blank characters

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/pprint.py in pformat(self, object)
    151     def pformat(self, object):
    152         sio = _StringIO()
--> 153         self._format(object, sio, 0, 0, {}, 0)
    154         return sio.getvalue()
    155 

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/pprint.py in _format(self, object, stream, indent, allowance, context, level)
    168             self._readable = False
    169             return
--> 170         rep = self._repr(object, context, level)
    171         max_width = self._width - indent - allowance
    172         if len(rep) > max_width:

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/pprint.py in _repr(self, object, context, level)
    402 
    403     def _repr(self, object, context, level):
--> 404         repr, readable, recursive = self.format(object, context.copy(),
    405                                                 self._depth, level)
    406         if not readable:

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/utils/_pprint.py in format(self, object, context, maxlevels, level)
    178 
    179     def format(self, object, context, maxlevels, level):
--> 180         return _safe_repr(object, context, maxlevels, level,
    181                           changed_only=self._changed_only)
    182 

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
    423         recursive = False
    424         if changed_only:
--> 425             params = _changed_params(object)
    426         else:
    427             params = object.get_params(deep=False)

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/utils/_pprint.py in _changed_params(estimator)
     89     estimator with non-default values."""
     90 
---> 91     params = estimator.get_params(deep=False)
     92     init_func = getattr(estimator.__init__, 'deprecated_original',
     93                         estimator.__init__)

/mnt/scratch/unen004/miniconda3/envs/matchms/lib/python3.8/site-packages/sklearn/base.py in get_params(self, deep)
    193         out = dict()
    194         for key in self._get_param_names():
--> 195             value = getattr(self, key)
    196             if deep and hasattr(value, 'get_params'):
    197                 deep_items = value.get_params().items()

AttributeError: 'DeepWrapper' object has no attribute 'lr'```

latent_dims, max_iter meaning

Received a question over email about the meaning of these parameters. So need to make that clearer in documentation/docstrings.

latent_dims refers to the number of orthogonal latent dimensions to find (sometimes called 'effects' in the literature). max_iter refers to the maximum number of iterations for iterative (as opposed to eigenvalue based) model solutions. For unregularised CCA and PLS, this refers to the NIPALS algorithm.

Update tests

Reduce the dimensionality of the tests substantially. Add in some hard checks such as correlation thresholds where they are missing.

TypeError: cannot pickle 'generator' object

I am using cca-zoo's GridSearchCV to run a grid search on my training set. Because the whole process takes a long time, I would like to be able to somehow cache the resulting estimator that is found using the grid search. I was thinking about using joblib for that. However, when I try to pickle it, I am getting the following error:


  File "C:\Users\Johannes.Wiesner\Documents\projects\hcp_project\testing\test_pickle_cca.py", line 52, in <module>
    dump(best_estimator,"./cache/gridsearch_estimator.pkl")

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\joblib\numpy_pickle.py", line 482, in dump
    NumpyPickler(f, protocol=protocol).dump(value)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 487, in dump
    self.save(obj)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\joblib\numpy_pickle.py", line 284, in save
    return Pickler.save(self, obj)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 603, in save
    self.save_reduce(obj=obj, *rv)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 717, in save_reduce
    save(state)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\joblib\numpy_pickle.py", line 284, in save
    return Pickler.save(self, obj)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 560, in save
    f(self, obj)  # Call unbound method with explicit self

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 971, in save_dict
    self._batch_setitems(obj.items())

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 997, in _batch_setitems
    save(v)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\site-packages\joblib\numpy_pickle.py", line 284, in save
    return Pickler.save(self, obj)

  File "C:\Users\Johannes.Wiesner\Miniconda3\envs\csp_wiesner_johannes\lib\pickle.py", line 578, in save
    rv = reduce(self.proto)

TypeError: cannot pickle 'generator' object

Here's a script to reproduce the error:

# -*- coding: utf-8 -*-
"""
Created on Mon Jan 31 13:53:30 2022

@author: Johannes.Wiesner
"""

import os
import numpy as np
from cca_zoo.models import PMD
from cca_zoo.model_selection import GridSearchCV
from cca_zoo.data import generate_covariance_data
from joblib import dump,load
(train_view_1,train_view_2),_=generate_covariance_data(n=200,view_features=[10,10],latent_dims=1,correlation=1)

rng = np.random.RandomState(42)

## Set parameters #############################################################

deflation='pls'
max_iter= 100
latent_dims = 3
tol = 1e-15
param_grid = {'c':[[0.1,0.9],[0.1,0.9]]}

###############################################################################
## Run CCA analysis with cca-zoo package ######################################
###############################################################################

estimator = PMD(latent_dims=latent_dims,
                random_state=rng,
                deflation=deflation,
                max_iter=max_iter,
                tol=tol)

def scorer(estimator,X):
    dim_corrs=estimator.score(X)
    return dim_corrs.mean()


if not os.path.isdir("./cache/"):
    
    os.makedirs("./cache")
    grid = GridSearchCV(estimator,
                    param_grid=param_grid,
                    cv=None,
                    verbose=True,
                    scoring=scorer)

    grid.fit([train_view_1,train_view_2])
    best_estimator = grid.best_estimator_
    dump(best_estimator,"./cache/gridsearch_estimator.pkl")
    
else:
    best_estimator = load("./cache/gridsearch_estimator.pkl")

Adapting to numpy convention for random state setting

Currently to get consistent numerical results, one needs to set a global random seed, as there's no random state parameter in the CCA classes.
It will greatly improve the numerical output consistency adding a parameter for random state.

It will also worth it to make sure the implementation is up-to-date with the latest best practice suggested by numpy.
The current example will not be supported after 1.16, see: https://numpy.org/doc/stable/reference/random/legacy.html
More discussion about the new best practice:
https://numpy.org/neps/nep-0019-rng-policy.html
https://albertcthomas.github.io/good-practices-random-number-generators/

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.