Git Product home page Git Product logo

Comments (4)

sschoenholz avatar sschoenholz commented on April 28, 2024

Thanks for reporting this! I agree, it would be nicer if potentials didn't interfere with NaN checking.

I think the jit vs non-jit behavior is expected. JAX's NaN checking occurs op-by-op, except when it encounters a jitted block of code. If JAX encounters a JIT block then it will execute the whole JIT block and check for NaNs in the output. If it finds NaNs in the outputs it will re-execute the JIT block op-by-op to find the source of the NaNs. Since the NaNs never "escape" the LJ potential, the NaN detector only goes off when the code isn't JIT.

Actually, I think the solution to this behavior is just to JIT all of the potential functions. I'll think about this a little bit more and circle back with the change.

from jax-md.

MaxiLechner avatar MaxiLechner commented on April 28, 2024

Yes that should do it. If I understand the Debugging NaN's section here https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-NaNs then lennard_jones will never be de-optimized if a nan shows up after lennard-jones has already been called. On the other hand lennard-jones will simply remove any nan's that are put into it.

The reason I even found this issue is because a scan was returning nans and when setting the jax_debug_nans flag to true then suddenly computing forces would produce nans. Just as before this is also solved by jitting the force.

import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
config.update("jax_debug_nans", True)

from jax import random, grad, jit

from jax_md import space, energy, quantity

N = 5

dimension = 2
box_size = 12
displacement, shift = space.periodic(box_size) 

key = random.PRNGKey(0)
key, split = random.split(key)
R = random.uniform(key, (N,dimension), minval=0.0, maxval=box_size, dtype=jnp.float64) 

energy_fn = energy.soft_sphere_pair(displacement)
force_fn = quantity.force(energy_fn)

#Works
print(jit(force_fn)(R))
#Doesn't work
print(force_fn(R))

I have tried to simplify the issue and submitted a bug report google/jax#5698.

from jax-md.

MaxiLechner avatar MaxiLechner commented on April 28, 2024

The helpful folks from jax helped me figure out the issue. It is due to the fact that util.safe_mask only masks out nans. They of course still exist and are caught by the debug flag. I wonder if there is any harm done by simply decorating safe_mask with jit?

from jax-md.

sschoenholz avatar sschoenholz commented on April 28, 2024

Good idea! I'll make the change in the next version. Thanks for the suggestion.

from jax-md.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.