Comments (9)
Talking with Matt on the main issue confirmed what I sort of had suspected. When you jit
compile a block of code (like the apply_fn
, XLA optimizes the whole program and, in particular, it will rewrite certain numerical operations with a preference for making them faster. This can, however, lead to small numerical differences between jit
and not jit
versions of code. Here I plot the error as a function of time and, as you can see, the error on any given step is small and at the order we would expect (like 10^-30 which is the limit of double precision). It is still, nonetheless, interesting how large the differences become after accumulating for even a relatively modest number of simulation steps.
While it is important to validate, I think that statistically both trajectories should be equivalent although the trajectories themselves obviously can be quite different due to different numerical errors. I plan on filing a bug with the XLA team to see whether they can identify the optimization that's causing particular uncertainty (OTOH, it could also be the case that the XLA fused version is the more accurate one...).
from jax-md.
Hey! Thanks for reporting this, it definitely seems like a problem. I surfaced the issue to the JAX folks since they or the XLA team will likely have to find the offending code. Once the situation is resolved, we'll add a test to check for this at the jax md level.
from jax-md.
Hi, thanks for your fast and detailed response. I'm looking forward to the solution of the problem.
from jax-md.
@sschoenholz , if I can chime in briefly, we've been working on identifying and addressing some reproducibility issues in molecular simulation (not related to JAX MD specifically) because for our applications that involve frequent energy minimizations it matters. If you are coupled to a thermostat it doesn't really matter "because noise", and trajectories will be equivalent.
I am not sure if this is the origin of the inconsistency here, but usually the problem emerges when parallelization affects the order of summation. You'll most certainly run into this problem as you change the parallelization details (e.g. number of processes). To properly characterize the error I suggest using Units in the Last Place (ULP) that tell you how many floating points exist between the two numbers you are comparing.
There are two ways of fixing this particular problem: the hack is to increase the precision (128 floats), the solution with strong theoretical guarantees is to use exact summation algorithms (compensated summation won't suffice, we used R. Neal's algorithm). Both come at a computational cost. My student @spraharsh has been looking at this in detail.
from jax-md.
Hi guys, I stumbled upon some kind of similar precision problem with jit.jax. My gradient optimization algorithm came to a completely different solution when calculating ratios inside the jit-wrapped function instead of passing them as parameters to the function. Below is a simplified code snippet:
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
def f1(x, y, z):
out = 0
ratio = y/z
for i in range(n_iter):
out += x * ratio
return out
def f2(x, ratio):
out = 0
for i in range(n_iter):
out += x * ratio
return out
# Jax Implementation:
n = 10
jit_f1, jit_f2= jax.jit(f1), jax.jit(f2)
x, y, z = jnp.array([1.]), jnp.array([1.]), jnp.array([1.])
for i in range(n):
f1_i = f1(x, y, z)[0]
f1_jit_i = jit_f1(x, y, z)[0]
f2_i = f2(x, y / z)[0]
f2_jit_i = jit_f2(x, y / z)[0]
print(f"{i:04d}) - {f1_i:017.12f} vs {f1_jit_i:017.12f} vs {f2_i:017.12f} vs {f2_jit_i:017.12f}", end="\t")
print(f"Error = {jnp.abs(f1_i - f1_jit_i):.2e} - {jnp.abs(f2_i - f2_jit_i):.2e} - {jnp.abs(f1_i - f2_i):.2e} - {jnp.abs(f1_jit_i - f2_jit_i):.2e}")
# print(f"{i:04d}) - {f1_i:017.12f} vs {f2_i:017.12f} vs {f3_i:017.12f}", end="\t")
# print(f"Error = {jnp.abs(f1_i - f2_i):.2e} - {jnp.abs(f1_i - f3_i):.2e} - {jnp.abs(f3_i - f2_i):.2e}")
y += 0.1
z += 0.1
The output will be
0000) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0001) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 1.14e-13 - 0.00e+00 - 0.00e+00 - 1.14e-13
0002) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0003) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0004) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 1.14e-13 - 0.00e+00 - 0.00e+00 - 1.14e-13
0005) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0006) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0007) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0008) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 0.00e+00 - 0.00e+00 - 0.00e+00 - 0.00e+00
0009) - 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 vs 1000.000000000000 Error = 1.14e-13 - 0.00e+00 - 0.00e+00 - 1.14e-13
You can see that small errors occur. However, after many iterations, these errors lead to completely different results. Is there a way to fix this problem? I have already read that it is more related to XLA.
from jax-md.
Hey @sschoenholz.
Was there any progress with solving this numerical error accumulation issue?
I suffer from this error aggregation that kills my results.
thanks, Eyal
from jax-md.
Hi @EyalRozenberg1, thanks for pinging this thread. It had slipped off my radar. My best understanding at the moment is that these issues are pretty difficult to resolve. These systems are inherently chaotic, in the sense that the Lyapunov exponent is typically greater than one. As such, one would expect small numerical errors to grow exponentially as the simulation progresses. This is more-or-less what we see in the plot above.
At the same time, XLA does not preserve exact reduction orders when jit
is used (I gather this is important for the global optimizations that XLA performs). As such, we should expect small numerical differences between jit
and no-jit
codepaths. However, by the above argument, we expect these differences to grow exponentially as the system evolves.
Having said this, I would not expect observables to be sensitive to these issues. So quantities like average energy, correlation functions, pressure, etc... should be stable wrt jit
/ no-jit
. Can you describe your problem in more detail? What results are getting killed by numerical errors? It is possible there are other issues in JAX MD that are causing the problems you're facing and I'd love to fix them.
from jax-md.
Thank you for your reply, @sschoenholz.
I have a normalizing flow model with a varying number of layers. In order to get a better approximation, I have to increase the number of layers. The layers are very simple invertible transformations. The numerical error accumulation occurs when the number of layers is increased. (30 layers is the maximum value I was able to safely use.)
from jax-md.
Hey Eyal, thanks for your reply. I have never actually tried implementing a normalizing flow. Out of curiosity, is this using jax-md or is it a neural network layer? If the latter, it might be worth looping in the JAX or FLAX folks by making an issue in one of those two repos. One point that I might make is that I think the number of layers you should be able to use ought to be related to how chaotic (as measured by the Lyuponov exponent) the normalizing flow layer is. Is it possible to initialize the layer differently to make it less chaotic (aka, picking parameters so that the layer-to-layer Jacobian determinant is smaller)?
Anyway, I'd love to help you get this working! Let's iterate on it.
from jax-md.
Related Issues (20)
- test_nequip_silicon in energy_test.py is broken.
- test_nve_2d_neighbor_list_multi_atom_species in rigid_body_test.py is broken HOT 1
- documentation not compiling 2
- Elasticity calculations involving rigid bodies
- Out Of Memory issue during neighbor list generation
- Inconsistence of periodic and periodic_general
- equivariant_neural_networks notebook is broken
- Question about npt simulation HOT 2
- AttributeError: module 'jax.random' has no attribute 'KeyArray' HOT 1
- Neighbor lists are broken for rigid bodies.
- PME Energies
- AttributeError: module 'jax' has no attribute 'linear_util'
- ImportError in the example notebooks HOT 1
- Run sample notebooks as of 13 April 2024. HOT 2
- Potential Definition Discrepancy with Stress Calculation
- future directions to improve jax-md's performance?
- Error when running flocking.ipynb HOT 1
- How to generate equivalent NequIP model in JAX-MD
- test_coulomb_cubeions fails
- Question about correctly implementing custom non-conservative force function HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax-md.