jax-md / jax-md Goto Github PK
View Code? Open in Web Editor NEWDifferentiable, Hardware Accelerated, Molecular Dynamics
License: Apache License 2.0
Differentiable, Hardware Accelerated, Molecular Dynamics
License: Apache License 2.0
To reproduce the problem:
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.15-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
!pip install -q `git+https://www.github.com/google/jax-md
import numpy as onp
import jax.numpy as np
from jax import random
from jax import jit, grad
from jax.config import config
config.update("jax_enable_x64", True)
from jax_md import space, smap, energy, minimize, quantity, simulate`
def SetPositionsOctahedron():
a=1/np.sqrt(2.)
R=list(np.full((6,3),0.))
R[0]=np.array([a,0,0])
R[1]=np.array([-a,0,0])
R[2]=np.array([0,a,0])
R[3]=np.array([0,-a,0])
R[4]=np.array([0,0,a])
R[5]=np.array([0,0,-a])
return np.array(R).astype(np.float64)
At this point, everything is fine, but if you then run:
box_size = 8.
Rinit = SetPositionsOctahedron()+box_size/2
you get the following error message:
GPU memory usage is close to the limit
Your GPU is close to its memory limit. You will not be able to use any additional memory in this session. Currently, 10.07 GB / 11.17 GB is being used. Would you like to terminate some sessions in order to free up GPU memory (state will be lost for those sessions)?
If I press "ignore," the issue doesn't go away but it doesn't seem to negatively affect me. That might be because I'm not (yet) running computationally intensive work. So the good news I don't need this fixed in order to keep working (for now), but I thought you should be away of it.
Is there any paper or online document like Arxiv describes the ideas behind this project?
Or a more detailed README.md
?
Hello! This is a very cool project!
I have a few questions related to using jax-md
for condensed matter MD, in particular related to computing the stress. I'll give some context below, feel free to skip if this is obvious already!
In that context, the stress tensor sigma
is defined as the derivative of the total energy E
with respect to a 3x3
infinitesimal strain tensor epsilon
:
sigma_ab = 1/V * (d E / d epsilon_ab) at epsilon = 0
epsilon
describes a strain transformation of all real space coordinates R
as:
R -> (1 + e) @ R
This transformation is applied to all coordinates, i.e. the atoms in the unit cell and the basis vectors.
In an autodiff framework this shouldn't be too difficult to implement. SchNetPack
does it by applying the strain transformation to R
and basis
before the forward pass, and then computing the gradient. In jax-md
, however, the problem is that the basis is treated as a parameter to the energy function, via the displacement
and only the coordinates are treated as argument, so it's easy to compute the gradient with respect to R
, but not with respect to the basis.
The logical way to apply the strain transformation in jax-md
is in the displacement function, since if one transforms the displacement_fn
(and the shift_fn
, but this seems to not be needed to only compute energies), the energy behaves as though it was computed in transformed coordinates. (Since only displacement vectors are ever used to compute energies.)
I've tried to implement this experimentally here by simply wrapping the periodic_general
displacement function. This works and gives correct results, but seems to be un-jittable, presumably because I'm re-defining the energy function when it's called.
An alternative approach would be to "hack" the feature of jax-md
that allows the T
parameter of periodic_general
to be a callable: if one lets T
be a function that accepts the strain as argument:
def T(strain: Array = np.zeros((3, 3), dtype=np.double), t=0) -> Array:
strain_transformation = np.eye(N=3, M=3, dtype=np.double) + strain
return transform(strain_transformation, basis)
Then passing strain = np.zeros((3, 3), dtype=np.double)
to the energy function will work. However, this runs into a problem: jax-md
defines the Jacobian-vector product of the transform
function such that it doesn't act on the backwards pass here, so taking the gradient with respect to strain
in the above example will give zero stress. If the custom jvp
is commented out it seems to work (even though I haven't tested it for correctness). This approach also seems to not play well with jit
, but I'm not sure why.
So, finally, my questions:
jvp
for transform
defined this way? Removing it doesn't seem to break jax-md
, but I assume there's a good reason why it is this way...(Tagging in @fabiannagel since he's also involved in this!)
https://github.com/google/jax-md/blob/64a4c80175be80391177bb549320aec387cc01b2/jax_md/energy.py#L150
r_onste
should be r_onset
import jax.numpy as np
from jax import vmap
from jax_md.util import f32
from jax_md import space, energy
batch_size = 32
N = 100
dim = 2
box_size = 10
displacement, shift = space.periodic(box_size)
neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(displacement, box_size)
R = np.zeros((batch_size, N, dim), dtype=f32)
nbrs = vmap(neighbor_fn)(R)
When I tried to do vectorize the neighbor_fn
, I got the following error:
~/.anaconda3/envs/md/python3.7/site-packages/jax_md/partition.py in _estimate_cell_capacity(R, box_size, cell_size, buffer_size_multiplier)
173 cell_capacity = onp.max(count_cell_filling(R, box_size, cell_size))
--> 174 return int(cell_capacity * buffer_size_multiplier)
175
FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
is there any solution to bypass this problem?
Looks like the neural network potentials notebook needs to be updated for the new location of optax given error produced in "Imports & Utils":
21
22 from jax import random
---> 23 from jax.experimental import optix
24
25 from jax_md import energy, space, simulate, quantity
ImportError: cannot import name 'optix'
Fixed by adding !pip install -q git+https://www.github.com/deepmind/optax
and replacing from jax.experimental import optix
with import optax as optix
without modifying the rest, alternatively could go through and update all the other usages of "optix".
Lmk if you want a patch or would rather do it. Awesome notebook btw! ๐
In addition to simulate.nve and simulate.nvt, it would be nice to have simulate.brownian. This is just the over-damped limit of Langevin dynamics, but I believe typical implementations are very different so I'm listing it as a separate issue.
When I try running the colab notebooks, I get the following error in the first cell:
ERROR: HTTP error 404 while getting https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
ERROR: Could not install requirement jaxlib==0.1.23 from https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl because of error 404 Client Error: Not Found for url: https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
ERROR: Could not install requirement jaxlib==0.1.23 from https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl because of HTTP error 404 Client Error: Not Found for url: https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl for URL https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 215kB 1.4MB/s
|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 61kB 26.2MB/s
Building wheel for jax (setup.py) ... done
Building wheel for opt-einsum (setup.py) ... done
Building wheel for jax-md (setup.py) ... done
jax-md/jax-md/energy.py line 37
check_kwargs_time_dependnece(unused_kwargs)
This is not something that I need now, but it is almost certainly a feature that you'll eventually want to include. The standard (e.g. in LAMMPS and HOOMD) for fast neighbor detection is to use a combination of cell lists (which you already have) and Verlet lists, i.e. use cell lists to create and periodically update Verlet lists. I've noticed a non-negligible speedup compared to just cell lists in my own code, and for people who really care about these things, I think it's pretty important.
I would be happy to chat more about this.
This seems like a really excellent project and I am excited about the potential for gradient-based optimization with your approach. However, in running the MD cookbook notebook, every simulation or minimization routine results in nans. I was surprised to find this even for simple Lennard-Jones fluid at low density and with very small time steps (it occurs even for 1e-6 sometimes). I am running directly from Google Colab so it doesn't seem like an issue on my end. Is there a stable version I could try rolling back to where this has been tested?
I also tried the jax.config.update("jax_debug_nans", True)
setting and got the following output from the first Brownian dynamics simulation of the bubble raft on the first step:
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
FloatingPointError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
528 try:
--> 529 return compiled_fun(*args)
530 except FloatingPointError:
21 frames
FloatingPointError: invalid value (nan) encountered in xla_call
During handling of the above exception, another exception occurred:
FloatingPointError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _check_nans(name, xla_shape, buf)
350 if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
351 if np.any(np.isnan(buf.to_py())):
--> 352 raise FloatingPointError(f"invalid value (nan) encountered in {name}")
353
354 ### compiling jaxprs
FloatingPointError: invalid value (nan) encountered in mul
run_simulation is referenced in the text but never defined.
Unless I'm missing something, in the implementation of multiplicative_isotropic_cutoff, the function smooth_fn(dr) assumes that dr is an (N,N) array of distances, not an (N,N,d) array of displacements. However, when it's passed to, for example, smap.pair, it gets passed an (N,N,d) array of displacements. The solution is to change
def cutoff_fn(dr, *args, **kwargs):
return smooth_fn(dr) * fn(dr, *args, **kwargs)
to
def cutoff_fn(dR, *args, **kwargs):
dr = space.distance(dR)
return smooth_fn(dr) * fn(dr, *args, **kwargs)
at line 232 of energy.py
The function energy.multiplicative_isotropic_cutoff creates the need for potentials like soft_sphere take an (N,N) array of distances rather than an (N,N,d) array of displacements. However, smap.pair, among others, still need a function that takes (N,N,d) arrays of displacements. In the solution I just proposed to #55, energy.multiplicative_isotropic_cutoff serves this purposes of calling space.distance(dR) and passing the result to the wrapped energy function. I propose a similar function, lets call it energy.isotropic, whose only purpose is to call space.distance(dR). This would increase code reusability, for example:
def my_potential(dr, **kwargs):
...
def my_potential_pair(metric):
return smap.pair(
energy.isotropic(my_potential),
metric)
def my_potential_pair_cutoff(metric, r_onset, r_cutoff):
return smap.pair(
energy.multiplicative_isotropic_cutoff(my_potential, r_onset, r_cutoff),
metric)
https://github.com/google/jax-md/blob/64a4c80175be80391177bb549320aec387cc01b2/jax_md/energy.py#L166
Unless I'm totally missing something, the parameter alpha
isn't ever used here. Probably a typo when making lennard_jones_neighbor_list
based on soft_sphere_neighbor_list
? :)
Using an apply function and passing apply(state) rather than apply(state, t=t) gives a NotImplementedError. Here's a pastebin with code that produces the error: https://pastebin.com/7TNqZkF2
The same code was working until an update was pushed last week.
Hi,
Should n_pos be reset to 0 when P < 0? Like
n_pos= np.where(P < 0, 0, n_pos)
after Line 160 in minimize.py?
Some code that was working a few days ago is now throwing an error. I know there have been a lot of changes recently, and I think this is an easy fix (but I'm not 100% sure what the right thing to do is, so I'm filing an issue rather than a PR).
Without going into the details (I can send code if necessary), I am doing a FIRE minimization using lax.scan, and it appears the type of the FireDescentState is changing. Specifically n_pos
is changing from an int64
to a float32
. My guess is that in line 154 of minimize.py:
return FireDescentState(R, V, force(R, **kwargs), dt_start, alpha_start, 0) # pytype: disable=wrong-arg-count
the 0
should be explicitly cast to an f32
?
Here's the error I'm getting:
FireDescentState(position=ShapedArray(float64[19,7,3]), velocity=ShapedArray(float64[19,7,3]), force=ShapedArray(float64[19,7,3]), dt=ShapedArray(float16[]), alpha=ShapedArray(float16[]), n_pos=ShapedArray(float32[]))
and
FireDescentState(position=ShapedArray(float64[19,7,3]), velocity=ShapedArray(float64[19,7,3]), force=ShapedArray(float64[19,7,3]), dt=ShapedArray(float16[]), alpha=ShapedArray(float16[]), n_pos=ShapedArray(int64[])).
This is straightforward. If it helps, here is how I did it (based on the soft_sphere potential).
def morse(dR, D0=1.0, alpha=30.0, r0=1.0):
dr = space.distance(dR)
U = D0*(np.exp(-2.*alpha*(dr-r0))-2.*np.exp(-alpha*(dr-r0)))
return np.array(U, dtype=dr.dtype)
def morse_pairwise(
metric, species=None, D0=1.0, alpha=30.0, r0=1.0):
return smap.pairwise(
morse,
metric,
species=species,
D0=D0,
alpha=alpha,
r0=r0)
In the new update to simulate.py, the parameter T_schedule
was replaced with kT
, with the idea that users can dynamically pass kT
to the apply function to get a variable temperature. Is there a reason not to allow, for example, dt
and gamma
in simulate.brownian
to be variable as well? Seems like an easy change with no downside.
Looks like c55ec95 may have broken the customizing potentials notebook. E.g. in various places, simulate.brownian
calls need changes of parameterization from T_schedule=1.0
to kT=1.0
.
I'll just go ahead and submit a patch and maybe a test to verify notebooks run without error?
Thanks a lot for this work, @sschoenholz and others, it's been really interesting to learn about.
Thanks for developing this promising library. I am quite new to the JAX framework so I am not sure if it can accommodate algorithms like the smooth particle mesh Ewald method (for which a GPU implementation is described here https://pubs.acs.org/doi/10.1021/ct900275y).
Is it a reasonable feature to request for this library?
In functions such as soft_sphere_pair
, parameters (e.g. sigma
, epsilon
, alpha
) are explicitly cast to np.float32. However, sometimes it is actually important to keep some parameters at double precision. It would be nice if there was an option that would direct soft_sphere_pair
to recast the parameters. Something like
energy_fn = energy.soft_sphere_pair(displacement, sigma=jnp.array([1.0],dtype=jnp.float64), dont_recast=True)
Is there a reason not to do this?
In the customizing_potentials_cookbook.ipynb example, it seems that the potential depends only on
However, for AMBER force field for example,
jax-md
?It is a very interesting project! I am wondering whether it is possible to include a way to include common force field based energy functions?
Such as if we hope to simulation real biological systems with both common force fields (such as Amber14SB, Charmm36) and machine learning based force fields?
Thanks!
When running a nose_hoover simulation that was working yesterday, today I am getting the following error:
(I can give you the specific code if need, but I don't think it will be.) However, if I change
state = init(key, Rinit)
to
state = init(key, Rinit, mass=float(1.0))
it works. My guess is this is an easy fix in line 70 of quantity.py since the default for mass is np.float32(1.0).
I've followed the instructions for the pip install as well as pip install'ing from the most recent clone of the repo. In both cases, I double check that I'm not inside of the jax_md dir on my system. So I follow the space.py instructions and I get the following error:
>>> from jax_md import space
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/workspace/wsa/jones289/miniconda3/envs/jaxmd/lib/python3.5/site-packages/jax_md/__init__.py", line 15, in <module>
from jax_md import space
File "/usr/workspace/wsa/jones289/miniconda3/envs/jaxmd/lib/python3.5/site-packages/jax_md/space.py", line 41, in <module>
from jax import custom_jvp
ImportError: cannot import name 'custom_jvp'
I tried previously with python 3.7, so I tried reinstalling with python=3.5, but that doesn't seem to be the issue. I'm not sure if I misread the documentation on how to setup jax_md but clearly something is wrong here.
Hi, I follow the instructions in Appendix A of your paper and run the code below. It runs slowly so I want to enable jax.jit
for apply_fn
, and simply add apply_fn = jit(apply_fn)
. However, after 1000 iterations, the results are totally different, while the results are similar in 1-10 iterations. How can I deal with it? I also tried to enable 64-bit float, but it didn't work.
from jax import random, jit
from jax_md import energy, space, simulate
N = 32
dt = 1e-1
temperature = 0.1
box_size = 5.0
key = random.PRNGKey(0)
displacement, shift = space.periodic(box_size)
energy_fn = energy.soft_sphere_pair(displacement)
def simulation(key):
pos_key, sim_key = random.split(key)
R = random.uniform(pos_key, (N, 2), maxval=box_size)
init_fn, apply_fn = simulate.brownian(
energy_fn, shift, dt, temperature)
state = init_fn(sim_key, R)
for i in range(1000):
state = apply_fn(state)
return state.position
positions = simulation(key)
print(positions)
When passing wrapped=False to the periodic space functions, it seems that in addition to displacement and shift, it is important to define a third function that allows the user to put the particles back in the box. A typical use case would be to use unwrapped positions during a simulation because they are (presumably) faster, and then do something like R = wrap(state.position)
to get the final positions. My thinking is that there are two options for the api:
option 1:
displacement, shift = space.periodic(L)
displacement, shift, wrap = space.periodic(L, wrapped=False)
option 2:
displacement, shift = space.periodic(L)
displacement, shift = space.periodic(L, wrapped=False)
displacement, shift, wrap = space.periodic(L, wrapped=False, wrap_function=True)
The second option might be preferable since it would be backwards compatible. I would be happy to take a first shot at this if that would be helpful.
In the bubble tutorial in pair_correlation_fun
dr = np.where(dr > 1e-7, dr, 1e7)
I believe the last number should be 1e-7
The first example when I learnt molecular simulation was hydration free energy simulation, Is it possible to do that with jax-md?
Hello folks! First of all, I am new in the use of jax-md. So, I as for your comphreension if my question is silly:
Is it possible to extract/obtain a single particle energy contribution to system total energy during a simulation?
It would be nice to be able to define rigid bodies (i.e. a collection of particles whose relative positions are fixed up to translations and rotation.
This is probably a fairly intensive request because it implies the need for rotational degrees of freedom, torques, etc.
Having the energy be time-dependent is really useful, but it seems like it may not be able to interfacing with temperature changing with time. It would be nice to be able to have temperature schedules be time-dependent in addition to the energy.
A few minor suggestions to help improve the clarity/performance of neighbor lists:
neighbor_fn
(the returned function of partition.neighbor_list
) into two, rather than having the two different use cases where only one case is compatible with jit/grad. So, for example:neighbor_fn, energy_fn = lennard_jones_neighbor_list(displacement, box_size)
#construct the neighbor list
nbrs = neighbor_fn(R)
#update the neighbor list
nbrs = neighbor_fn(R, nbrs)
would become
(construct_neighbor_fn, update_neighbor_fn), energy_fn = lennard_jones_neighbor_list(displacement, box_size)
#construct the neighbor list
nbrs = construct_neighbor_fn(R)
#update the neighbor list
nbrs = update_neighbor_fn(R, nbrs)
It would then be easier to communicate that update_neighbor_fn
can be included with jax
transformation, but construct_neighbor_fn
cannot.
Pass kwargs through the neighbor list convenience functions in energy.py. Unless I'm missing something, there currently is not a way to change, e.g., capacity_multiplier
using these convenience functions.
I know this has come up in the past, but I forget the rationale behind scaling dr_threshold
, e.g.:
https://github.com/google/jax-md/blob/c55ec95999331844a72a1e1a3a3009276c60a098/jax_md/energy.py#L223
This is done for some of the potentials (soft_sphere, LJ) but not others (morse, bks). My feeling is that physically dr_threshold
is a buffer for how far particles can move before the neighbor lists have to be recomputed, but big particles don't typically move faster than small particles (depending on the dynamics/potentials). I'm all for coming up with intelligent heuristics for what the optimum dr_threshold
should be, its not clear that this is the right heuristic. But I'm happy to be wrong.
This assumes you have a system of monodisperse spheres with positions stored in R. Let me know if you want me to share my colab.
This code sets up a list of bonds from the topology of a sphere packing but sets the rest length equal to the current distance between the particles. Therefore, the energy should be zero, but for the dynamic case it is equal to the energy of the sphere packing, which is what you would expect if you use length=1.0 (i.e. the default static parameter).
#use this to compare later
print(energy_fn(R))
print(np.linalg.norm(grad(energy_fn)(R)))
displacement_all = vmap(vmap(displacement, (0, None), 0), (None, 0), 0)
dR = displacement_all(R, R)
dr = space.distance(dR)
#0 if particles do not overlap
#rest length if particles do overlap
bond_lengths_matrix = np.where(dr<1.,dr,0.)
bond_lengths_matrix = np.triu(bond_lengths_matrix) #make upper triangular
index_list=np.dstack(np.meshgrid(np.arange(N), np.arange(N), indexing='ij'))
i_s = np.where(bond_lengths_matrix>0.0001, index_list[:,:,0], -1).flatten()
j_s = np.where(bond_lengths_matrix>0.0001, index_list[:,:,1], -1).flatten()
l_s = bond_lengths_matrix.flatten()
temp = np.transpose(np.array([i_s,j_s]))
bond_list = temp[(temp!=np.array([-1,-1]))[:,1]]
length_list = l_s[(temp!=np.array([-1,-1]))[:,1]]
k_list = np.full(length_list.shape, 1., dtype=np.float64)
def simple_spring_bond_float64(
displacement_or_metric, bond, bond_type=None, length=1, epsilon=1, alpha=2):
"""Convenience wrapper to compute energy of particles bonded by springs."""
length = np.array(length, np.float64)
epsilon = np.array(epsilon, np.float64)
alpha = np.array(alpha, np.float32)
return smap.bond(
energy.simple_spring,
energy._canonicalize_displacement_or_metric(displacement_or_metric),
bond,
bond_type,
length=length,
epsilon=epsilon,
alpha=alpha)
#Pass bond info dynamically (the energy here should be zero)
bond_energy_fn = simple_spring_bond_float64(displacement, None)
print(bond_energy_fn(R,bond_list, length=length_list, epsilon=k_list, alpha=2.))
print(np.linalg.norm(grad(bond_energy_fn)(R,bond_list, length=length_list, epsilon=k_list, alpha=2.)))
#pass bonds statically
bond_energy_fn = simple_spring_bond_float64(displacement, bond_list, length=length_list, epsilon=k_list, alpha=2.)
print(bond_energy_fn(R))
print(np.linalg.norm(grad(bond_energy_fn)(R)))```
When passing a wrapped function to simulate.brownian I get an " unexpected keyword argument 't' " error. This is easily resolved by passing through **kwargs, but I'd prefer if this was not necessary as this is not documented and only needed in some cases.
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from jax import random
from jax_md import space, energy, simulate
displacement, shift = space.periodic(side_length)
energy_fn = energy.soft_sphere_pair(displacement)
no_keywords_fn = lambda R: energy_fn(R)
init, apply = simulate.brownian(no_keywords_fn, shift, dt=0.00001, T_schedule=1.0, gamma=0.1)
key, split = random.split(key)
state = init(split, R)
apply(state)
keywords_fn = lambda R, **kwargs: energy_fn(R,**kwargs)
init, apply = simulate.brownian(keywords_fn, shift, dt=0.00001, T_schedule=1.0, gamma=0.1)
key, split = random.split(key)
state = init(split, R)
apply(state)
This is not an issue for nvt_nose_hoover and for minimize.fire_decent.
The same issue also appears for nve.
And for nvt_langevin this issue appears for both init and apply.
Setting the jax_debug_nans
flag to True
can prevent you from computing energy.lennard_jones(x)
.
By using np.nan_to_num
seperately on idr6
and idr12
this issue can be circumvented. Alternatively jit compiling the energy function also works.
I don't know if this issue comes up in a different place or if it is even worth solving.
Here is a minimal repro:
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
config.update("jax_debug_nans", True)
from jax import random, grad, jit
from jax_md import space, energy
N = 5
dimension = 2
box_size = 12
displacement, shift = space.periodic(box_size)
key = random.PRNGKey(0)
key, split = random.split(key)
R = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=jnp.float64)
energy_fn = energy.lennard_jones_pair(displacement)
#Works
print(jit(energy_fn)(R))
#Doesn't work
print(energy_fn(R))
This is a low priority comment as I do not currently need this for my work.
My understanding is that in energy.multiplicitive_isotropic_cutoff
, you can pass r_onset
and r_cutoff
as scalars or as [n,m] arrays, but not on a per-species level. It is often desirable for these parameters to be functions of the other parameters in the potential. For example, in the Morse potential, V(r) = D0 * (1 - exp(-alpha * (r - sigma)))^2 - D0, V(a / alpha + sigma) is independent of alpha and sigma. Therefore, it would be convenient to set r_onset = a_onset / alpha + sigma
and r_cutoff = a_cutoff / alpha + sigma
, for scalar a_onset
and a_cutoff
. This way, species with different alpha
or sigma
would still be using "the same potential", if that makes sense.
This logic also leads to a proposed change to energy.lennard_jones_pair
, which currently sets r_onset
and r_cutoff
using np.max(sigma)
, which violates this "same potential" idea.
I don't know if this is important enough to warrant the change, but I will say that I have seen people use the smoothing function S(r) as more than just a small perturbation at the tail of a potential: e.g. setting setting r_onset to be the minimum of a potential (r_onset = sigma
). The current implementation makes this impossible for systems with species.
Thanks to my intern @scnlong for first noticing this and helping put the repo together.
When simulating a system with many species but with fixed total number of particles, the computation time increases (pretty dramatically) with the number of species. See the following repo.
https://colab.research.google.com/drive/1-crXlbZ0gc94e3ECE8HMVPcpueNJRamD?usp=sharing
It's not in the repo, but the memory also seems to increases in a similar way.
I believe this is a feature request -- though I perhaps I am missing an obvious solution already available.
I would like to be able to selectively fix atom positions. I.e. be able to run an energy minimization with some atom locations fixed, then "release" them and run minimization again (with a different subset of atoms fixed).
I can not figure out how to do this since all positions are handled through the R
array whose elements are all updated by the minimization algorithms at every state.
The smap.pairwise function promotes an energy function that acts on a single pair to one that acts on all pairs. Instead, it would be nice to have a version of this that promoted a function to act only on specified pairs (that in principle could change during a simulation). This could be used to create permanent bonds. Perhaps the smap.pairwise function stays the same but you just use a new metric?
This is really easy to do by hand, but might be nice to have it automatically built in as an option somehow.
Basically, whenever there is a pair interaction that doesn't have strict cutoff, it is common to truncate it as some finite distance r_cut. However, you then get a small discontinuity in the potential at r_cut, so it is common to adjust the potential slightly so that it is C1 (continuous in the potential and first derivative). The best way I've seen this does is the 'xplor' mode in HOOMD Blue where you have a smoothing function S(r) that satisfies the following:
You then just multiply your potential by S(r) and choose ron and rcut to be in the tail of the potential.
I want to add new 2- and 3-body potentials is there any documentation that I can start with. Particularly, I want to use a NN force field in the JAX-MD, trained using TensorFlow.
Thanks,
Alireza
When creating batches
def batch(key):
steps_per_epoch = no_training_samples // batch_size
train_epochs = train_steps // steps_per_epoch
for s in range(train_epochs):
key, split = random.split(key)
permutation = random.shuffle(split, training_samples)
positions = train_positions[permutation]
features = train_features[permutation]
for i in range(0, no_training_samples, batch_size):
batch_data = (positions[permutation[i:i + batch_size]],
features[permutation[i:i + batch_size]])
yield batch_data
you first permute the training data
positions = train_positions[permutation]
and then access the permuted data via the permutations
positions[permutation[i:i + batch_size]],
This effectively shuffles and then unshuffles the data so you never have randomly shuffled batches.
In addition to simulate.nve and simulate.nvt, it would be nice to have simulate.langevin.
When analyzing a trajectory, it is sometimes necessary to "unwrap" motion across periodic boundaries (e.g. when calculating the diffusion constant of a particle). While this can sometimes be inferred from adjacent frames in a trajectory, this is not very robust and can lead to errors in some situations. I don't know what the best approach here is... LAMMPS has an integer variable (e.g. n1x, n1y, n1z, n2x, etc) that can be used to unwrap the positions, but there might be better ways.
The renderer has a bug: It cannot plot time varying parameters per particle.
Description of problem:
renderer.Disk can take as arguments time varying per particle parameters. See snippet of documentation:
position: An array of shape `(steps, count, dim)` or `(count, dim)`
specifying possibly time varying positions. Here `dim` is the spatial dimension.
size: An array of shape (steps, count)`, `(count,)`, or `()` specifying
possibly time-varying / per-disk diameters.
color: An array of shape `(steps, count, 3)` or `(count,)` specifying
possibly time-varying / per-disk RGB colors.
count: The number of disks.
However it seems to not be able to plot time varying parameters. It plots the parameter value at first timepoint and does not change the parameters at later timepoints.
How to reproduce bug: run following script
import numpy as onp
from jax import random
from jax_md.colab_tools import renderer
from jax.ops import index_update,index, index_add
n_timepoints = 4
n_particles = 10
#make a random array of positions with size: (n_timepoints,n_particles,2)
pos = 10* onp.random.rand(n_timepoints,n_particles,2)
#make an array of diamaters with size: (n_timepoints,n_particles)
#the diamaters of all particles could vary in time, in this example is always equal to one
diameter = np.ones((n_timepoints,n_particles))
#set one particle's diamater equal to 2 at first timepoint
diameter = index_update(diameter,index[0,0],2)
#the renderer should plot the diamaters in time, but actually only plots diamaters at first timepoint
#and indeed one particle has diameter 2 for all timepoint rather than just for first timepoint
renderer.render(box_size,
{ 'particles': renderer.Disk(pos, diameter)},
buffer_size=50)
renderer.render(box_size,
{ 'particles': renderer.Disk(position, diameter, color)})```
Line 75-76 in jax-md/jax_md/energy.py
Here dr is normalized so a value of 1 means they are one radius away. Therefore, your function is only non-zero if two bubbles are less than one radius apart (1), however, the bubbles begin to touch when they are 2 radii apart (2). Thus, your function is only nonzero after the bubbles are compressed halfway.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.