Git Product home page Git Product logo

Comments (6)

canergen avatar canergen commented on June 30, 2024

Hi. Can you please verify the installed version? Since scvi-tools 1.1.0, there is no explicit dependency in scvi-tools on chex anymore. It seems that flax is causing this ImportError. Can you import from flax.training import train_state? If you can't import flax correctly, the easiest solution would be to set up a new conda environment. Fixing these dependency issues can unfortunately otherwise be a lengthy process.

from scvi-tools.

gouinK avatar gouinK commented on June 30, 2024

Thank you for the quick response!

I have confirmed the version:

scvi-tools                1.1.2                    pypi_0    pypi

And running from flax.training import train_state does indeed throw the same error, see traceback below:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 from flax.training import train_state

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/flax/training/train_state.py:17
      1 # Copyright 2024 The Flax Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     15 from typing import Any, Callable
---> 17 import optax
     19 from flax import core, struct
     20 from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/__init__.py:17
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/contrib/__init__.py:17
      1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contributed optimizers in Optax."""
---> 17 from optax.contrib.cocob import cocob
     18 from optax.contrib.cocob import COCOBState
     19 from optax.contrib.complex_valued import split_real_and_imaginary

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/contrib/cocob.py:25
     23 import jax.numpy as jnp
     24 import jax.tree_util as jtu
---> 25 from optax._src import base
     28 class COCOBState(NamedTuple):
     29   """State for COntinuous COin Betting."""

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/optax/_src/base.py:19
     15 """Base interfaces and datatypes."""
     17 from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union
---> 19 import chex
     20 import jax
     21 import jax.numpy as jnp

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/chex/_src/pytypes.py:54
     52 Numeric = Union[Array, Scalar]
     53 Shape = jax.core.Shape
---> 54 PRNGKey = jax.random.KeyArray
     55 PyTreeDef = jax.tree_util.PyTreeDef
     56 Device = jax.Device

File ~/miniconda3/envs/test_env/lib/python3.10/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr.<locals>.getattr(name)
     52   warnings.warn(message, DeprecationWarning, stacklevel=2)
     53   return fn
---> 54 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.random' has no attribute 'KeyArray'

Looking at the traceback more closely, the trace seems to be showing:
flax --> optax --> chex --> jax
Is flax a dependency of scvi-tools and if so is there a suggested version of flax to use?

Thanks!

from scvi-tools.

canergen avatar canergen commented on June 30, 2024

Flax is used within scvi-tools. There isn't a specific requirement for the Flax version. However, there is a mismatch in your environment in the JAX and Flax version installed (Flax is older than JAX) and this is causing issues. If you install JAX from scratch in a new environment using pypi, the error shouldn't occur. You can try uninstalling Flax and JAX in the current environment and reinstall JAX (will install a correct version of Flax) and hope that it's fixed. My own experience is that it's easier to set up a new environment.

from scvi-tools.

gouinK avatar gouinK commented on June 30, 2024

Thanks, I will give that a try!

from scvi-tools.

gouinK avatar gouinK commented on June 30, 2024

Looking into the versions, this is what I have - both flax and jax seem to be the latest versions shown on their respective github pages, so I'm not sure that is the issue here.

flax                      0.8.2 
jax                       0.4.26
jaxlib                    0.4.26
chex                      0.1.7
optax                     0.2.1

I went ahead and uninstalled flax, jax, and jaxlib. Then ran this
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, which resulted in this:

scvi-tools 1.1.2 requires flax, which is not installed.
Successfully installed jax-0.4.26 jaxlib-0.4.26+cuda12.cudnn89 nvidia-cuda-nvcc-cu12-12.4.131

from scvi-tools.

canergen avatar canergen commented on June 30, 2024

I'm sorry and you need to also install Flax. Can you please check in a new environment to install JAX and Flax and see that it works. It's very difficult to fix a conda environment with wrong dependencies. We can support that creating a new conda environment and following the installation works: https://docs.scvi-tools.org/en/stable/installation.html.

from scvi-tools.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.