Git Product home page Git Product logo

jax-md's Introduction

JAX, M.D.

Accelerated, Differentiable, Molecular Dynamics

Quickstart | Reference docs | Paper | NeurIPS 2020

Build Status Coverage PyPI PyPI - License

Molecular dynamics is a workhorse of modern computational condensed matter physics. It is frequently used to simulate materials to observe how small scale interactions can give rise to complex large-scale phenomenology. Most molecular dynamics packages (e.g. HOOMD Blue or LAMMPS) are complicated, specialized pieces of code that are many thousands of lines long. They typically involve significant code duplication to allow for running simulations on CPU and GPU. Additionally, large amounts of code is often devoted to taking derivatives of quantities to compute functions of interest (e.g. gradients of energies to compute forces).

However, recent work in machine learning has led to significant software developments that might make it possible to write more concise molecular dynamics simulations that offer a range of benefits. Here we target JAX, which allows us to write python code that gets compiled to XLA and allows us to run on CPU, GPU, or TPU. Moreover, JAX allows us to take derivatives of python code. Thus, not only is this molecular dynamics simulation automatically hardware accelerated, it is also end-to-end differentiable. This should allow for some interesting experiments that we're excited to explore.

JAX, MD is a research project that is currently under development. Expect sharp edges and possibly some API breaking changes as we continue to support a broader set of simulations. JAX MD is a functional and data driven library. Data is stored in arrays or tuples of arrays and functions transform data from one state to another.

Getting Started

For a video introducing JAX MD along with a demo, check out this talk from the Physics meets Machine Learning series:

Science Meets ML Talk

To get started playing around with JAX MD check out the following colab notebooks on Google Cloud without needing to install anything. For a very simple introduction, I would recommend the Minimization example. For an example of a bunch of the features of JAX MD, check out the JAX MD cookbook.

You can install JAX MD locally with pip,

pip install jax-md --upgrade

If you want to build the latest version then you can grab the most recent version from head,

git clone https://github.com/google/jax-md
pip install -e jax-md

Overview

We now summarize the main components of the library.

Spaces (space.py)

In general we must have a way of computing the pairwise distance between atoms. We must also have efficient strategies for moving atoms in some space that may or may not be globally isomorphic to R^N. For example, periodic boundary conditions are commonplace in simulations and must be respected. Spaces are defined as a pair of functions, (displacement_fn, shift_fn). Given two points displacement_fn(R_1, R_2) computes the displacement vector between the two points. If you would like to compute displacement vectors between all pairs of points in a given (N, dim) matrix the function space.map_product appropriately vectorizes displacement_fn. It is often useful to define a metric instead of a displacement function in which case you can use the helper function space.metric to convert a displacement function to a metric function. Given a point and a shift shift_fn(R, dR) displaces the point R by an amount dR.

The following spaces are currently supported:

Example:

from jax_md import space
box_size = 25.0
displacement_fn, shift_fn = space.periodic(box_size)

Potential Energy (energy.py)

In the simplest case, molecular dynamics calculations are often based on a pair potential that is defined by a user. This then is used to compute a total energy whose negative gradient gives forces. One of the very nice things about JAX is that we get forces for free! The second part of the code is devoted to computing energies.

We provide the following classical potentials:

We also provide the following neural network potentials:

For finite-ranged potentials it is often useful to consider only interactions within a certain neighborhood. We include the _neighbor_list modifier to the above potentials that uses a list of neighbors (see below) for optimization.

Example:

import jax.numpy as np
from jax import random
from jax_md import energy, quantity
N = 1000
spatial_dimension = 2
key = random.PRNGKey(0)
R = random.uniform(key, (N, spatial_dimension), minval=0.0, maxval=1.0)
energy_fn = energy.lennard_jones_pair(displacement_fn)
print('E = {}'.format(energy_fn(R)))
force_fn = quantity.force(energy_fn)
print('Total Squared Force = {}'.format(np.sum(force_fn(R) ** 2)))

Dynamics (simulate.py, minimize.py)

Given an energy function and a system, there are a number of dynamics are useful to simulate. The simulation code is based on the structure of the optimizers found in JAX. In particular, each simulation function returns an initialization function and an update function. The initialization function takes a set of positions and creates the necessary dynamical state variables. The update function does a single step of dynamics to the dynamical state variables and returns an updated state.

We include a several different kinds of dynamics. However, there is certainly room to add more for e.g. constant strain simulations.

It is often desirable to find an energy minimum of the system. We provide two methods to do this. We provide simple gradient descent minimization. This is mostly for pedagogical purposes, since it often performs poorly. We additionally include the FIRE algorithm which often sees significantly faster convergence. Moreover a common experiment to run in the context of molecular dynamics is to simulate a system with a fixed volume and temperature.

We provide the following dynamics:

Example:

from jax_md import simulate
temperature = 1.0
dt = 1e-3
init, update = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, temperature)
state = init(key, R)
for _ in range(100):
  state = update(state)
R = state.position

Spatial Partitioning (partition.py)

In many applications, it is useful to construct spatial partitions of particles / objects in a simulation.

We provide the following methods:

Cell List Example:

from jax_md import partition

cell_size = 5.0
capacity = 10
cell_list_fn = partition.cell_list(box_size, cell_size, capacity)
cell_list_data = cell_list_fn.allocate(R)

Neighbor List Example:

from jax_md import partition

neighbor_list_fn = partition.neighbor_list(displacement_fn, box_size, cell_size)
neighbors = neighbor_list_fn.allocate(R) # Create a new neighbor list.

# Do some simulating....

neighbors = neighbors.update(R)  # Update the neighbor list without resizing.
if neighbors.did_buffer_overflow:  # Couldn't fit all the neighbors into the list.
  neighbors = neighbor_list_fn.allocate(R)  # So create a new neighbor list.

There are three different formats of neighbor list supported: Dense, Sparse, and OrderedSparse. Dense neighbor lists store neighbors in an (particle_count, neighbors_per_particle) array, Sparse neighbor lists store neighbors in a (2, total_neighbors) array of pairs, OrderedSparse neighbor lists are like Sparse neighbor lists, but they only store pairs such that i < j.

Development

JAX MD is under active development. We have very limited development resources and so we typically focus on adding features that will have high impact to researchers using JAX MD (including us). Please don't hesitate to open feature requests to help us guide development. We more than welcome contributions!

Technical gotchas

GPU

You must follow JAX's GPU installation instructions to enable GPU support.

64-bit precision

To enable 64-bit precision, set the respective JAX flag before importing jax_md (see the JAX guide), for example:

from jax.config import config
config.update("jax_enable_x64", True)

Publications

JAX MD has been used in the following publications. If you don't see your paper on the list, but you used JAX MD let us know and we'll add it to the list!

  1. A Differentiable Neural-Network Force Field for Ionic Liquids. (J. Chem. Inf. Model. 2022)
    H. Montes-Campos, J. Carrete, S. Bichelmaier, L. M. Varela, and G. K. H. Madsen
  2. Correlation Tracking: Using simulations to interpolate highly correlated particle tracks. (Phys. Rev. E. 2022)
    E. M. King, Z. Wang, D. A. Weitz, F. Spaepen, and M. P. Brenner
  3. Optimal Control of Nonequilibrium Systems Through Automatic Differentiation.
    M. C. Engel, J. A. Smith, and M. P. Brenner
  4. Graph Neural Networks Accelerated Molecular Dynamics. (J. Chem. Phys. 2022)
    Z. Li, K. Meidani, P. Yadav, and A. B. Farimani
  5. Gradients are Not All You Need.
    L. Metz, C. D. Freeman, S. S. Schoenholz, and T. Kachman
  6. Lagrangian Neural Network with Differential Symmetries and Relational Inductive Bias.
    R. Bhattoo, S. Ranu, and N. M. A. Krishnan
  7. Efficient and Modular Implicit Differentiation.
    M. Blondel, Q. Berthet, M. Cuturi, R. Frostig, S. Hoyer, F. Llinares-Lรณpez, F. Pedregosa, and J.-P. Vert
  8. Learning neural network potentials from experimental data via Differentiable Trajectory Reweighting.
    (Nature Communications 2021)

    S. Thaler and J. Zavadlav
  9. Learn2Hop: Learned Optimization on Rough Landscapes. (ICML 2021)
    A. Merchant, L. Metz, S. S. Schoenholz, and E. D. Cubuk
  10. Designing self-assembling kinetics with differentiable statistical physics models. (PNAS 2021)
    C. P. Goodrich, E. M. King, S. S. Schoenholz, E. D. Cubuk, and M. P. Brenner

Citation

If you use the code in a publication, please cite the repo using the .bib,

@inproceedings{jaxmd2020,
 author = {Schoenholz, Samuel S. and Cubuk, Ekin D.},
 booktitle = {Advances in Neural Information Processing Systems},
 publisher = {Curran Associates, Inc.},
 title = {JAX M.D. A Framework for Differentiable Physics},
 url = {https://papers.nips.cc/paper/2020/file/83d3d4b6c9579515e1679aca8cbc8033-Paper.pdf},
 volume = {33},
 year = {2020}
}

jax-md's People

Contributors

abhijeetgangan avatar adrhill avatar amilmerchant avatar ananduri avatar arturtoshev avatar cagrikymk avatar chiang-yuan avatar cpgoodri avatar ekindogus avatar ellamichelleking avatar ernoc avatar flokno avatar hadrianmontes avatar marcberneman avatar mattjj avatar maxilechner avatar niklasschmitz avatar oliverdutton avatar pmarks avatar ravinderbhattoo avatar ruibin-liu avatar sanghyukyoo avatar sschoenholz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jax-md's Issues

Precision problem with jax.jit

Hi, I follow the instructions in Appendix A of your paper and run the code below. It runs slowly so I want to enable jax.jit for apply_fn, and simply add apply_fn = jit(apply_fn). However, after 1000 iterations, the results are totally different, while the results are similar in 1-10 iterations. How can I deal with it? I also tried to enable 64-bit float, but it didn't work.

from jax import random, jit
from jax_md import energy, space, simulate


N = 32
dt = 1e-1
temperature = 0.1
box_size = 5.0
key = random.PRNGKey(0)
displacement, shift = space.periodic(box_size)
energy_fn = energy.soft_sphere_pair(displacement)


def simulation(key):
    pos_key, sim_key = random.split(key)
    R = random.uniform(pos_key, (N, 2), maxval=box_size)
    init_fn, apply_fn = simulate.brownian(
        energy_fn, shift, dt, temperature)
    state = init_fn(sim_key, R)
    for i in range(1000):
        state = apply_fn(state)
    return state.position


positions = simulation(key)
print(positions)

scaling typo

In the bubble tutorial in pair_correlation_fun

dr = np.where(dr > 1e-7, dr, 1e7)

I believe the last number should be 1e-7

Possible bug with floats when setting the mass

When running a nose_hoover simulation that was working yesterday, today I am getting the following error:

https://pastebin.com/PPMuSTVh

(I can give you the specific code if need, but I don't think it will be.) However, if I change

state = init(key, Rinit)

to

state = init(key, Rinit, mass=float(1.0))

it works. My guess is this is an easy fix in line 70 of quantity.py since the default for mass is np.float32(1.0).

Computing stress

Hello! This is a very cool project!

I have a few questions related to using jax-md for condensed matter MD, in particular related to computing the stress. I'll give some context below, feel free to skip if this is obvious already!


In that context, the stress tensor sigma is defined as the derivative of the total energy E with respect to a 3x3 infinitesimal strain tensor epsilon:

sigma_ab = 1/V * (d E / d epsilon_ab) at epsilon = 0

epsilon describes a strain transformation of all real space coordinates R as:

R -> (1 + e) @ R

This transformation is applied to all coordinates, i.e. the atoms in the unit cell and the basis vectors.


In an autodiff framework this shouldn't be too difficult to implement. SchNetPack does it by applying the strain transformation to R and basis before the forward pass, and then computing the gradient. In jax-md, however, the problem is that the basis is treated as a parameter to the energy function, via the displacement and only the coordinates are treated as argument, so it's easy to compute the gradient with respect to R, but not with respect to the basis.

The logical way to apply the strain transformation in jax-md is in the displacement function, since if one transforms the displacement_fn (and the shift_fn, but this seems to not be needed to only compute energies), the energy behaves as though it was computed in transformed coordinates. (Since only displacement vectors are ever used to compute energies.)

I've tried to implement this experimentally here by simply wrapping the periodic_general displacement function. This works and gives correct results, but seems to be un-jittable, presumably because I'm re-defining the energy function when it's called.

An alternative approach would be to "hack" the feature of jax-md that allows the T parameter of periodic_general to be a callable: if one lets T be a function that accepts the strain as argument:

def T(strain: Array = np.zeros((3, 3), dtype=np.double), t=0) -> Array:
    strain_transformation = np.eye(N=3, M=3, dtype=np.double) + strain
    return transform(strain_transformation, basis)

Then passing strain = np.zeros((3, 3), dtype=np.double) to the energy function will work. However, this runs into a problem: jax-md defines the Jacobian-vector product of the transform function such that it doesn't act on the backwards pass here, so taking the gradient with respect to strain in the above example will give zero stress. If the custom jvp is commented out it seems to work (even though I haven't tested it for correctness). This approach also seems to not play well with jit, but I'm not sure why.

So, finally, my questions:

  • Why is the jvp for transform defined this way? Removing it doesn't seem to break jax-md, but I assume there's a good reason why it is this way...
  • Do you have any thoughts on a reasonable implementation for stress? For our systems, we can't really use a framework that can't compute it.

(Tagging in @fabiannagel since he's also involved in this!)

Feature request: Potential truncation and smoothing

This is really easy to do by hand, but might be nice to have it automatically built in as an option somehow.

Basically, whenever there is a pair interaction that doesn't have strict cutoff, it is common to truncate it as some finite distance r_cut. However, you then get a small discontinuity in the potential at r_cut, so it is common to adjust the potential slightly so that it is C1 (continuous in the potential and first derivative). The best way I've seen this does is the 'xplor' mode in HOOMD Blue where you have a smoothing function S(r) that satisfies the following:

  • S(r) = 1 when r< ron
  • S(r) = 0 when r> rcut
  • S(r) is C1 everywhere

You then just multiply your potential by S(r) and choose ron and rcut to be in the tail of the potential.

nans in MD cookbook notebook

This seems like a really excellent project and I am excited about the potential for gradient-based optimization with your approach. However, in running the MD cookbook notebook, every simulation or minimization routine results in nans. I was surprised to find this even for simple Lennard-Jones fluid at low density and with very small time steps (it occurs even for 1e-6 sometimes). I am running directly from Google Colab so it doesn't seem like an issue on my end. Is there a stable version I could try rolling back to where this has been tested?

I also tried the jax.config.update("jax_debug_nans", True) setting and got the following output from the first Brownian dynamics simulation of the bubble raft on the first step:

Invalid value encountered in the output of a jit function. Calling the de-optimized version.


FloatingPointError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
528 try:
--> 529 return compiled_fun(*args)
530 except FloatingPointError:

21 frames
FloatingPointError: invalid value (nan) encountered in xla_call

During handling of the above exception, another exception occurred:

FloatingPointError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _check_nans(name, xla_shape, buf)
350 if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
351 if np.any(np.isnan(buf.to_py())):
--> 352 raise FloatingPointError(f"invalid value (nan) encountered in {name}")
353
354 ### compiling jaxprs

FloatingPointError: invalid value (nan) encountered in mul

Selectively fix atom positions

I believe this is a feature request -- though I perhaps I am missing an obvious solution already available.

I would like to be able to selectively fix atom positions. I.e. be able to run an energy minimization with some atom locations fixed, then "release" them and run minimization again (with a different subset of atoms fixed).

I can not figure out how to do this since all positions are handled through the R array whose elements are all updated by the minimization algorithms at every state.

Usage of nan_to_num can be in conflict with jax_debug_nans.

Setting the jax_debug_nans flag to True can prevent you from computing energy.lennard_jones(x).
By using np.nan_to_num seperately on idr6 and idr12 this issue can be circumvented. Alternatively jit compiling the energy function also works.
I don't know if this issue comes up in a different place or if it is even worth solving.

Here is a minimal repro:

import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
config.update("jax_debug_nans", True)

from jax import random, grad, jit

from jax_md import space, energy

N = 5

dimension = 2
box_size = 12
displacement, shift = space.periodic(box_size) 

key = random.PRNGKey(0)
key, split = random.split(key)
R = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=jnp.float64) 

energy_fn = energy.lennard_jones_pair(displacement)

#Works
print(jit(energy_fn)(R))
#Doesn't work
print(energy_fn(R))

pass species to energy.multiplicitive_isotropic_cutoff

This is a low priority comment as I do not currently need this for my work.

My understanding is that in energy.multiplicitive_isotropic_cutoff, you can pass r_onset and r_cutoff as scalars or as [n,m] arrays, but not on a per-species level. It is often desirable for these parameters to be functions of the other parameters in the potential. For example, in the Morse potential, V(r) = D0 * (1 - exp(-alpha * (r - sigma)))^2 - D0, V(a / alpha + sigma) is independent of alpha and sigma. Therefore, it would be convenient to set r_onset = a_onset / alpha + sigma and r_cutoff = a_cutoff / alpha + sigma, for scalar a_onset and a_cutoff. This way, species with different alpha or sigma would still be using "the same potential", if that makes sense.

This logic also leads to a proposed change to energy.lennard_jones_pair, which currently sets r_onset and r_cutoff using np.max(sigma), which violates this "same potential" idea.

I don't know if this is important enough to warrant the change, but I will say that I have seen people use the smoothing function S(r) as more than just a small perturbation at the tail of a potential: e.g. setting setting r_onset to be the minimum of a potential (r_onset = sigma). The current implementation makes this impossible for systems with species.

Feature Request or Modifying Pair Potential

I want to add new 2- and 3-body potentials is there any documentation that I can start with. Particularly, I want to use a NN force field in the JAX-MD, trained using TensorFlow.

Thanks,
Alireza

Feature request: bonds

The smap.pairwise function promotes an energy function that acts on a single pair to one that acts on all pairs. Instead, it would be nice to have a version of this that promoted a function to act only on specified pairs (that in principle could change during a simulation). This could be used to create permanent bonds. Perhaps the smap.pairwise function stays the same but you just use a new metric?

Feature Request: Time-dependent temperature

Having the energy be time-dependent is really useful, but it seems like it may not be able to interfacing with temperature changing with time. It would be nice to be able to have temperature schedules be time-dependent in addition to the energy.

Bug in jax_md.colab_tools.renderer: The renderer does not plot time varying per particle parameters (e.g. diameter)

The renderer has a bug: It cannot plot time varying parameters per particle.

Description of problem:
renderer.Disk can take as arguments time varying per particle parameters. See snippet of documentation:

 position: An array of shape `(steps, count, dim)` or `(count, dim)` 
 specifying possibly time varying positions. Here `dim` is the spatial dimension.
 size: An array of shape (steps, count)`, `(count,)`, or `()` specifying
 possibly time-varying / per-disk diameters.
 color: An array of shape `(steps, count, 3)` or `(count,)` specifying
 possibly time-varying / per-disk RGB colors.
 count: The number of disks.

However it seems to not be able to plot time varying parameters. It plots the parameter value at first timepoint and does not change the parameters at later timepoints.

How to reproduce bug: run following script

import numpy as onp
from jax import random
from jax_md.colab_tools import renderer
from jax.ops import index_update,index, index_add

n_timepoints = 4
n_particles = 10
#make a random array of positions with size: (n_timepoints,n_particles,2)
pos = 10* onp.random.rand(n_timepoints,n_particles,2)
#make an array of diamaters with size: (n_timepoints,n_particles)
#the diamaters of all particles could vary in time, in this example is always equal to one
diameter = np.ones((n_timepoints,n_particles))
#set one particle's diamater equal to 2 at first timepoint 
diameter = index_update(diameter,index[0,0],2)

#the renderer should plot the diamaters in time, but actually only plots diamaters at first timepoint 
#and indeed one particle has diameter 2 for all timepoint rather than just for first timepoint
renderer.render(box_size, 
                { 'particles': renderer.Disk(pos, diameter)},
                buffer_size=50)

renderer.render(box_size, 
                { 'particles': renderer.Disk(position, diameter, color)})```

Feature Request: keep track of movement across periodic boundary conditions

When analyzing a trajectory, it is sometimes necessary to "unwrap" motion across periodic boundaries (e.g. when calculating the diffusion constant of a particle). While this can sometimes be inferred from adjacent frames in a trajectory, this is not very robust and can lead to errors in some situations. I don't know what the best approach here is... LAMMPS has an integer variable (e.g. n1x, n1y, n1z, n2x, etc) that can be used to unwrap the positions, but there might be better ways.

Feature request: wrap function when wrapped=False in space.periodic and space.periodic_general

When passing wrapped=False to the periodic space functions, it seems that in addition to displacement and shift, it is important to define a third function that allows the user to put the particles back in the box. A typical use case would be to use unwrapped positions during a simulation because they are (presumably) faster, and then do something like R = wrap(state.position) to get the final positions. My thinking is that there are two options for the api:

option 1:
displacement, shift = space.periodic(L)
displacement, shift, wrap = space.periodic(L, wrapped=False)

option 2:
displacement, shift = space.periodic(L)
displacement, shift = space.periodic(L, wrapped=False)
displacement, shift, wrap = space.periodic(L, wrapped=False, wrap_function=True)

The second option might be preferable since it would be backwards compatible. I would be happy to take a first shot at this if that would be helpful.

Neural networks potential example and optax import

Looks like the neural network potentials notebook needs to be updated for the new location of optax given error produced in "Imports & Utils":

     21 
     22 from jax import random
---> 23 from jax.experimental import optix
     24 
     25 from jax_md import energy, space, simulate, quantity

ImportError: cannot import name 'optix'

Fixed by adding !pip install -q git+https://www.github.com/deepmind/optax and replacing from jax.experimental import optix with import optax as optix without modifying the rest, alternatively could go through and update all the other usages of "optix".

Lmk if you want a patch or would rather do it. Awesome notebook btw! ๐Ÿ˜€

energy.soft_sphere radii vs diameter

Line 75-76 in jax-md/jax_md/energy.py

Here dr is normalized so a value of 1 means they are one radius away. Therefore, your function is only non-zero if two bubbles are less than one radius apart (1), however, the bubbles begin to touch when they are 2 radii apart (2). Thus, your function is only nonzero after the bubbles are compressed halfway.

Bug/Documentation problem with **kwargs for simulate/minimize methods.

When passing a wrapped function to simulate.brownian I get an " unexpected keyword argument 't' " error. This is easily resolved by passing through **kwargs, but I'd prefer if this was not necessary as this is not documented and only needed in some cases.

example

import jax.numpy as jnp

from jax.config import config
config.update("jax_enable_x64", True)

from jax import random
from jax_md import space, energy, simulate

displacement, shift = space.periodic(side_length)
energy_fn = energy.soft_sphere_pair(displacement)

This does not work

no_keywords_fn = lambda R: energy_fn(R)

init, apply = simulate.brownian(no_keywords_fn, shift, dt=0.00001, T_schedule=1.0, gamma=0.1)
key, split = random.split(key)
state = init(split, R)
apply(state)

This does work

keywords_fn = lambda R, **kwargs: energy_fn(R,**kwargs)

init, apply = simulate.brownian(keywords_fn, shift, dt=0.00001, T_schedule=1.0, gamma=0.1)
key, split = random.split(key)
state = init(split, R)
apply(state)

This is not an issue for nvt_nose_hoover and for minimize.fire_decent.

The same issue also appears for nve.

And for nvt_langevin this issue appears for both init and apply.

Possible improvements to neighbor lists

A few minor suggestions to help improve the clarity/performance of neighbor lists:

  1. Split neighbor_fn (the returned function of partition.neighbor_list) into two, rather than having the two different use cases where only one case is compatible with jit/grad. So, for example:
neighbor_fn, energy_fn = lennard_jones_neighbor_list(displacement, box_size)

#construct the neighbor list
nbrs = neighbor_fn(R)

#update the neighbor list
nbrs = neighbor_fn(R, nbrs)

would become

(construct_neighbor_fn, update_neighbor_fn), energy_fn = lennard_jones_neighbor_list(displacement, box_size)

#construct the neighbor list
nbrs = construct_neighbor_fn(R)

#update the neighbor list
nbrs = update_neighbor_fn(R, nbrs)

It would then be easier to communicate that update_neighbor_fn can be included with jax transformation, but construct_neighbor_fn cannot.

  1. Pass kwargs through the neighbor list convenience functions in energy.py. Unless I'm missing something, there currently is not a way to change, e.g., capacity_multiplier using these convenience functions.

  2. I know this has come up in the past, but I forget the rationale behind scaling dr_threshold, e.g.:
    https://github.com/google/jax-md/blob/c55ec95999331844a72a1e1a3a3009276c60a098/jax_md/energy.py#L223
    This is done for some of the potentials (soft_sphere, LJ) but not others (morse, bks). My feeling is that physically dr_threshold is a buffer for how far particles can move before the neighbor lists have to be recomputed, but big particles don't typically move faster than small particles (depending on the dynamics/potentials). I'm all for coming up with intelligent heuristics for what the optimum dr_threshold should be, its not clear that this is the right heuristic. But I'm happy to be wrong.

ImportError: cannot import name 'custom_jvp'

I've followed the instructions for the pip install as well as pip install'ing from the most recent clone of the repo. In both cases, I double check that I'm not inside of the jax_md dir on my system. So I follow the space.py instructions and I get the following error:

>>> from jax_md import space
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/workspace/wsa/jones289/miniconda3/envs/jaxmd/lib/python3.5/site-packages/jax_md/__init__.py", line 15, in <module>
    from jax_md import space
  File "/usr/workspace/wsa/jones289/miniconda3/envs/jaxmd/lib/python3.5/site-packages/jax_md/space.py", line 41, in <module>
    from jax import custom_jvp
ImportError: cannot import name 'custom_jvp'

I tried previously with python 3.7, so I tried reinstalling with python=3.5, but that doesn't seem to be the issue. I'm not sure if I misread the documentation on how to setup jax_md but clearly something is wrong here.

Static bond parameters being used instead of dynamic bond parameters

This assumes you have a system of monodisperse spheres with positions stored in R. Let me know if you want me to share my colab.

This code sets up a list of bonds from the topology of a sphere packing but sets the rest length equal to the current distance between the particles. Therefore, the energy should be zero, but for the dynamic case it is equal to the energy of the sphere packing, which is what you would expect if you use length=1.0 (i.e. the default static parameter).

#use this to compare later
print(energy_fn(R))
print(np.linalg.norm(grad(energy_fn)(R)))
displacement_all = vmap(vmap(displacement, (0, None), 0), (None, 0), 0)
dR = displacement_all(R, R)
dr = space.distance(dR)

#0 if particles do not overlap
#rest length if particles do overlap
bond_lengths_matrix = np.where(dr<1.,dr,0.)
bond_lengths_matrix = np.triu(bond_lengths_matrix) #make upper triangular

index_list=np.dstack(np.meshgrid(np.arange(N), np.arange(N), indexing='ij'))

i_s = np.where(bond_lengths_matrix>0.0001, index_list[:,:,0], -1).flatten()
j_s = np.where(bond_lengths_matrix>0.0001, index_list[:,:,1], -1).flatten()
l_s = bond_lengths_matrix.flatten()
temp = np.transpose(np.array([i_s,j_s]))

bond_list = temp[(temp!=np.array([-1,-1]))[:,1]]
length_list = l_s[(temp!=np.array([-1,-1]))[:,1]]
k_list = np.full(length_list.shape, 1., dtype=np.float64)

def simple_spring_bond_float64(
    displacement_or_metric, bond, bond_type=None, length=1, epsilon=1, alpha=2):
  """Convenience wrapper to compute energy of particles bonded by springs."""
  length = np.array(length, np.float64)
  epsilon = np.array(epsilon, np.float64)
  alpha = np.array(alpha, np.float32)
  return smap.bond(
    energy.simple_spring,
    energy._canonicalize_displacement_or_metric(displacement_or_metric),
    bond,
    bond_type,
    length=length,
    epsilon=epsilon,
    alpha=alpha)

#Pass bond info dynamically (the energy here should be zero)
bond_energy_fn = simple_spring_bond_float64(displacement, None)
print(bond_energy_fn(R,bond_list, length=length_list, epsilon=k_list, alpha=2.))
print(np.linalg.norm(grad(bond_energy_fn)(R,bond_list, length=length_list, epsilon=k_list, alpha=2.)))

#pass bonds statically
bond_energy_fn = simple_spring_bond_float64(displacement, bond_list, length=length_list, epsilon=k_list, alpha=2.)
print(bond_energy_fn(R))
print(np.linalg.norm(grad(bond_energy_fn)(R)))```

Broken link to jax wheel in Colab notebook

When I try running the colab notebooks, I get the following error in the first cell:

ERROR: HTTP error 404 while getting https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
  ERROR: Could not install requirement jaxlib==0.1.23 from https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl because of error 404 Client Error: Not Found for url: https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
ERROR: Could not install requirement jaxlib==0.1.23 from https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl because of HTTP error 404 Client Error: Not Found for url: https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl for URL https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
     |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 215kB 1.4MB/s 
     |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 61kB 26.2MB/s 
  Building wheel for jax (setup.py) ... done
  Building wheel for opt-einsum (setup.py) ... done
  Building wheel for jax-md (setup.py) ... done

Feature request: rigid bodies

It would be nice to be able to define rigid bodies (i.e. a collection of particles whose relative positions are fixed up to translations and rotation.

This is probably a fairly intensive request because it implies the need for rotational degrees of freedom, torques, etc.

option to not recast parameters in energy pair functions

In functions such as soft_sphere_pair, parameters (e.g. sigma, epsilon, alpha) are explicitly cast to np.float32. However, sometimes it is actually important to keep some parameters at double precision. It would be nice if there was an option that would direct soft_sphere_pair to recast the parameters. Something like

energy_fn = energy.soft_sphere_pair(displacement, sigma=jnp.array([1.0],dtype=jnp.float64), dont_recast=True)

Is there a reason not to do this?

FireDescentState changing type and causing error

Some code that was working a few days ago is now throwing an error. I know there have been a lot of changes recently, and I think this is an easy fix (but I'm not 100% sure what the right thing to do is, so I'm filing an issue rather than a PR).

Without going into the details (I can send code if necessary), I am doing a FIRE minimization using lax.scan, and it appears the type of the FireDescentState is changing. Specifically n_pos is changing from an int64 to a float32. My guess is that in line 154 of minimize.py:

    return FireDescentState(R, V, force(R, **kwargs), dt_start, alpha_start, 0)  # pytype: disable=wrong-arg-count

the 0 should be explicitly cast to an f32?

Here's the error I'm getting:

FireDescentState(position=ShapedArray(float64[19,7,3]), velocity=ShapedArray(float64[19,7,3]), force=ShapedArray(float64[19,7,3]), dt=ShapedArray(float16[]), alpha=ShapedArray(float16[]), n_pos=ShapedArray(float32[]))
and
FireDescentState(position=ShapedArray(float64[19,7,3]), velocity=ShapedArray(float64[19,7,3]), force=ShapedArray(float64[19,7,3]), dt=ShapedArray(float16[]), alpha=ShapedArray(float16[]), n_pos=ShapedArray(int64[])).

Feature Request: Verlet lists

This is not something that I need now, but it is almost certainly a feature that you'll eventually want to include. The standard (e.g. in LAMMPS and HOOMD) for fast neighbor detection is to use a combination of cell lists (which you already have) and Verlet lists, i.e. use cell lists to create and periodically update Verlet lists. I've noticed a non-negligible speedup compared to just cell lists in my own code, and for people who really care about these things, I think it's pretty important.

I would be happy to chat more about this.

error: GPU memory usage is close to the limit

To reproduce the problem:

  1. Runtime: reset all runtimes
  2. run:
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.15-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
!pip install -q `git+https://www.github.com/google/jax-md

import numpy as onp
import jax.numpy as np
from jax import random
from jax import jit, grad
from jax.config import config
config.update("jax_enable_x64", True)
from jax_md import space, smap, energy, minimize, quantity, simulate`

def SetPositionsOctahedron():
  a=1/np.sqrt(2.)
  R=list(np.full((6,3),0.))
  R[0]=np.array([a,0,0])
  R[1]=np.array([-a,0,0])
  R[2]=np.array([0,a,0])
  R[3]=np.array([0,-a,0])
  R[4]=np.array([0,0,a])
  R[5]=np.array([0,0,-a])
  return np.array(R).astype(np.float64)

At this point, everything is fine, but if you then run:

box_size = 8.
Rinit = SetPositionsOctahedron()+box_size/2

you get the following error message:

GPU memory usage is close to the limit
Your GPU is close to its memory limit. You will not be able to use any additional memory in this session. Currently, 10.07 GB / 11.17 GB is being used. Would you like to terminate some sessions in order to free up GPU memory (state will be lost for those sessions)?

If I press "ignore," the issue doesn't go away but it doesn't seem to negatively affect me. That might be because I'm not (yet) running computationally intensive work. So the good news I don't need this fixed in order to keep working (for now), but I thought you should be away of it.

Vectorizing neighbor function: 'vmap(neighbor_fn)' doesn't work.

import jax.numpy as np
from jax import vmap
from jax_md.util import f32
from jax_md import space, energy

batch_size = 32
N = 100
dim = 2
box_size = 10

displacement, shift = space.periodic(box_size)
neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(displacement, box_size)
R = np.zeros((batch_size, N, dim), dtype=f32)

nbrs = vmap(neighbor_fn)(R)

When I tried to do vectorize the neighbor_fn, I got the following error:

~/.anaconda3/envs/md/python3.7/site-packages/jax_md/partition.py in _estimate_cell_capacity(R, box_size, cell_size, buffer_size_multiplier)
    173   cell_capacity = onp.max(count_cell_filling(R, box_size, cell_size))
--> 174   return int(cell_capacity * buffer_size_multiplier)
    175 

FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.

is there any solution to bypass this problem?

Feature Request: energy function loader for common force fields

It is a very interesting project! I am wondering whether it is possible to include a way to include common force field based energy functions?

Such as if we hope to simulation real biological systems with both common force fields (such as Amber14SB, Charmm36) and machine learning based force fields?

Thanks!

Individual particle energy

Hello folks! First of all, I am new in the use of jax-md. So, I as for your comphreension if my question is silly:
Is it possible to extract/obtain a single particle energy contribution to system total energy during a simulation?

Bug in multiplicative_isotropic_cutoff

Unless I'm missing something, in the implementation of multiplicative_isotropic_cutoff, the function smooth_fn(dr) assumes that dr is an (N,N) array of distances, not an (N,N,d) array of displacements. However, when it's passed to, for example, smap.pair, it gets passed an (N,N,d) array of displacements. The solution is to change

def cutoff_fn(dr, *args, **kwargs): 
    return smooth_fn(dr) * fn(dr, *args, **kwargs)

to

def cutoff_fn(dR, *args, **kwargs):
    dr = space.distance(dR)
    return smooth_fn(dr) * fn(dr, *args, **kwargs)

at line 232 of energy.py

Feature request: energy.isotropic function

The function energy.multiplicative_isotropic_cutoff creates the need for potentials like soft_sphere take an (N,N) array of distances rather than an (N,N,d) array of displacements. However, smap.pair, among others, still need a function that takes (N,N,d) arrays of displacements. In the solution I just proposed to #55, energy.multiplicative_isotropic_cutoff serves this purposes of calling space.distance(dR) and passing the result to the wrapped energy function. I propose a similar function, lets call it energy.isotropic, whose only purpose is to call space.distance(dR). This would increase code reusability, for example:

def my_potential(dr, **kwargs):
  ...

def my_potential_pair(metric):
  return smap.pair(
      energy.isotropic(my_potential),
      metric)

def my_potential_pair_cutoff(metric, r_onset, r_cutoff):
  return smap.pair(
      energy.multiplicative_isotropic_cutoff(my_potential, r_onset, r_cutoff),
      metric)

Feature Request: Brownian dynamics

In addition to simulate.nve and simulate.nvt, it would be nice to have simulate.brownian. This is just the over-damped limit of Langevin dynamics, but I believe typical implementations are very different so I'm listing it as a separate issue.

More variable simulation parameters

In the new update to simulate.py, the parameter T_schedule was replaced with kT, with the idea that users can dynamically pass kT to the apply function to get a variable temperature. Is there a reason not to allow, for example, dt and gamma in simulate.brownian to be variable as well? Seems like an easy change with no downside.

cookbook: shuffling batches

When creating batches

def batch(key):
steps_per_epoch = no_training_samples // batch_size
train_epochs = train_steps // steps_per_epoch
for s in range(train_epochs):
key, split = random.split(key)
permutation = random.shuffle(split, training_samples)
positions = train_positions[permutation]
features = train_features[permutation]
for i in range(0, no_training_samples, batch_size):
batch_data = (positions[permutation[i:i + batch_size]],
features[permutation[i:i + batch_size]])
yield batch_data

you first permute the training data
positions = train_positions[permutation]
and then access the permuted data via the permutations
positions[permutation[i:i + batch_size]],

This effectively shuffles and then unshuffles the data so you never have randomly shuffled batches.

Feature request: Morse potential

This is straightforward. If it helps, here is how I did it (based on the soft_sphere potential).

def morse(dR, D0=1.0, alpha=30.0, r0=1.0):
  dr = space.distance(dR)
  U = D0*(np.exp(-2.*alpha*(dr-r0))-2.*np.exp(-alpha*(dr-r0)))
  return np.array(U, dtype=dr.dtype)

def morse_pairwise(
    metric, species=None, D0=1.0, alpha=30.0, r0=1.0):
  return smap.pairwise(
      morse,
      metric,
      species=species,
      D0=D0,
      alpha=alpha,
      r0=r0)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.