Git Product home page Git Product logo

Comments (9)

sschoenholz avatar sschoenholz commented on April 28, 2024 4

Hey Marcel,

Thanks for the great questions! Sorry about the delay in response; it's been a busy week pre-ICML.

I have been thinking about these issues quite a bit, and actually have had a replacement for periodic_general since the fall; but I was waiting until I had finished a refactor of the NVT simulations (along with new NPT simulations) to check it in.

New Periodic General

In any case, here is the new version of the code. Instead of passing the transformation as a callable (which was originally designed to mimic JAX's optimizers) you can now override the box directly. For example,

displacement_fn, shift_fn = periodic_general(box)
energy_fn = energy.soft_sphere_pair(displacement)
E = energy_fn(position)  # Will compute energy using `box`.
E_new = energy_fn(position, box=new_box)  # Will compute energy using `new_box`.

As you suggest, this makes it easy to leverage autodiff to compute the stress tensor:

deformation_energy = lambda epsilon: energy_fn(position, box=box + epsilon)
stress_tensor = grad(deformation_energy)(np.zeros_like(box)) / np.linalg.det(box)

Here is a colab notebook that puts it all together.

Use of Custom JVP

One of the changes to the new periodic general implementation that you'll probably notice is that we changed the custom JVP so that gradient information about the box does flow through the transformation. I think the behavior of the current version of the code (zeroing out gradients wrt box) is incorrect. Having said that I'd love to explain why we use a custom JVP in the first place to get your thoughts.

periodic_general is supposed to take positions in the unit cube $U=[0, 1]^d$ and map them into a general triclinic box, $V$ when computing displacements. On the contrary, when shifting particle positions, it seems sensible for the shift function to take positions in the unit cube $U$ but forces (or displacements) that are already in the triclinic box $V$. This is important, for example, since for simulations like NVT we would like velocities to have the correct units when computing the temperature. However, with autodiff, if we take grad(energy_fn) to produce forces then the derivatives with be backpropagated all the way to the unit cube. Therefore, periodic_general uses a custom JVP that doesn't do this final step of the chain rule.

At the time of writing the original periodic_general I hadn't realized that we could simultaneously 1) not differentiate with respect to positions in the unit cube and 2) faithfully propagate gradients about the box. However, I have since realized that it is possible to both and the new version of the code takes this into account and should work properly.

Finally

Thanks for trying out jax md! Please let me know if you have any thoughts about how to improve the periodic_general code, if you notice bugs, or any other feedback about the library itself. I think the new version here is significantly better than the version that is currently checked in, but it's still not completely vetted and feedback would be greatly appreciated!

The new changes do break backward compatibility, so I have been waiting until after ICML to check them in, in the event that it breaks peoples code. However, hopefully this Friday I will officially update the periodic_general code.

from jax-md.

sschoenholz avatar sschoenholz commented on April 28, 2024 3

As of version 0.1.13 the new periodic_general function is in.

from jax-md.

sschoenholz avatar sschoenholz commented on April 28, 2024 1

Good question! I've gone back and forth on the question of whether to keep things in real space or unit cube and I'm very open to changing things around. One advantage of the current architecture is that none of the code relies on the implementation of the spaces, so if you want to use a version of periodic_general that keeps everything in real space, the simulations, neural networks, and energy functions should just work.

A few notes, perhaps, on why I favor the unit-cube implementation at the moment:

  1. It seems to me that when writing a simulation with fixed strain (e.g. shear), it is a bit nicer to keep things in the unit cube because then you can write a simulation as,
    for t in range(steps):
      strain = t / steps * strain_rate
      box = np.array([[box_size, strain], [0.0, box_size]])
      state = step_fn(state, box=box)
    without having to worry about re-mapping the particle positions.
  2. Although I'm not quite happy with the current interaction, it seems like one needs to project into the unit-cube before doing spatial partitioning. If particle positions are stored in the unit cube then there isn't as much mental overhead when combining periodic_general with cell lists. If particles are stored in real space, I don't see a way of avoiding putting some burden on the user to get this right.
  3. Last time I tried to reason about this, it seemed like there was some extra efficiency to storing in the unit-cube (e.g. you only have a single transformation call to compute displacements rather than three). I know LAMMPS stores positions in the unit-cube, so that was also a source of inspiration.

However, as I said above I'm not very confident in this decision and so I'm very open to changing the design. It also would be completely fine for you all to use a periodic general that stored particle positions in real space. As long as you were a bit careful about the interactions with cell lists, nothing else should require special care. I'm quite sure that everything will work OK from the autodiff perspective and JAX MD is agnostic about the form of displacement / shift functions.

from jax-md.

sirmarcel avatar sirmarcel commented on April 28, 2024 1

Thanks for the clarification! These seem like fine reasons, especially if you do the whole MD within jax-md and entirely in fractional coordinates. From an implementation perspective it seems more convenient to work in fractional coordinates as much as possible. Our use case with using ase as an external MD driver might be a bit special in that regard, since that requires all real-space coordinates.

It might be worth mentioning the reasoning explicitly in the docs, and potentially giving an example of transforming into fractional coordinates from real space. In the end it's all very straightforward, but it took me a while to wrap my head around the "unit hypercube". ;)

from jax-md.

sschoenholz avatar sschoenholz commented on April 28, 2024 1

Good point! I also think you're right that "unit_cube" is a bit nonstandard. Here is an example of a periodic_general function I've been playing around with that has an option to either use real space or fractional coordinates. I think there's no downside to including both. Still have to write docstrings and utilities, but I pretty closely checked that it correctly produces energies, stresses, pressure, and elastic constants in both modes. Let me know if you find any issues. I'll write back here when it's checked in.

from jax-md.

sschoenholz avatar sschoenholz commented on April 28, 2024 1

Thanks for reaching out! It seems that I copied an old version of the new periodic_general above, my apologies! I really need to get the code checked in, but I have had some technical issues lately (computer broke) and so there has been a little lag.

Here is the correct version, let me know if anything still seems out of the ordinary. I'll try to check in the code by Monday so that it's tested.

from jax_md import space
from jax import custom_jvp
from jax import lax
from functools import partial

periodic_displacement = space.periodic_displacement
pairwise_displacement = space.pairwise_displacement
periodic_shift = space.periodic_shift

f32 = np.float32

def inverse(box):
  if np.isscalar(box) or box.size == 1:
    return 1 / box
  elif box.ndim == 1:
    return 1 / box
  elif box.ndim == 2:
    return np.linalg.inv(box)
  
  raise ValueError()

def get_free_indices(n):
  return ''.join([chr(ord('a') + i) for i in range(n)])

@custom_jvp
def transform(box, R):
  if np.isscalar(box) or box.size == 1:
    return R * box
  elif box.ndim == 1:
    indices = get_free_indices(R.ndim - 1) + 'i'
    return np.einsum(f'i,{indices}->{indices}', box, R)
  elif box.ndim == 2:
    free_indices = get_free_indices(R.ndim - 1)
    left_indices = free_indices + 'j'
    right_indices = free_indices + 'i'
    return np.einsum(f'ij,{left_indices}->{right_indices}', box, R)
  raise ValueError()

@transform.defjvp
def transform_jvp(primals, tangents):
  box, R = primals
  dbox, dR = tangents
  return (transform(box, R), dR + transform(dbox, R))
  
def periodic_general(box, fractional_coordinates=True, wrapped=True):
  inv_box = inverse(box)

  def displacement_fn(Ra, Rb, **kwargs):
    _box, _inv_box = box, inv_box

    if 'box' in kwargs:      
      _box = kwargs['box']

      if not fractional_coordinates: 
        _inv_box = inverse(_box)

    if 'new_box' in kwargs:
      _box = kwargs['new_box']

    if not fractional_coordinates:
      Ra = transform(_inv_box, Ra)
      Rb = transform(_inv_box, Rb)

    dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
    return transform(_box, dR) 

  def u(R, dR):
    if wrapped:
      return periodic_shift(f32(1.0), R, dR)
    return R + dR

  def shift_fn(R, dR, **kwargs):
    if not fractional_coordinates and not wrapped:
      return R + dR

    _box, _inv_box = box, inv_box
    if 'box' in kwargs:
      _box = kwargs['box']
      _inv_box = inverse(_box)

    if 'new_box' in kwargs:
      _box = kwargs['new_box']

    dR = transform(_inv_box, dR)
    if not fractional_coordinates:
      R = transform(_inv_box, R)

    R = u(R, dR)

    if not fractional_coordinates:
      R = transform(_box, R)
    return R
  
  return displacement_fn, shift_fn

from jax-md.

sirmarcel avatar sirmarcel commented on April 28, 2024

So, it looks like the "callable T" route doesn't play well with jit, Fabian spent some time testing it. Once the displacement gets combined with an energy function, something seems to prevent jit from properly working.

from jax-md.

sirmarcel avatar sirmarcel commented on April 28, 2024

Hi Samuel,

Thanks for the quick and thorough reply, and good luck with ICML!

Fabian, or I, will reply separately to the stress-related points once we've had a closer look! At first glance this looks like precisely what we need, which is great.

On the periodic_general issue, I have a general question: Wouldn't it be more convenient to work entirely in real space, including positions, and treat the transformation into the unit cube as a purely internal intermediate step? In other words, you make the transformation to scaled coordinates part of the displacement_fn and shift_fn, and purely work in "real space" when it comes to input and output. That way, I think one can entirely sidestep the issue of having some quantities be in fractional coordinates and some in real coordinates. I don't have a great intuition for autograd yet, but it seems like this would also apply for gradients?

I'll have a closer look at the actual implementation you posted soon, I'm looking forward to it.

from jax-md.

S-Thaler avatar S-Thaler commented on April 28, 2024

This new periodic_general box looks very useful to me, both with and without fractional coordinates.

For some of my examples using the new periodic_general with fractional_coordinates = True works just fine, however for others I think I get incorrect results.
With fractional_coordinates = False, I frequently get incorrect results - maybe there's an error on my side.

I've written a small example where the standard periodic box and the new periodic_general with fractional_coordinates = True works, but fractional_coordinates = False diverges.

Interestingly, all boxes give the same energy initially. This could suggest that the error occurs either in the backward pass or when updating particle positions (or I'm just using fractional_coordinates = False incorrectly).
I appreciate any insights!

from jax_md import space, energy, simulate, quantity
from jax import custom_jvp, jit, random, lax
import jax.numpy as np
import numpy as onp


"""new implementation of periodic_general from issue 116"""

periodic_displacement = space.periodic_displacement
pairwise_displacement = space.pairwise_displacement
periodic_shift = space.periodic_shift

f32 = np.float32


def inverse(box):
    if np.isscalar(box) or box.size == 1:
        return 1 / box
    elif box.ndim == 1:
        return 1 / box
    elif box.ndim == 2:
        return np.linalg.inv(box)

    raise ValueError()


def get_free_indices(n):
    return ''.join([chr(ord('a') + i) for i in range(n)])


def base_transform(box, R):
    if np.isscalar(box) or box.size == 1:
        return R * box
    elif box.ndim == 1:
        indices = get_free_indices(R.ndim - 1) + 'i'
        return np.einsum(f'i,{indices}->{indices}', box, R)
    elif box.ndim == 2:
        free_indices = get_free_indices(R.ndim - 1)
        left_indices = free_indices + 'j'
        right_indices = free_indices + 'i'
        return np.einsum(f'ij,{left_indices}->{right_indices}', box, R)
    raise ValueError()


@custom_jvp
def transform_without_tangents(box, R):
    return base_transform(box, R)


@transform_without_tangents.defjvp
def transform_without_tangents_jvp(primals, tangents):
    box, R = primals
    dbox, dR = tangents

    return (transform_without_tangents(box, R),
            dR + transform_without_tangents(dbox, R))


def transform(box, R, fractional_coordinates=True):
    if not fractional_coordinates:
        return base_transform(box, R)
    return transform_without_tangents(box, R)


def periodic_general(box, fractional_coordinates=True, wrapped=True):
    inv_box = inverse(box)

    def displacement_fn(Ra, Rb, **kwargs):
        _box, _inv_box = box, inv_box

        if 'box' in kwargs:
            _box = kwargs['box']

            if not fractional_coordinates:
                _inv_box = inverse(_box)

        if 'new_box' in kwargs:
            _box = kwargs['new_box']

        if not fractional_coordinates:
            Ra = transform(_inv_box, Ra)
            Rb = transform(_inv_box, Rb)

        dR = periodic_displacement(f32(1.0), pairwise_displacement(Ra, Rb))
        return transform(_box, dR, fractional_coordinates=fractional_coordinates)

    def u(R, dR):
        if wrapped:
            return periodic_shift(f32(1.0), R, dR)
        return R + dR

    def shift_fn(R, dR, **kwargs):
        if not fractional_coordinates and not wrapped:
            return R + dR

        _box, _inv_box = box, inv_box
        if 'box' in kwargs:
            _box = kwargs['box']
            _inv_box = inverse(_box)

        if 'new_box' in kwargs:
            _box = kwargs['new_box']

        dR = transform(_inv_box, dR, fractional_coordinates=fractional_coordinates)
        if not fractional_coordinates:
            R = transform(_inv_box, R)

        R = u(R, dR)

        if not fractional_coordinates:
            R = transform(_box, R)

        return R

    return displacement_fn, shift_fn


"""LJ system adapted from nve_neighbor_list jupyter notebook"""

Nx = particles_per_side = 80
spacing = np.float32(1.25)
side_length = Nx * spacing

R = onp.stack([onp.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing
R = np.array(R, np.float64)


# standard box works, gives stable temperatures below 1
# periodic general with fractional_coordinates=True also works
# periodic general with fractional_coordinates=False quickly diverges

# switch between different boxes:
standard_box = False
fractional_coordinates = False

box = np.ones(2) * side_length  # standard definition of rectangular box
if standard_box:
    displacement, shift = space.periodic(box)
else:
    box = np.array([[box[0], 0.], [0., box[1]]])  # same box, only represented as tensor
    displacement, shift = periodic_general(box, fractional_coordinates=fractional_coordinates)
    if fractional_coordinates:  # scale R to unit hypercube
        inv_box = inverse(box)
        R = np.dot(R, inv_box)


energy_fn = jit(energy.lennard_jones_pair(displacement))
print('E = {}'.format(energy_fn(R)))  # energies are initially the same for all boxes! -11525,65

init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
state = init_fn(random.PRNGKey(0), R)

body_fn = lambda _, state: (apply_fn(state))

step = 0
while step < 30:
    state = lax.fori_loop(0, 100, body_fn, state)
    print('Temperature at step', step, ':', quantity.temperature(state.velocity, state.mass))
    step += 1

from jax-md.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.