Git Product home page Git Product logo

mdsunivie / deeperwin Goto Github PK

View Code? Open in Web Editor NEW
44.0 3.0 6.0 8.39 MB

DeepErwin is a python 3.8+ package that implements and optimizes JAX 2.x wave function models for numerical solutions to the multi-electron Schrödinger equation. DeepErwin supports weight-sharing when optimizing wave functions for multiple nuclear geometries and the usage of pre-trained neural network weights to accelerate optimization.

License: Other

Python 99.95% Shell 0.05%
deep-learning deep-neural-networks quantum-monte-carlo variational-monte-carlo schrodinger-equation physical-chem weight-sharing transfer-learning

deeperwin's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

deeperwin's Issues

error of pip install -e .

Collecting orbax (from flax->deeperwin==1.2.0)
Using cached orbax-0.1.9-py3-none-any.whl
Using cached orbax-0.1.8.tar.gz (1.6 kB)
Preparing metadata (setup.py) ... error
error: subprocess-exited-with-error

× python setup.py egg_info did not run successfully.
│ exit code: 1
╰─> [3 lines of output]

  *** Orbax is a namespace, and not a standalone package. For model checkpointing and exporting utilities, please install `orbax-checkpoint` and `orbax-export` respectively (instead of `orbax`). ***

  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

TypeError: 'NoneType' object is not iterable

When running !deeperwin -f lih.yml in Google Colab, an error occurred as

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/deeperwin/process_molecules_shared.py", line 227, in <module>
    log_psi_squared, mcmc, wfs, init_shared_params, optimize_epoch, opt_state, opt_get_params, opt_set_params = init_wfs(
  File "/usr/local/lib/python3.9/dist-packages/deeperwin/process_molecules_shared.py", line 61, in init_wfs
    new_shared_params, wf.unique_trainable_params = split_trainable_params(new_trainable_params,
  File "/usr/local/lib/python3.9/dist-packages/deeperwin/utils.py", line 363, in split_trainable_params
    for module in shared_modules:
TypeError: 'NoneType' object is not iterable

The file lih.yml is written the same as in document:

physical:
    name: LiH
    changes:
      - R: [[0,0,0],[3.0,0,0]]
        comment: "Equilibrium bond length"
      - R: [[0,0,0],[2.8,0,0]]
        comment: "Compressed molecule"
      - R: [[0,0,0],[3.2,0,0]]
        comment: "Stretched molecule"
optimization:
    shared_optimization:
        use: True

python version 3.9
jax==0.3.1
jaxlib==0.1.74+cuda11.cudnn805
deeperwin installed via !pip install git+https://github.com/mdsunivie/deeperwin.git


Can you give some hints about how to deal with this?

Jax version issue: "AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'"

Hello all,

I'm trying to run our codes and I encounter this error after all the packages are installed:
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
I googled online and found this might be related to the jax version. We have a very strict jax version in the setup file, but it seems like that version doesn't have this attribute. I'm at jax=0.3.23, which is the same version as the setup file.
Thank you so much for your attention!

Best,
Sherry Cheng

DNN library initialization failed

I use conda to mange the package. The following is the installation.

  1. Create a conda enviroment
    conda create -n deeperwin python=3.8
  2. Activate the enviroment
    conda activate deeperwin
    (deeperwin) yzhou@Moss:~/base/. so the deeperwin is activated.
  3. Clone the source code
    git clone https://github.com/mdsunivie/deeperwin.git
  4. Install deeperwin
    cd deeperwin
    pip install -e .
  5. After installation, I check the jax is using GPU
    `python
    Python 3.8.16 (default, Mar 2 2023, 03:21:46)
    [GCC 11.2.0] :: Anaconda, Inc. on linux
    Type "help", "copyright", "credits" or "license" for more information.

import jax
jax.local_device_count()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1`
So here is no GPU found, but I do have a A100 GPU.

  1. Upgrade jax to gpu version and check
    `$ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    python
    Python 3.8.16 (default, Mar 2 2023, 03:21:46)
    [GCC 11.2.0] :: Anaconda, Inc. on linux
    Type "help", "copyright", "credits" or "license" for more information.

import jax
jax.local_device_count()
1
print(jax.devices()[0])
gpu:0`
So the GPU is ready.

  1. Run the sample_configs/config_basic.yml
    Something wrong. the information is
    $ deeperwin run sample_configs/config_basic.yml Traceback (most recent call last): File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in <module> sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')()) File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main process_molecule(args.config_file) File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 14, in process_molecule config: Configuration = Configuration.parse_obj(raw_config) File "pydantic/main.py", line 526, in pydantic.main.BaseModel.parse_obj File "pydantic/main.py", line 339, in pydantic.main.BaseModel.__init__ File "pydantic/main.py", line 1102, in pydantic.main.validate_model File "/media/hdd/yzhou/deeperwin/src/deeperwin/configuration.py", line 1164, in no_reuse_while_shared if (values['optimization'].shared_optimization is not None) and (values['reuse'] is not None): KeyError: 'optimization'
    8 Run the sample_configs/config_minimal.yml
    Here is the wrong information:
    $ deeperwin run sample_configs/config_minimal.yml 2023-05-22 12:48:56.136608: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found. Traceback (most recent call last): File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in <module> sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')()) File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main process_molecule(args.config_file) File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 36, in process_molecule from deeperwin.optimization import optimize_wavefunction, evaluate_wavefunction, pretrain_orbitals File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 9, in <module> from deeperwin.evaluation import evaluate_wavefunction File "/media/hdd/yzhou/deeperwin/src/deeperwin/evaluation.py", line 8, in <module> from deeperwin.loggers import DataLogger, WavefunctionLogger File "/media/hdd/yzhou/deeperwin/src/deeperwin/loggers.py", line 18, in <module> from deeperwin.checkpoints import save_run, RunData File "/media/hdd/yzhou/deeperwin/src/deeperwin/checkpoints.py", line 6, in <module> from deeperwin.mcmc import MCMCState File "/media/hdd/yzhou/deeperwin/src/deeperwin/mcmc.py", line 17, in <module> class MCMCState: File "/media/hdd/yzhou/deeperwin/src/deeperwin/mcmc.py", line 27, in MCMCState stepsize: jnp.array = jnp.array(1e-2) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2022, in array out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 549, in _convert_element_type return convert_element_type_p.bind(operand, new_dtype=new_dtype, File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(*tracers, **params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive compiled_fun = xla_primitive_callable( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper return cached(config._trace_context(), *args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached return f(*args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 222, in xla_primitive_callable compiled = _xla_callable_uncached( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 252, in _xla_callable_uncached return computation.compile().unsafe_call File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 495, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 463, in backend_compile return backend.compile(built_c, compile_options=options) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
    It seems the DNN is not available.

How to finetune on multiple geometries

Hi,

I have read the paper "Towards a Foundation Model for Neural Network Wavefunctions" and I was interested in how to run the fine-tuning of a pre-trained model/checkpoint against different geometries of the same molecule. I was thinking something similar to the Figure 5 in the paper. Could you indicate me how should I do so or how to construct the config file necessary?

Thanks beforehand!

deeperwin: error: unrecognized arguments: config.yml

I have successful using conda to install deeperwin with python 3.8 in ubuntu 22.04.
following the tutorial, add a config yml file.
vi config.yml
physical:
name: LiH
R: [[0,0,0],[3.5,0,0]]
then save the yml file.
using command line
$ deeperwin run config.yml
usage: deeperwin [-h] [--parameter PARAMETER [PARAMETER ...]] [--force] [--wandb-sweep WANDB_SWEEP WANDB_SWEEP WANDB_SWEEP] [--exclude-param-name] [--dry-run] [--start-time-offset START_TIME_OFFSET]
[config_file]
deeperwin: error: unrecognized arguments: config.yml
If using the the yml files in sample_configs, the same error.
Why?

Availability of the checkpoints

Hi,
I looked at your paper Towards a Foundation Model for Neural Network Wavefunctions. I found it interesting and would like to try the pretrained models. I wonder if the checkpoints used in the paper are available?
Also in the paper how long does it take to train your models on the 18 compounds datasets?
In the paper, the models were pretrained with 500k steps with the dataset of 360 geometries. Does 1 pretraining step mean 1 gradient update step using 1 geometry?
Thank you!

Document page error

The document link https://mipunivie.github.io/deeperwin/ in the README is dead. Could you change it to the correct one?

Still have problem in python 3.8.5

Dear all
The ubuntu is 22.04 and cuda version:12.1. Conda version is 2023.03-1-Linux-x86_64.
`$ nvidia-smi
Mon May 22 13:57:44 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100 80GB PCIe On | 00000000:21:00.0 Off | 0 |
| N/A 34C P0 45W / 300W| 4MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3309 G /usr/lib/xorg/Xorg 4MiB |
+---------------------------------------------------------------------------------------+
`

Using 3.8.5 to install deeperwin, I can run the sample_config/config_custom_molecule.yml, but have a error with jax.

  1. conda create -n deeperwin python=3.8.5
  2. conda activate deeperwin
  3. git clone https://github.com/mdsunivie/deeperwin.git
  4. pip install -e . .After this install, the version does not support GPU.
    `$ python
    Python 3.8.16 (default, Mar 2 2023, 03:21:46)
    [GCC 11.2.0] :: Anaconda, Inc. on linux
    Type "help", "copyright", "credits" or "license" for more information.

import jax
jax.local_device_count()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1
`

  1. Following the tutorial, upgrade jax to cuda version. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.
    `Collecting jaxlib==0.4.10+cuda11.cudnn86
    Using cached https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.10%2Bcuda11.cudnn86-cp38-cp38-manylinux2014_x86_64.whl (167.8 MB)
    Installing collected packages: jaxlib
    Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.10
    Uninstalling jaxlib-0.4.10:
    Successfully uninstalled jaxlib-0.4.10
    Successfully installed jaxlib-0.4.10+cuda11.cudnn86
    (deeperwin) yzhou@Moss:~/base/deeperwin$ python
    Python 3.8.16 (default, Mar 2 2023, 03:21:46)
    [GCC 11.2.0] :: Anaconda, Inc. on linux
    Type "help", "copyright", "credits" or "license" for more information.

import jax
jax.local_device_count()
1
`

  1. run the sample_configs/config_minimal.yml
    $ deeperwin run sample_configs/config_minimal.yml 2023-05-22 12:48:56.136608: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found. Traceback (most recent call last): File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in <module> sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')()) File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main process_molecule(args.config_file) File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 36, in process_molecule from deeperwin.optimization import optimize_wavefunction, evaluate_wavefunction, pretrain_orbitals File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 9, in <module> from deeperwin.evaluation import evaluate_wavefunction File "/media/hdd/yzhou/deeperwin/src/deeperwin/evaluation.py", line 8, in <module> from deeperwin.loggers import DataLogger, WavefunctionLogger File "/media/hdd/yzhou/deeperwin/src/deeperwin/loggers.py", line 18, in <module> from deeperwin.checkpoints import save_run, RunData File "/media/hdd/yzhou/deeperwin/src/deeperwin/checkpoints.py", line 6, in <module> from deeperwin.mcmc import MCMCState File "/media/hdd/yzhou/deeperwin/src/deeperwin/mcmc.py", line 17, in <module> class MCMCState: File "/media/hdd/yzhou/deeperwin/src/deeperwin/mcmc.py", line 27, in MCMCState stepsize: jnp.array = jnp.array(1e-2) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2022, in array out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 549, in _convert_element_type return convert_element_type_p.bind(operand, new_dtype=new_dtype, File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(*tracers, **params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive compiled_fun = xla_primitive_callable( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper return cached(config._trace_context(), *args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached return f(*args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 222, in xla_primitive_callable compiled = _xla_callable_uncached( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 252, in _xla_callable_uncached return computation.compile().unsafe_call File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 495, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 463, in backend_compile return backend.compile(built_c, compile_options=options) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
    Shows DNN library initialization failed.
  2. Upgrade the jaxlib to jaxlib-0.4.10+cuda12.cudnn88
    Attempting uninstall: jaxlib Found existing installation: jaxlib 0.4.10+cuda11.cudnn86 Uninstalling jaxlib-0.4.10+cuda11.cudnn86: Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86
    9 Rerun the same sample config.
    $ deeperwin run sample_configs/config_minimal.yml 2023-05-22 13:02:24,242 jax._src.dispatch WARNING Finished tracing + transforming jit(reshape) in 0.0005035400390625 sec 2023-05-22 13:02:24,244 jax._src.interpreters.pxla DEBUG Compiling reshape for with global shapes and types [ShapedArray(int32[1,1])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:24,247 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:24,271 jax._src.dispatch WARNING Finished XLA compilation of jit(reshape) in 0.022927522659301758 sec 2023-05-22 13:02:24,277 jax._src.dispatch WARNING Finished tracing + transforming <lambda> for pjit in 0.0014164447784423828 sec 2023-05-22 13:02:24,281 jax._src.dispatch WARNING Finished tracing + transforming _select_master_data for pmap in 0.0062906742095947266 sec 2023-05-22 13:02:24,281 jax._src.interpreters.pxla DEBUG sharded_avals: (ShapedArray(int32[]),) 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG global_sharded_avals: (ShapedArray(int32[]),) 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG num_replicas: 1 num_local_replicas: 1 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG devices: None 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG local_devices: None 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG Compiling _select_master_data (139714514916832) for 1 devices with args (ShapedArray(int32[1]),). (num_replicas=1) 2023-05-22 13:02:24,289 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:24,318 jax._src.dispatch WARNING Finished XLA compilation of _select_master_data in 0.028462648391723633 sec 2023-05-22 13:02:24,323 jax._src.dispatch WARNING Finished tracing + transforming _multi_slice for pjit in 0.0012958049774169922 sec 2023-05-22 13:02:24,324 jax._src.interpreters.pxla DEBUG Compiling _multi_slice for with global shapes and types [ShapedArray(int32[1])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:24,328 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:24,352 jax._src.dispatch WARNING Finished XLA compilation of jit(_multi_slice) in 0.023581981658935547 sec 2023-05-22 13:02:24,356 dpe DEBUG CUDA_VISIBLE_DEVICES=None 2023-05-22 13:02:24,356 dpe DEBUG Used hardware: gpu; Local device count: 1; Global device count: 1 2023-05-22 13:02:24,357 dpe DEBUG Calculating baseline solution... /home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/pyscf/gto/mole.py:1213: UserWarning: Function mol.dumps drops attribute charge because it is not JSON-serializable warnings.warn(msg) 2023-05-22 13:02:25,298 h5py._conv DEBUG Creating converter from 5 to 3 2023-05-22 13:02:25,854 jax._src.dispatch WARNING Finished tracing + transforming jit(convert_element_type) in 0.0013532638549804688 sec 2023-05-22 13:02:25,870 jax._src.dispatch WARNING Finished tracing + transforming <lambda> for pjit in 0.0011436939239501953 sec 2023-05-22 13:02:25,872 jax._src.dispatch WARNING Finished tracing + transforming _threefry_seed for pjit in 0.004873991012573242 sec 2023-05-22 13:02:25,874 jax._src.interpreters.pxla DEBUG Compiling _threefry_seed for with global shapes and types [ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:25,884 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:26,046 jax._src.dispatch WARNING Finished XLA compilation of jit(_threefry_seed) in 0.16130280494689941 sec 2023-05-22 13:02:26,052 jax._src.dispatch WARNING Finished tracing + transforming ravel for pjit in 0.0003886222839355469 sec 2023-05-22 13:02:26,054 jax._src.dispatch WARNING Finished tracing + transforming threefry_2x32 for pjit in 0.002865314483642578 sec 2023-05-22 13:02:26,055 jax._src.dispatch WARNING Finished tracing + transforming _threefry_split_original for pjit in 0.004869937896728516 sec 2023-05-22 13:02:26,056 jax._src.interpreters.pxla DEBUG Compiling _threefry_split_original for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:26,061 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:26,177 jax._src.dispatch WARNING Finished XLA compilation of jit(_threefry_split_original) in 0.11521124839782715 sec 2023-05-22 13:02:26,183 jax._src.dispatch WARNING Finished tracing + transforming _unstack for pjit in 0.0017693042755126953 sec

here is some warning about "jax._src.dispatch WARNING".

Finally come to a error.

`2023-05-22 14:05:07,397 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]]
2023-05-22 14:05:08,823 jax._src.dispatch WARNING Finished XLA compilation of run_inter_steps in 1.422454595565796 sec
Traceback (most recent call last):
File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in
sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')())
File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main
process_molecule(args.config_file)
File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 105, in process_molecule
mcmc_state, params, opt_state, clipping_state = optimize_wavefunction(
File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 161, in optimize_wavefunction
params, opt_state, clipping_state, stats = optimizer.step(params,
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 947, in step
step_counter_int = self.verify_args_and_get_step_counter(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 447, in verify_args_and_get_step_counter
return int(utils.get_first(step_counter))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 208, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/api.py", line 301, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 474, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 935, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 888, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2150, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 248, in get_first
return jax.tree_util.tree_map(index_if_not_scalar, obj)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 234, in index_if_not_scalar
if isinstance(value, chex.Array):
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 769, in instancecheck
return self.subclasscheck(type(obj))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 777, in subclasscheck
raise TypeError("Subscripted generics cannot be used with"
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Subscripted generics cannot be used with class and instance checks

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in
sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')())
File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main
process_molecule(args.config_file)
File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 105, in process_molecule
mcmc_state, params, opt_state, clipping_state = optimize_wavefunction(
File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 161, in optimize_wavefunction
params, opt_state, clipping_state, stats = optimizer.step(params,
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 947, in step
step_counter_int = self.verify_args_and_get_step_counter(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 447, in verify_args_and_get_step_counter
return int(utils.get_first(step_counter))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 248, in get_first
return jax.tree_util.tree_map(index_if_not_scalar, obj)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 234, in index_if_not_scalar
if isinstance(value, chex.Array):
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 769, in instancecheck
return self.subclasscheck(type(obj))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 777, in subclasscheck
raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks
`

How to fix this problem? Thanks.

A small bug in `config.physical` to load data

Hello everyone,

Thank you so much for your previous reply and update. I would like to report another issue.

When I tried to run config_bm_hfcoeff.yml by deeperwin run config_bm_hfcoeff.yml, I received the following error flags.
Traceback (most recent call last): File "/anaconda/envs/deeperwin/bin/deeperwin", line 33, in <module> sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')()) File "/home/t-shcheng/work/deeperwin/src/deeperwin/cli.py", line 16, in main process_single_molecule(args.config_file) File "/home/t-shcheng/work/deeperwin/src/deeperwin/process_molecule.py", line 99, in process_single_molecule tuple(config.physical.Z)) AttributeError: 'str' object has no attribute 'Z'

When I print out the config.physical, it give TinyMol_CNO_rot_dist_train_18compounds_360geoms, which is the name of the dataset not a geometry. Note that single molecule jobs run perfectly. I feel maybe we need to have a modified data loader for a dataset.

I would really appreciate if anyone could take a look at this set of codes. Many thanks ahead for your help!

Best,
Sherry Cheng

`e3nn-jax` required but not listed in the setup.cfg (the most recent `e3nn-jax` cannot work)

Hello all,

I would like to just report that there is a missing package e3nn-jax that should be included in the setup files. Also, since the deeperwin needs a specific jax version, which is different from the requirements of the current e3nn-jax, I guess it will be also great to label the specific version that is used in deeperwin.
I just closed the other issue I made since it is actually caused by using the most recent version of e3nn-jax. I downgraded to e3nn-jax==0.11.0 and the error flag of "AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'" disappears.
Thank you so much for your attention and help!
Best,
Sherry Cheng

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.