Git Product home page Git Product logo

Comments (9)

sschoenholz avatar sschoenholz commented on September 7, 2024 4

Talking with Matt on the main issue confirmed what I sort of had suspected. When you jit compile a block of code (like the apply_fn, XLA optimizes the whole program and, in particular, it will rewrite certain numerical operations with a preference for making them faster. This can, however, lead to small numerical differences between jit and not jit versions of code. Here I plot the error as a function of time and, as you can see, the error on any given step is small and at the order we would expect (like 10^-30 which is the limit of double precision). It is still, nonetheless, interesting how large the differences become after accumulating for even a relatively modest number of simulation steps.

image

While it is important to validate, I think that statistically both trajectories should be equivalent although the trajectories themselves obviously can be quite different due to different numerical errors. I plan on filing a bug with the XLA team to see whether they can identify the optimization that's causing particular uncertainty (OTOH, it could also be the case that the XLA fused version is the more accurate one...).

from jax-md.

sschoenholz avatar sschoenholz commented on September 7, 2024

Hey! Thanks for reporting this, it definitely seems like a problem. I surfaced the issue to the JAX folks since they or the XLA team will likely have to find the offending code. Once the situation is resolved, we'll add a test to check for this at the jax md level.

from jax-md.

LyricZhao avatar LyricZhao commented on September 7, 2024

Hi, thanks for your fast and detailed response. I'm looking forward to the solution of the problem.

from jax-md.

smcantab avatar smcantab commented on September 7, 2024

@sschoenholz , if I can chime in briefly, we've been working on identifying and addressing some reproducibility issues in molecular simulation (not related to JAX MD specifically) because for our applications that involve frequent energy minimizations it matters. If you are coupled to a thermostat it doesn't really matter "because noise", and trajectories will be equivalent.

I am not sure if this is the origin of the inconsistency here, but usually the problem emerges when parallelization affects the order of summation. You'll most certainly run into this problem as you change the parallelization details (e.g. number of processes). To properly characterize the error I suggest using Units in the Last Place (ULP) that tell you how many floating points exist between the two numbers you are comparing.

There are two ways of fixing this particular problem: the hack is to increase the precision (128 floats), the solution with strong theoretical guarantees is to use exact summation algorithms (compensated summation won't suffice, we used R. Neal's algorithm). Both come at a computational cost. My student @spraharsh has been looking at this in detail.

from jax-md.

khansel01 avatar khansel01 commented on September 7, 2024

Hi guys, I stumbled upon some kind of similar precision problem with jit.jax. My gradient optimization algorithm came to a completely different solution when calculating ratios inside the jit-wrapped function instead of passing them as parameters to the function. Below is a simplified code snippet:

import jax
import jax.numpy as jnp

from jax.config import config
config.update("jax_enable_x64", True)

def f1(x, y, z):
    out = 0
    ratio = y/z
    for i in range(n_iter):
        out += x * ratio
    return out

def f2(x, ratio):
    out = 0
    for i in range(n_iter):
        out += x * ratio
    return out

# Jax Implementation: 
n = 10

jit_f1, jit_f2= jax.jit(f1), jax.jit(f2)

x, y, z = jnp.array([1.]), jnp.array([1.]), jnp.array([1.])

for i in range(n):
  f1_i = f1(x, y, z)[0]
  f1_jit_i = jit_f1(x, y, z)[0]
  f2_i = f2(x, y / z)[0]
  f2_jit_i = jit_f2(x, y / z)[0]

  print(f"{i:04d}) - {f1_i:017.12f} vs {f1_jit_i:017.12f} vs {f2_i:017.12f} vs {f2_jit_i:017.12f}", end="\t")
  print(f"Error = {jnp.abs(f1_i - f1_jit_i):.2e} - {jnp.abs(f2_i - f2_jit_i):.2e} - {jnp.abs(f1_i - f2_i):.2e} - {jnp.abs(f1_jit_i - f2_jit_i):.2e}")
  # print(f"{i:04d}) - {f1_i:017.12f} vs {f2_i:017.12f} vs {f3_i:017.12f}", end="\t")
  # print(f"Error = {jnp.abs(f1_i - f2_i):.2e} - {jnp.abs(f1_i - f3_i):.2e} - {jnp.abs(f3_i - f2_i):.2e}")

  y += 0.1
  z += 0.1

The output will be

0000) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0001) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 1.14e-13 - 0.00e+00 - 0.00e+00 - 1.14e-13
0002) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0003) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0004) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 1.14e-13 - 0.00e+00 - 0.00e+00 - 1.14e-13
0005) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0006) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0007) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0008) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0009) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000	Error = 1.14e-13 - 0.00e+00 - 0.00e+00 - 1.14e-13

You can see that small errors occur. However, after many iterations, these errors lead to completely different results. Is there a way to fix this problem? I have already read that it is more related to XLA.

from jax-md.

EyalRozenberg1 avatar EyalRozenberg1 commented on September 7, 2024

Hey @sschoenholz.
Was there any progress with solving this numerical error accumulation issue?
I suffer from this error aggregation that kills my results.

thanks, Eyal

from jax-md.

sschoenholz avatar sschoenholz commented on September 7, 2024

Hi @EyalRozenberg1, thanks for pinging this thread. It had slipped off my radar. My best understanding at the moment is that these issues are pretty difficult to resolve. These systems are inherently chaotic, in the sense that the Lyapunov exponent is typically greater than one. As such, one would expect small numerical errors to grow exponentially as the simulation progresses. This is more-or-less what we see in the plot above.

At the same time, XLA does not preserve exact reduction orders when jit is used (I gather this is important for the global optimizations that XLA performs). As such, we should expect small numerical differences between jit and no-jit codepaths. However, by the above argument, we expect these differences to grow exponentially as the system evolves.

Having said this, I would not expect observables to be sensitive to these issues. So quantities like average energy, correlation functions, pressure, etc... should be stable wrt jit/ no-jit. Can you describe your problem in more detail? What results are getting killed by numerical errors? It is possible there are other issues in JAX MD that are causing the problems you're facing and I'd love to fix them.

from jax-md.

EyalRozenberg1 avatar EyalRozenberg1 commented on September 7, 2024

Thank you for your reply, @sschoenholz.
I have a normalizing flow model with a varying number of layers. In order to get a better approximation, I have to increase the number of layers. The layers are very simple invertible transformations. The numerical error accumulation occurs when the number of layers is increased. (30 layers is the maximum value I was able to safely use.)

from jax-md.

sschoenholz avatar sschoenholz commented on September 7, 2024

Hey Eyal, thanks for your reply. I have never actually tried implementing a normalizing flow. Out of curiosity, is this using jax-md or is it a neural network layer? If the latter, it might be worth looping in the JAX or FLAX folks by making an issue in one of those two repos. One point that I might make is that I think the number of layers you should be able to use ought to be related to how chaotic (as measured by the Lyuponov exponent) the normalizing flow layer is. Is it possible to initialize the layer differently to make it less chaotic (aka, picking parameters so that the layer-to-layer Jacobian determinant is smaller)?

Anyway, I'd love to help you get this working! Let's iterate on it.

from jax-md.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.