Comments (4)
Thanks for reporting this! I agree, it would be nicer if potentials didn't interfere with NaN checking.
I think the jit vs non-jit behavior is expected. JAX's NaN checking occurs op-by-op, except when it encounters a jitted block of code. If JAX encounters a JIT block then it will execute the whole JIT block and check for NaNs in the output. If it finds NaNs in the outputs it will re-execute the JIT block op-by-op to find the source of the NaNs. Since the NaNs never "escape" the LJ potential, the NaN detector only goes off when the code isn't JIT.
Actually, I think the solution to this behavior is just to JIT all of the potential functions. I'll think about this a little bit more and circle back with the change.
from jax-md.
Yes that should do it. If I understand the Debugging NaN's section here https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-NaNs then lennard_jones will never be de-optimized if a nan shows up after lennard-jones has already been called. On the other hand lennard-jones will simply remove any nan's that are put into it.
The reason I even found this issue is because a scan was returning nans and when setting the jax_debug_nans flag to true then suddenly computing forces would produce nans. Just as before this is also solved by jitting the force.
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
config.update("jax_debug_nans", True)
from jax import random, grad, jit
from jax_md import space, energy, quantity
N = 5
dimension = 2
box_size = 12
displacement, shift = space.periodic(box_size)
key = random.PRNGKey(0)
key, split = random.split(key)
R = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=jnp.float64)
energy_fn = energy.soft_sphere_pair(displacement)
force_fn = quantity.force(energy_fn)
#Works
print(jit(force_fn)(R))
#Doesn't work
print(force_fn(R))
I have tried to simplify the issue and submitted a bug report google/jax#5698.
from jax-md.
The helpful folks from jax helped me figure out the issue. It is due to the fact that util.safe_mask
only masks out nans. They of course still exist and are caught by the debug flag. I wonder if there is any harm done by simply decorating safe_mask
with jit
?
from jax-md.
Good idea! I'll make the change in the next version. Thanks for the suggestion.
from jax-md.
Related Issues (20)
- NaNs for Lennard Jones potential gradients. HOT 1
- Error importing jax_md
- documentation not compiling HOT 1
- GPU memory leak when using soft_sphere_neighbor_list with epsilon species tensor
- FireDescent should use velocity and not momentum when calculating P
- Proposal: Extending Jax MD with Monte Carlo Capabilities and Bonded Potentials HOT 7
- nequip uses legacy e3nn-jax modules HOT 3
- module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
- Inconsistency using NpT and space.periodic_general
- Neighborlist shape changes when updated with 2D box
- Question: Particularities of Autodifferentiation for Forces
- Cannot import 'FunctionalFullyConnectedTensorProduct' from 'e3nn_jax' HOT 1
- test_nequip_silicon in energy_test.py is broken.
- test_nve_2d_neighbor_list_multi_atom_species in rigid_body_test.py is broken HOT 1
- documentation not compiling 2
- Elasticity calculations involving rigid bodies
- Out Of Memory issue during neighbor list generation
- Inconsistence of periodic and periodic_general
- equivariant_neural_networks notebook is broken
- Question about npt simulation HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax-md.