Comments (6)
Thanks for the report! The issue is that if the variance of an input (here taken to mean a single pixel for a single datapoint) is zero then it can cause NaNs. There is a simple fix to just add a small stability term to our normalization. I expect us to push a fix today.
from neural-tangents.
Thanks so much for your reply!
I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan
problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:
import tensorflow as tf
import numpy as np
from jax import random
from neural_tangents import stax
mnist = tf.keras.datasets.mnist
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0 # normalize the input values to values in (0, 1)
x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples.
# standardize the data
mean = np.mean(x_train)
std = np.std(x_train)
x_train_dense = (x_train - mean) / std
x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1]) # dense input samples
# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(128, (3, 3)),
stax.Relu(),
stax.Flatten(),
stax.Dense(10) )
print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan
print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine
The output of the above script should be like:
NTK evaluated w/ sparse MNIST images:
[[nan nan nan]
[nan nan nan]
[nan nan nan]]
NTK evaluated w/ dense, standardized MNIST images:
[[1.1637697 0.6116184 0.21783468]
[0.6116184 1.3009455 0.213599 ]
[0.21783468 0.213599 0.79291606]]
So, without standardization, the sparse MNIST images seem to cause the nan
problem. The standardized MNIST images with zero mean actually seem to solve the problem.
Thanks for your time!
from neural-tangents.
from neural-tangents.
@SiuMath @sschoenholz Many thanks for your explanations! Previously I misunderstood the variance we are talking about here as the one defined across multiple samples for a single pixel :)
from neural-tangents.
FYI, I believe Sam has fixed it 0e92b0f!
from neural-tangents.
@romanngg many thanks!!
from neural-tangents.
Related Issues (20)
- Report of Abuse of misinformation in scientific research
- How to do Aggregate on a Graph whose nodes are all vectors HOT 6
- The analytical output of GP can not fit the result of NNGP generated by the nt.predict.gp_inference HOT 1
- Question: Relu Kernel Computation HOT 3
- Question: Connection MLE "parametrized" GP in infinite Width Limit vs minimizing MSE "parametrized" Kernel in infinite Width HOT 4
- Question regarding OOM issues HOT 3
- Question regarding lr in Neural Tangents Cookbook
- eNTK implementation uses deprecated xla attribute HOT 2
- Colab notebooks issue HOT 2
- How to obtain aleatoric uncertainty? HOT 2
- How to compute the empirical after kernel? HOT 1
- pip install issues HOT 2
- Erf function goes beyond [-1,1] HOT 2
- using stax.Cos(a=1.0, b=1.0, c=0.0) to get kernel from conv layer gives error HOT 2
- NTK is not PD
- stax.serial PSDness HOT 1
- How to use batch to gradient_descent_mse_ensemble ? HOT 1
- NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation HOT 7
- NKT_mean output Nan, when the number of training sample is increased HOT 3
- Inefficient jacobian computation for embedding layers. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neural-tangents.