Comments (9)
Hey Marcel,
Thanks for the great questions! Sorry about the delay in response; it's been a busy week pre-ICML.
I have been thinking about these issues quite a bit, and actually have had a replacement for periodic_general since the fall; but I was waiting until I had finished a refactor of the NVT simulations (along with new NPT simulations) to check it in.
New Periodic General
In any case, here is the new version of the code. Instead of passing the transformation as a callable (which was originally designed to mimic JAX's optimizers) you can now override the box directly. For example,
displacement_fn, shift_fn = periodic_general(box)
energy_fn = energy.soft_sphere_pair(displacement)
E = energy_fn(position) # Will compute energy using `box`.
E_new = energy_fn(position, box=new_box) # Will compute energy using `new_box`.
As you suggest, this makes it easy to leverage autodiff to compute the stress tensor:
deformation_energy = lambda epsilon: energy_fn(position, box=box + epsilon)
stress_tensor = grad(deformation_energy)(np.zeros_like(box)) / np.linalg.det(box)
Here is a colab notebook that puts it all together.
Use of Custom JVP
One of the changes to the new periodic general implementation that you'll probably notice is that we changed the custom JVP so that gradient information about the box does flow through the transformation. I think the behavior of the current version of the code (zeroing out gradients wrt box) is incorrect. Having said that I'd love to explain why we use a custom JVP in the first place to get your thoughts.
periodic_general
is supposed to take positions in the unit cube grad(energy_fn)
to produce forces then the derivatives with be backpropagated all the way to the unit cube. Therefore, periodic_general
uses a custom JVP that doesn't do this final step of the chain rule.
At the time of writing the original periodic_general
I hadn't realized that we could simultaneously 1) not differentiate with respect to positions in the unit cube and 2) faithfully propagate gradients about the box. However, I have since realized that it is possible to both and the new version of the code takes this into account and should work properly.
Finally
Thanks for trying out jax md! Please let me know if you have any thoughts about how to improve the periodic_general
code, if you notice bugs, or any other feedback about the library itself. I think the new version here is significantly better than the version that is currently checked in, but it's still not completely vetted and feedback would be greatly appreciated!
The new changes do break backward compatibility, so I have been waiting until after ICML to check them in, in the event that it breaks peoples code. However, hopefully this Friday I will officially update the periodic_general
code.
from jax-md.
As of version 0.1.13 the new periodic_general
function is in.
from jax-md.
Good question! I've gone back and forth on the question of whether to keep things in real space or unit cube and I'm very open to changing things around. One advantage of the current architecture is that none of the code relies on the implementation of the spaces, so if you want to use a version of periodic_general
that keeps everything in real space, the simulations, neural networks, and energy functions should just work.
A few notes, perhaps, on why I favor the unit-cube implementation at the moment:
- It seems to me that when writing a simulation with fixed strain (e.g. shear), it is a bit nicer to keep things in the unit cube because then you can write a simulation as,
without having to worry about re-mapping the particle positions.
for t in range(steps): strain = t / steps * strain_rate box = np.array([[box_size, strain], [0.0, box_size]]) state = step_fn(state, box=box)
- Although I'm not quite happy with the current interaction, it seems like one needs to project into the unit-cube before doing spatial partitioning. If particle positions are stored in the unit cube then there isn't as much mental overhead when combining
periodic_general
with cell lists. If particles are stored in real space, I don't see a way of avoiding putting some burden on the user to get this right. - Last time I tried to reason about this, it seemed like there was some extra efficiency to storing in the unit-cube (e.g. you only have a single transformation call to compute displacements rather than three). I know LAMMPS stores positions in the unit-cube, so that was also a source of inspiration.
However, as I said above I'm not very confident in this decision and so I'm very open to changing the design. It also would be completely fine for you all to use a periodic general that stored particle positions in real space. As long as you were a bit careful about the interactions with cell lists, nothing else should require special care. I'm quite sure that everything will work OK from the autodiff perspective and JAX MD is agnostic about the form of displacement
/ shift
functions.
from jax-md.
Thanks for the clarification! These seem like fine reasons, especially if you do the whole MD within jax-md
and entirely in fractional coordinates. From an implementation perspective it seems more convenient to work in fractional coordinates as much as possible. Our use case with using ase
as an external MD driver might be a bit special in that regard, since that requires all real-space coordinates.
It might be worth mentioning the reasoning explicitly in the docs, and potentially giving an example of transforming into fractional coordinates from real space. In the end it's all very straightforward, but it took me a while to wrap my head around the "unit hypercube". ;)
from jax-md.
Good point! I also think you're right that "unit_cube" is a bit nonstandard. Here is an example of a periodic_general
function I've been playing around with that has an option to either use real space or fractional coordinates. I think there's no downside to including both. Still have to write docstrings and utilities, but I pretty closely checked that it correctly produces energies, stresses, pressure, and elastic constants in both modes. Let me know if you find any issues. I'll write back here when it's checked in.
from jax-md.
Thanks for reaching out! It seems that I copied an old version of the new periodic_general
above, my apologies! I really need to get the code checked in, but I have had some technical issues lately (computer broke) and so there has been a little lag.
Here is the correct version, let me know if anything still seems out of the ordinary. I'll try to check in the code by Monday so that it's tested.
from jax_md import space
from jax import custom_jvp
from jax import lax
from functools import partial
periodic_displacement = space.periodic_displacement
pairwise_displacement = space.pairwise_displacement
periodic_shift = space.periodic_shift
f32 = np.float32
def inverse(box):
if np.isscalar(box) or box.size == 1:
return 1 / box
elif box.ndim == 1:
return 1 / box
elif box.ndim == 2:
return np.linalg.inv(box)
raise ValueError()
def get_free_indices(n):
return ''.join([chr(ord('a') + i) for i in range(n)])
@custom_jvp
def transform(box, R):
if np.isscalar(box) or box.size == 1:
return R * box
elif box.ndim == 1:
indices = get_free_indices(R.ndim - 1) + 'i'
return np.einsum(f'i,{indices}->{indices}', box, R)
elif box.ndim == 2:
free_indices = get_free_indices(R.ndim - 1)
left_indices = free_indices + 'j'
right_indices = free_indices + 'i'
return np.einsum(f'ij,{left_indices}->{right_indices}', box, R)
raise ValueError()
@transform.defjvp
def transform_jvp(primals, tangents):
box, R = primals
dbox, dR = tangents
return (transform(box, R), dR + transform(dbox, R))
def periodic_general(box, fractional_coordinates=True, wrapped=True):
inv_box = inverse(box)
def displacement_fn(Ra, Rb, **kwargs):
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
if not fractional_coordinates:
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
if not fractional_coordinates:
Ra = transform(_inv_box, Ra)
Rb = transform(_inv_box, Rb)
dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
return transform(_box, dR)
def u(R, dR):
if wrapped:
return periodic_shift(f32(1.0), R, dR)
return R + dR
def shift_fn(R, dR, **kwargs):
if not fractional_coordinates and not wrapped:
return R + dR
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
dR = transform(_inv_box, dR)
if not fractional_coordinates:
R = transform(_inv_box, R)
R = u(R, dR)
if not fractional_coordinates:
R = transform(_box, R)
return R
return displacement_fn, shift_fn
from jax-md.
So, it looks like the "callable T" route doesn't play well with jit
, Fabian spent some time testing it. Once the displacement gets combined with an energy function, something seems to prevent jit
from properly working.
from jax-md.
Hi Samuel,
Thanks for the quick and thorough reply, and good luck with ICML!
Fabian, or I, will reply separately to the stress-related points once we've had a closer look! At first glance this looks like precisely what we need, which is great.
On the periodic_general
issue, I have a general question: Wouldn't it be more convenient to work entirely in real space, including positions, and treat the transformation into the unit cube as a purely internal intermediate step? In other words, you make the transformation to scaled coordinates part of the displacement_fn
and shift_fn
, and purely work in "real space" when it comes to input and output. That way, I think one can entirely sidestep the issue of having some quantities be in fractional coordinates and some in real coordinates. I don't have a great intuition for autograd
yet, but it seems like this would also apply for gradients?
I'll have a closer look at the actual implementation you posted soon, I'm looking forward to it.
from jax-md.
This new periodic_general box looks very useful to me, both with and without fractional coordinates.
For some of my examples using the new periodic_general
with fractional_coordinates = True
works just fine, however for others I think I get incorrect results.
With fractional_coordinates = False
, I frequently get incorrect results - maybe there's an error on my side.
I've written a small example where the standard periodic
box and the new periodic_general
with fractional_coordinates = True
works, but fractional_coordinates = False
diverges.
Interestingly, all boxes give the same energy initially. This could suggest that the error occurs either in the backward pass or when updating particle positions (or I'm just using fractional_coordinates = False
incorrectly).
I appreciate any insights!
from jax_md import space, energy, simulate, quantity
from jax import custom_jvp, jit, random, lax
import jax.numpy as np
import numpy as onp
"""new implementation of periodic_general from issue 116"""
periodic_displacement = space.periodic_displacement
pairwise_displacement = space.pairwise_displacement
periodic_shift = space.periodic_shift
f32 = np.float32
def inverse(box):
if np.isscalar(box) or box.size == 1:
return 1 / box
elif box.ndim == 1:
return 1 / box
elif box.ndim == 2:
return np.linalg.inv(box)
raise ValueError()
def get_free_indices(n):
return ''.join([chr(ord('a') + i) for i in range(n)])
def base_transform(box, R):
if np.isscalar(box) or box.size == 1:
return R * box
elif box.ndim == 1:
indices = get_free_indices(R.ndim - 1) + 'i'
return np.einsum(f'i,{indices}->{indices}', box, R)
elif box.ndim == 2:
free_indices = get_free_indices(R.ndim - 1)
left_indices = free_indices + 'j'
right_indices = free_indices + 'i'
return np.einsum(f'ij,{left_indices}->{right_indices}', box, R)
raise ValueError()
@custom_jvp
def transform_without_tangents(box, R):
return base_transform(box, R)
@transform_without_tangents.defjvp
def transform_without_tangents_jvp(primals, tangents):
box, R = primals
dbox, dR = tangents
return (transform_without_tangents(box, R),
dR + transform_without_tangents(dbox, R))
def transform(box, R, fractional_coordinates=True):
if not fractional_coordinates:
return base_transform(box, R)
return transform_without_tangents(box, R)
def periodic_general(box, fractional_coordinates=True, wrapped=True):
inv_box = inverse(box)
def displacement_fn(Ra, Rb, **kwargs):
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
if not fractional_coordinates:
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
if not fractional_coordinates:
Ra = transform(_inv_box, Ra)
Rb = transform(_inv_box, Rb)
dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
return transform(_box, dR, fractional_coordinates=fractional_coordinates)
def u(R, dR):
if wrapped:
return periodic_shift(f32(1.0), R, dR)
return R + dR
def shift_fn(R, dR, **kwargs):
if not fractional_coordinates and not wrapped:
return R + dR
_box, _inv_box = box, inv_box
if 'box' in kwargs:
_box = kwargs['box']
_inv_box = inverse(_box)
if 'new_box' in kwargs:
_box = kwargs['new_box']
dR = transform(_inv_box, dR, fractional_coordinates=fractional_coordinates)
if not fractional_coordinates:
R = transform(_inv_box, R)
R = u(R, dR)
if not fractional_coordinates:
R = transform(_box, R)
return R
return displacement_fn, shift_fn
"""LJ system adapted from nve_neighbor_list jupyter notebook"""
Nx = particles_per_side = 80
spacing = np.float32(1.25)
side_length = Nx * spacing
R = onp.stack([onp.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing
R = np.array(R, np.float64)
# standard box works, gives stable temperatures below 1
# periodic general with fractional_coordinates=True also works
# periodic general with fractional_coordinates=False quickly diverges
# switch between different boxes:
standard_box = False
fractional_coordinates = False
box = np.ones(2) * side_length # standard definition of rectangular box
if standard_box:
displacement, shift = space.periodic(box)
else:
box = np.array([[box[0], 0.], [0., box[1]]]) # same box, only represented as tensor
displacement, shift = periodic_general(box, fractional_coordinates=fractional_coordinates)
if fractional_coordinates: # scale R to unit hypercube
inv_box = inverse(box)
R = np.dot(R, inv_box)
energy_fn = jit(energy.lennard_jones_pair(displacement))
print('E = {}'.format(energy_fn(R))) # energies are initially the same for all boxes! -11525,65
init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R)
body_fn = lambda _, state: (apply_fn(state))
step = 0
while step < 30:
state = lax.fori_loop(0, 100, body_fn, state)
print('Temperature at step', step, ':', quantity.temperature(state.velocity, state.mass))
step += 1
from jax-md.
Related Issues (20)
- nequip uses legacy e3nn-jax modules HOT 3
- module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
- Inconsistency using NpT and space.periodic_general
- Neighborlist shape changes when updated with 2D box
- Question: Particularities of Autodifferentiation for Forces
- Cannot import 'FunctionalFullyConnectedTensorProduct' from 'e3nn_jax' HOT 1
- test_nequip_silicon in energy_test.py is broken.
- test_nve_2d_neighbor_list_multi_atom_species in rigid_body_test.py is broken HOT 1
- documentation not compiling 2
- Elasticity calculations involving rigid bodies
- Out Of Memory issue during neighbor list generation
- Inconsistence of periodic and periodic_general
- equivariant_neural_networks notebook is broken
- Question about npt simulation HOT 2
- AttributeError: module 'jax.random' has no attribute 'KeyArray' HOT 1
- Neighbor lists are broken for rigid bodies.
- PME Energies
- AttributeError: module 'jax' has no attribute 'linear_util'
- ImportError in the example notebooks
- Run sample notebooks as of 13 April 2024. HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax-md.