Git Product home page Git Product logo

Comments (6)

sschoenholz avatar sschoenholz commented on May 6, 2024

Thanks for the report! The issue is that if the variance of an input (here taken to mean a single pixel for a single datapoint) is zero then it can cause NaNs. There is a simple fix to just add a small stability term to our normalization. I expect us to push a fix today.

from neural-tangents.

liutianlin0121 avatar liutianlin0121 commented on May 6, 2024

Thanks so much for your reply!

I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:


import tensorflow as tf
import numpy as np
from jax import random
from neural_tangents import stax

mnist = tf.keras.datasets.mnist

(x_train, _), (_, _) = mnist.load_data()


x_train = x_train / 255.0 # normalize the input values to values in (0, 1)

x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples. 


# standardize the data 
mean = np.mean(x_train)
std = np.std(x_train)
x_train_dense = (x_train - mean) / std

x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1])  # dense input samples

# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
     stax.Conv(128, (3, 3)),
     stax.Relu(),
     stax.Flatten(),
     stax.Dense(10) )

print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan

print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine

The output of the above script should be like:

NTK evaluated w/ sparse MNIST images: 
 [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
NTK evaluated w/ dense, standardized MNIST images: 
 [[1.1637697  0.6116184  0.21783468]
 [0.6116184  1.3009455  0.213599  ]
 [0.21783468 0.213599   0.79291606]]

So, without standardization, the sparse MNIST images seem to cause the nan problem. The standardized MNIST images with zero mean actually seem to solve the problem.

Thanks for your time!

from neural-tangents.

SiuMath avatar SiuMath commented on May 6, 2024

from neural-tangents.

liutianlin0121 avatar liutianlin0121 commented on May 6, 2024

@SiuMath @sschoenholz Many thanks for your explanations! Previously I misunderstood the variance we are talking about here as the one defined across multiple samples for a single pixel :)

from neural-tangents.

romanngg avatar romanngg commented on May 6, 2024

FYI, I believe Sam has fixed it 0e92b0f!

from neural-tangents.

liutianlin0121 avatar liutianlin0121 commented on May 6, 2024

@romanngg many thanks!!

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.