Dear all
The ubuntu is 22.04 and cuda version:12.1. Conda version is 2023.03-1-Linux-x86_64.
`$ nvidia-smi
Mon May 22 13:57:44 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100 80GB PCIe On | 00000000:21:00.0 Off | 0 |
| N/A 34C P0 45W / 300W| 4MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3309 G /usr/lib/xorg/Xorg 4MiB |
+---------------------------------------------------------------------------------------+
`
Using 3.8.5 to install deeperwin, I can run the sample_config/config_custom_molecule.yml, but have a error with jax.
- conda create -n deeperwin python=3.8.5
- conda activate deeperwin
- git clone https://github.com/mdsunivie/deeperwin.git
- pip install -e . .After this install, the version does not support GPU.
`$ python
Python 3.8.16 (default, Mar 2 2023, 03:21:46)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
import jax
jax.local_device_count()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
1
`
- Following the tutorial, upgrade jax to cuda version. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html.
`Collecting jaxlib==0.4.10+cuda11.cudnn86
Using cached https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.10%2Bcuda11.cudnn86-cp38-cp38-manylinux2014_x86_64.whl (167.8 MB)
Installing collected packages: jaxlib
Attempting uninstall: jaxlib
Found existing installation: jaxlib 0.4.10
Uninstalling jaxlib-0.4.10:
Successfully uninstalled jaxlib-0.4.10
Successfully installed jaxlib-0.4.10+cuda11.cudnn86
(deeperwin) yzhou@Moss:~/base/deeperwin$ python
Python 3.8.16 (default, Mar 2 2023, 03:21:46)
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
import jax
jax.local_device_count()
1
`
- run the sample_configs/config_minimal.yml
$ deeperwin run sample_configs/config_minimal.yml 2023-05-22 12:48:56.136608: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:407] There was an error before creating cudnn handle (302): cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found. Traceback (most recent call last): File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in <module> sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')()) File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main process_molecule(args.config_file) File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 36, in process_molecule from deeperwin.optimization import optimize_wavefunction, evaluate_wavefunction, pretrain_orbitals File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 9, in <module> from deeperwin.evaluation import evaluate_wavefunction File "/media/hdd/yzhou/deeperwin/src/deeperwin/evaluation.py", line 8, in <module> from deeperwin.loggers import DataLogger, WavefunctionLogger File "/media/hdd/yzhou/deeperwin/src/deeperwin/loggers.py", line 18, in <module> from deeperwin.checkpoints import save_run, RunData File "/media/hdd/yzhou/deeperwin/src/deeperwin/checkpoints.py", line 6, in <module> from deeperwin.mcmc import MCMCState File "/media/hdd/yzhou/deeperwin/src/deeperwin/mcmc.py", line 17, in <module> class MCMCState: File "/media/hdd/yzhou/deeperwin/src/deeperwin/mcmc.py", line 27, in MCMCState stepsize: jnp.array = jnp.array(1e-2) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2022, in array out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 549, in _convert_element_type return convert_element_type_p.bind(operand, new_dtype=new_dtype, File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/core.py", line 790, in process_primitive return primitive.impl(*tracers, **params) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 131, in apply_primitive compiled_fun = xla_primitive_callable( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper return cached(config._trace_context(), *args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached return f(*args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 222, in xla_primitive_callable compiled = _xla_callable_uncached( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 252, in _xla_callable_uncached return computation.compile().unsafe_call File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile executable = UnloadedMeshExecutable.from_hlo( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo xla_executable, compile_options = _cached_compilation( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation xla_executable = dispatch.compile_or_get_cached( File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 495, in compile_or_get_cached return backend_compile(backend, computation, compile_options, File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/dispatch.py", line 463, in backend_compile return backend.compile(built_c, compile_options=options) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Shows DNN library initialization failed.
- Upgrade the jaxlib to jaxlib-0.4.10+cuda12.cudnn88
Attempting uninstall: jaxlib Found existing installation: jaxlib 0.4.10+cuda11.cudnn86 Uninstalling jaxlib-0.4.10+cuda11.cudnn86: Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86
9 Rerun the same sample config.
$ deeperwin run sample_configs/config_minimal.yml 2023-05-22 13:02:24,242 jax._src.dispatch WARNING Finished tracing + transforming jit(reshape) in 0.0005035400390625 sec 2023-05-22 13:02:24,244 jax._src.interpreters.pxla DEBUG Compiling reshape for with global shapes and types [ShapedArray(int32[1,1])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:24,247 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:24,271 jax._src.dispatch WARNING Finished XLA compilation of jit(reshape) in 0.022927522659301758 sec 2023-05-22 13:02:24,277 jax._src.dispatch WARNING Finished tracing + transforming <lambda> for pjit in 0.0014164447784423828 sec 2023-05-22 13:02:24,281 jax._src.dispatch WARNING Finished tracing + transforming _select_master_data for pmap in 0.0062906742095947266 sec 2023-05-22 13:02:24,281 jax._src.interpreters.pxla DEBUG sharded_avals: (ShapedArray(int32[]),) 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG global_sharded_avals: (ShapedArray(int32[]),) 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG num_replicas: 1 num_local_replicas: 1 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG devices: None 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG local_devices: None 2023-05-22 13:02:24,282 jax._src.interpreters.pxla DEBUG Compiling _select_master_data (139714514916832) for 1 devices with args (ShapedArray(int32[1]),). (num_replicas=1) 2023-05-22 13:02:24,289 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:24,318 jax._src.dispatch WARNING Finished XLA compilation of _select_master_data in 0.028462648391723633 sec 2023-05-22 13:02:24,323 jax._src.dispatch WARNING Finished tracing + transforming _multi_slice for pjit in 0.0012958049774169922 sec 2023-05-22 13:02:24,324 jax._src.interpreters.pxla DEBUG Compiling _multi_slice for with global shapes and types [ShapedArray(int32[1])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:24,328 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:24,352 jax._src.dispatch WARNING Finished XLA compilation of jit(_multi_slice) in 0.023581981658935547 sec 2023-05-22 13:02:24,356 dpe DEBUG CUDA_VISIBLE_DEVICES=None 2023-05-22 13:02:24,356 dpe DEBUG Used hardware: gpu; Local device count: 1; Global device count: 1 2023-05-22 13:02:24,357 dpe DEBUG Calculating baseline solution... /home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/pyscf/gto/mole.py:1213: UserWarning: Function mol.dumps drops attribute charge because it is not JSON-serializable warnings.warn(msg) 2023-05-22 13:02:25,298 h5py._conv DEBUG Creating converter from 5 to 3 2023-05-22 13:02:25,854 jax._src.dispatch WARNING Finished tracing + transforming jit(convert_element_type) in 0.0013532638549804688 sec 2023-05-22 13:02:25,870 jax._src.dispatch WARNING Finished tracing + transforming <lambda> for pjit in 0.0011436939239501953 sec 2023-05-22 13:02:25,872 jax._src.dispatch WARNING Finished tracing + transforming _threefry_seed for pjit in 0.004873991012573242 sec 2023-05-22 13:02:25,874 jax._src.interpreters.pxla DEBUG Compiling _threefry_seed for with global shapes and types [ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:25,884 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:26,046 jax._src.dispatch WARNING Finished XLA compilation of jit(_threefry_seed) in 0.16130280494689941 sec 2023-05-22 13:02:26,052 jax._src.dispatch WARNING Finished tracing + transforming ravel for pjit in 0.0003886222839355469 sec 2023-05-22 13:02:26,054 jax._src.dispatch WARNING Finished tracing + transforming threefry_2x32 for pjit in 0.002865314483642578 sec 2023-05-22 13:02:26,055 jax._src.dispatch WARNING Finished tracing + transforming _threefry_split_original for pjit in 0.004869937896728516 sec 2023-05-22 13:02:26,056 jax._src.interpreters.pxla DEBUG Compiling _threefry_split_original for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),). 2023-05-22 13:02:26,061 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]] 2023-05-22 13:02:26,177 jax._src.dispatch WARNING Finished XLA compilation of jit(_threefry_split_original) in 0.11521124839782715 sec 2023-05-22 13:02:26,183 jax._src.dispatch WARNING Finished tracing + transforming _unstack for pjit in 0.0017693042755126953 sec
here is some warning about "jax._src.dispatch WARNING".
Finally come to a error.
`2023-05-22 14:05:07,397 jax._src.xla_bridge DEBUG get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]]
2023-05-22 14:05:08,823 jax._src.dispatch WARNING Finished XLA compilation of run_inter_steps in 1.422454595565796 sec
Traceback (most recent call last):
File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in
sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')())
File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main
process_molecule(args.config_file)
File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 105, in process_molecule
mcmc_state, params, opt_state, clipping_state = optimize_wavefunction(
File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 161, in optimize_wavefunction
params, opt_state, clipping_state, stats = optimizer.step(params,
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 947, in step
step_counter_int = self.verify_args_and_get_step_counter(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 447, in verify_args_and_get_step_counter
return int(utils.get_first(step_counter))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 208, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/api.py", line 301, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 474, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 935, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/pjit.py", line 888, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2150, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 248, in get_first
return jax.tree_util.tree_map(index_if_not_scalar, obj)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 234, in index_if_not_scalar
if isinstance(value, chex.Array):
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 769, in instancecheck
return self.subclasscheck(type(obj))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 777, in subclasscheck
raise TypeError("Subscripted generics cannot be used with"
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Subscripted generics cannot be used with class and instance checks
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yzhou/base/anaconda3/envs/deeperwin/bin/deeperwin", line 33, in
sys.exit(load_entry_point('deeperwin', 'console_scripts', 'deeperwin')())
File "/media/hdd/yzhou/deeperwin/src/deeperwin/cli.py", line 52, in main
process_molecule(args.config_file)
File "/media/hdd/yzhou/deeperwin/src/deeperwin/process_molecule.py", line 105, in process_molecule
mcmc_state, params, opt_state, clipping_state = optimize_wavefunction(
File "/media/hdd/yzhou/deeperwin/src/deeperwin/optimization.py", line 161, in optimize_wavefunction
params, opt_state, clipping_state, stats = optimizer.step(params,
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 947, in step
step_counter_int = self.verify_args_and_get_step_counter(
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/optimizer.py", line 447, in verify_args_and_get_step_counter
return int(utils.get_first(step_counter))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 248, in get_first
return jax.tree_util.tree_map(index_if_not_scalar, obj)
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/site-packages/kfac_jax/_src/utils.py", line 234, in index_if_not_scalar
if isinstance(value, chex.Array):
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 769, in instancecheck
return self.subclasscheck(type(obj))
File "/home/yzhou/base/anaconda3/envs/deeperwin/lib/python3.8/typing.py", line 777, in subclasscheck
raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks
`
How to fix this problem? Thanks.