Git Product home page Git Product logo

google / neural-tangents Goto Github PK

View Code? Open in Web Editor NEW
2.2K 64.0 227.0 10.68 MB

Fast and Easy Infinite Neural Networks in Python

Home Page: https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html

License: Apache License 2.0

Python 21.43% Jupyter Notebook 78.57%
neural-networks infinite-networks gaussian-processes training-dynamics neural-tangents jax kernel deep-networks kernel-computation gradient-descent

neural-tangents's Introduction

Stand with Ukraine! πŸ‡ΊπŸ‡¦

Freedom of thought is fundamental to all of science. Right now, our freedom is being suppressed with bombing of civilians in Ukraine. Don't be against the war - fight against the war! supportukrainenow.org.

Neural Tangents

ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes

PyPI PyPI - Python Version Linux macOS Pytype Coverage Readthedocs

Overview

Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones. The library has been used in >100 papers.

Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture. See this listing of papers written by the creators of Neural Tangents which study the infinite width limit of neural networks.

Neural Tangents allows you to construct a neural network model from common building blocks like convolutions, pooling, residual connections, nonlinearities, and more, and obtain not only the finite model, but also the kernel function of the respective GP.

The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.

Neural Tangents is a work in progress. We happily welcome contributions!

Contents

Colab Notebooks

An easy way to get started with Neural Tangents is by playing around with the following interactive notebooks in Colaboratory. They demo the major features of Neural Tangents and show how it can be used in research.

Installation

To use GPU, first follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running

pip install jax jaxlib --upgrade

Once JAX is installed install Neural Tangents by running

pip install neural-tangents

or, to use the bleeding-edge version from GitHub source,

git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .

You can now run the examples and tests by calling:

pip install .[testing]
set -e; for f in examples/*.py; do python $f; done  # Run examples
set -e; for f in tests/*.py; do python $f; done  # Run tests

5-Minute intro

See this Colab for a detailed tutorial. Below is a very quick introduction.

Our library closely follows JAX's API for specifying neural networks, stax. In stax a network is defined by a pair of functions (init_fn, apply_fn) initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing its outputs y given inputs x.

from jax import random
from jax.example_libraries import stax

init_fn, apply_fn = stax.serial(
    stax.Dense(512), stax.Relu,
    stax.Dense(512), stax.Relu,
    stax.Dense(1)
)

key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)

y = apply_fn(params, x)  # (10, 1) jnp.ndarray outputs of the neural network

Neural Tangents is designed to serve as a drop-in replacement for stax, extending the (init_fn, apply_fn) tuple to a triple (init_fn, apply_fn, kernel_fn), where kernel_fn is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs x1 and x2.

from jax import random
from neural_tangents import stax

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512), stax.Relu(),
    stax.Dense(512), stax.Relu(),
    stax.Dense(1)
)

key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))

kernel = kernel_fn(x1, x2, 'nngp')

Note that kernel_fn can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network. The NTK corresponds to the (continuous) gradient descent trained infinite network. In the above example, we compute the NNGP kernel, but we could compute the NTK or both:

# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) jnp.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) jnp.ndarray

# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp  # True
both.ntk == ntk  # True

# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))

Additionally, if no third-argument is specified then the kernel_fn will return a Kernel namedtuple that contains additional metadata. This can be useful for composing applications of kernel_fn as follows:

kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)

Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:

import neural_tangents as nt

x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1))  # training targets

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
                                                      y_train)

y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) jnp.ndarray test predictions of an infinite Bayesian network

y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) jnp.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp  # True
both.ntk == y_test_ntk  # True

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))

Infinitely WideResnet

We can define a more complex, (infinitely) Wide Residual Network using the same nt.stax building blocks:

from neural_tangents import stax

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = stax.serial(
      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
  Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
      channels, (3, 3), strides, padding='SAME')
  return stax.serial(stax.FanOut(2),
                     stax.parallel(Main, Shortcut),
                     stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(16, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, int(16 * k)),
      WideResnetGroup(block_size, int(32 * k), (2, 2)),
      WideResnetGroup(block_size, int(64 * k), (2, 2)),
      stax.AvgPool((8, 8)),
      stax.Flatten(),
      stax.Dense(num_classes, 1., 0.))

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

Package description

The neural_tangents (nt) package contains the following modules and functions:

  • stax - primitives to construct neural networks like Conv, Relu, serial, parallel etc.

  • predict - predictions with infinite networks:

    • predict.gradient_descent_mse - inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (t=None) time. Computed in closed form.

    • predict.gradient_descent - inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver.

    • predict.gradient_descent_mse_ensemble - inference with an infinite ensemble of infinite width networks, either fully Bayesian (get='nngp') or inference with MSE loss using continuous gradient descent (get='ntk'). Finite-time Bayesian inference (e.g. t=1., get='nngp') is interpreted as gradient descent on the top layer only, since it converges to exact Gaussian process inference with NNGP (t=None, get='nngp'). Computed in closed form.

    • predict.gp_inference - exact closed form Gaussian process inference using NNGP (get='nngp'), NTK (get='ntk'), or both (get=('nngp', 'ntk')). Equivalent to predict.gradient_descent_mse_ensemble with t=None (infinite training time), but has a slightly different API (accepting precomputed kernel matrix k_train_train instead of kernel_fn and x_train).

  • monte_carlo_kernel_fn - compute a Monte Carlo kernel estimate of any (init_fn, apply_fn), not necessarily specified via nt.stax, enabling the kernel computation of infinite networks without closed-form expressions.

  • Tools to investigate training dynamics of wide but finite neural networks, like linearize, taylor_expand, empirical_kernel_fn and more. See Training dynamics of wide but finite networks for details.

Technical gotchas

We remark the following differences between our library and the JAX one.

  • All nt.stax layers are instantiated with a function call, i.e. nt.stax.Relu() vs jax.example_libraries.stax.Relu.
  • All layers with trainable parameters use the NTK parameterization by default. However, Dense and Conv layers also support the standard parameterization via a parameterization keyword argument.
  • nt.stax and jax.example_libraries.stax may have different layers and options available (for example nt.stax layers support CIRCULAR padding, have LayerNorm, but no BatchNorm.).

CPU and TPU performance

For CNNs w/ pooling, our CPU and TPU performance is suboptimal due to low core utilization (10-20%, looks like an XLA:CPU issue), and excessive padding respectively. We will look into improving performance, but recommend NVIDIA GPUs in the meantime. See Performance.

Training dynamics of wide but finite networks

The kernel of an infinite network kernel_fn(x1, x2).ntk combined with nt.predict.gradient_descent_mse together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss throughout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).

Weight space

Continuous gradient descent in an infinite network has been shown in to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.

For this, we provide two convenient functions:

  • nt.linearize, and
  • nt.taylor_expand,

which allow us to linearize or get an arbitrary-order Taylor expansion of any function apply_fn(params, x) around some initial parameters params_0 as apply_fn_lin = nt.linearize(apply_fn, params_0).

One can use apply_fn_lin(params, x) exactly as you would any other function (including as an input to JAX optimizers). This makes it easy to compare the training trajectory of neural networks with that of its linearization. Prior theory and experiments have examined the linearization of neural networks from inputs to logits or pre-activations, rather than from inputs to post-activations which are substantially more nonlinear.

Example:

import jax.numpy as jnp
import neural_tangents as nt

def apply_fn(params, x):
  W, b = params
  return jnp.dot(x, W) + b

W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))

apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = jnp.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2

x = jnp.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x)  # (3, 2) jnp.ndarray

Function space:

Outputs of a linearized model evolve identically to those of an infinite one but with a different kernel - precisely, the Neural Tangent Kernel evaluated on the specific apply_fn of the finite network given specific params_0 that the network is initialized with. For this we provide the nt.empirical_kernel_fn function that accepts any apply_fn and returns a kernel_fn(x1, x2, get, params) that allows to compute the empirical NTK and/or NNGP (based on get) kernels on specific params.

Example:

import jax.random as random
import jax.numpy as jnp
import neural_tangents as nt


def apply_fn(params, x):
  W, b = params
  return jnp.dot(x, W) + b


W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))
params = (W_0, b_0)

key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))

kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)

t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) jnp.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent

What to Expect

The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:

  1. Convergence as the network size increases.

    • For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).

    • For convolutional networks one generally observes reasonable agreement by the time the number of channels is 512.

  2. Convergence at small learning rates.

With a new model it is therefore advisable to start with large width on a small dataset using a small learning rate.

Performance

In the table below we measure time to compute a single NTK entry in a 21-layer CNN (3x3 filters, no strides, SAME padding, ReLU) on inputs of shape 3x32x32. Precisely:

layers = []
for _ in range(21):
  layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]

CNN with pooling

Top layer is stax.GlobalAvgPool():

_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
Platform Precision Milliseconds / NTK entry Max batch size (NxN)
CPU, >56 cores, >700 Gb RAM 32 112.90 >= 128
CPU, >56 cores, >700 Gb RAM 64 258.55 95 (fastest - 72)
TPU v2 32/16 3.2550 16
TPU v3 32/16 2.3022 24
NVIDIA P100 32 5.9433 26
NVIDIA P100 64 11.349 18
NVIDIA V100 32 2.7001 26
NVIDIA V100 64 6.2058 18

CNN without pooling

Top layer is stax.Flatten():

_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
Platform Precision Milliseconds / NTK entry Max batch size (NxN)
CPU, >56 cores, >700 Gb RAM 32 0.12013 2048 <= N < 4096 (fastest - 512)
CPU, >56 cores, >700 Gb RAM 64 0.3414 2048 <= N < 4096 (fastest - 256)
TPU v2 32/16 0.0015722 512 <= N < 1024
TPU v3 32/16 0.0010647 512 <= N < 1024
NVIDIA P100 32 0.015171 512 <= N < 1024
NVIDIA P100 64 0.019894 512 <= N < 1024
NVIDIA V100 32 0.0046510 512 <= N < 1024
NVIDIA V100 64 0.010822 512 <= N < 1024

Tested using version 0.2.1. All GPU results are per single accelerator. Note that runtime is proportional to the depth of your network. If your performance differs significantly, please file a bug!

Myrtle network

Colab notebook Performance Benchmark demonstrates how one would construct and benchmark kernels. To demonstrate flexibility, we took the Myrtle architecture as an example. With NVIDIA V100 64-bit precision, nt took 316/330/508 GPU-hours on full 60k CIFAR-10 dataset for Myrtle-5/7/10 kernels.

Citation

If you use the code in a publication, please cite our papers:

# Infinite width NTK/NNGP:
@inproceedings{neuraltangents2020,
    title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
    author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
    booktitle={International Conference on Learning Representations},
    year={2020},
    pdf={https://arxiv.org/abs/1912.02803},
    url={https://github.com/google/neural-tangents}
}

# Finite width, empirical NTK/NNGP:
@inproceedings{novak2022fast,
    title={Fast Finite Width Neural Tangent Kernel},
    author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
    booktitle={International Conference on Machine Learning},
    year={2022},
    pdf={https://arxiv.org/abs/2206.08720},
    url={https://github.com/google/neural-tangents}
}

# Attention and variable-length inputs:
@inproceedings{hron2020infinite,
    title={Infinite attention: NNGP and NTK for deep attention networks},
    author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak},
    booktitle={International Conference on Machine Learning},
    year={2020},
    pdf={https://arxiv.org/abs/2006.10540},
    url={https://github.com/google/neural-tangents}
}

# Infinite-width "standard" parameterization:
@misc{sohl2020on,
    title={On the infinite width limit of neural networks with a standard parameterization},
    author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee},
    publisher = {arXiv},
    year={2020},
    pdf={https://arxiv.org/abs/2001.07301},
    url={https://github.com/google/neural-tangents}
}

# Elementwise nonlinearities and sketching:
@inproceedings{han2022fast,
    title={Fast Neural Kernel Embeddings for General Activations},
    author={Insu Han and Amir Zandieh and Jaehoon Lee and Roman Novak and Lechao Xiao and Amin Karbasi},
    booktitle = {Advances in Neural Information Processing Systems},
    year={2022},
    pdf={https://arxiv.org/abs/2209.04121},
    url={https://github.com/google/neural-tangents}
}

neural-tangents's People

Contributors

alexalemi avatar bobby-he avatar darrenzhang01 avatar erikfrey avatar faizan-m avatar froystig avatar gnecula avatar hawkinsp avatar jaehlee avatar jekbradbury avatar jglaser avatar martindemello avatar mattjj avatar mmhamdy avatar rchen152 avatar romanngg avatar siumath avatar sohl-dickstein avatar sschoenholz avatar superbobry avatar tanyajainc137 avatar themrzmaster avatar yasamanb avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

neural-tangents's Issues

Training dynamics with a custom metric

I'm having fun playing with the Neural Tangents Cookbook.ipynb and I'd like to try extending it to multivariate regression. However, when I changed the output dimension of last layer in stax.serial, the dimensions of the predicted mean and predicted covariance remain the same. Why is this, and what do I need to change to extend to multivariate regression?

tuple index out of range error in `stax.Conv`

Any time I try to call the initialization function on a network containing a convolutional layer I get the same "tuple index out of range error".

Here is a minimum example using one of the code snippets provided in the preprint:

from neural_tangents import stax
from jax import random

key = random.PRNGKey(10)

def ConvolutionalNetwork(depth, W_std=1.0, b_std=0.0):
    layers = []
    for _ in range(depth):
        layers += [stax.Conv(1, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'), stax.Relu()]
    layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)]
    return stax.serial(*layers)

init_fn, apply_fn, kernel_fn = ConvolutionalNetwork(4)

x = random.normal(key, (10, 100))
init_fn(key, x.shape)

The same issue arises using the WideResNet code in the preprint as well, or while using Cifar-10 data. Does anyone have insight on this?

Thanks!

Flattening issue in predict.gradient descent?

Seems to be a bug in nt.predict.gradient_descent, perhaps related to flattening of inputs. Code snippet and stacktrace below.

Code snippet:

def ntk_loss(fx,y_hat):
  return -np.mean(np.sum(jstax.logsoftmax(beta*fx) * y_hat,axis=1))
 
g_dd = kernel_fn(x_train, x_train, 'ntk') # kernel_fn from nt.stax.serial
g_td = kernel_fn(x_test, x_train, 'ntk') # test and train numpy arrays
ntk_loss = scaled_loss_for_ntk(beta)
ntk_loss = jit(ntk_loss)
 
predict_fn = nt.predict.gradient_descent(g_dd, y_train, ntk_loss, g_td)
predict_fn(0.1,fx_train_initial,fx_test_initial)

Stacktrace of error:

ValueError                                Traceback (most recent call last)
<ipython-input-97-592ef42dc248> in <module>()
      5   ntk_outputs, ntk_loss_fn, ntk_acc_fn = get_ntk_dynamics(
      6       kernel_fn,x_train,x_test,y_train,
----> 7       y_test,fx_train_initial,fx_test_initial,beta)
      8   # get results
      9   train_loss = nnp.zeros(len(ts))
 
25 frames
<ipython-input-96-9b76a38cb837> in get_ntk_dynamics(kernel_fn, x_train, x_test, y_train, y_test, fx_train_initial, fx_test_initial, beta)
     22   print('NTK initial loss: {}'.format(ntk_loss(fx_train_initial,y_train)))
     23   predict_fn = nt.predict.gradient_descent(g_dd, y_train, ntk_loss, g_td)
---> 24   predict_fn(0.1,fx_train_initial,fx_test_initial)
     25 
     26   ntk_outputs = functools.partial(
 
google3/third_party/py/neural_tangents/predict.py in predict(dt, fx_train, fx_test)
    276       train_size, output_dim = fx_train.shape
    277       r.set_initial_value(fx, 0).set_f_params(train_size * output_dim)
--> 278       r.integrate(dt)
    279       fx = ufl(r.y)
    280 
 
google3/third_party/py/scipy/integrate/_ode.py in integrate(self, t, step, relax)
    430             self._y, self.t = mth(self.f, self.jac or (lambda: None),
    431                                   self._y, self.t, t,
--> 432                                   self.f_params, self.jac_params)
    433         except SystemError:
    434             # f2py issue with tuple returns, see ticket 1187.
 
google3/third_party/py/scipy/integrate/_ode.py in run(self, f, jac, y0, t0, t1, f_params, jac_params)
   1170     def run(self, f, jac, y0, t0, t1, f_params, jac_params):
   1171         x, y, iwork, istate = self.runner(*((f, t0, y0, t1) +
-> 1172                                           tuple(self.call_args) + (f_params,)))
   1173         self.istate = istate
   1174         if istate < 0:
 
google3/third_party/py/neural_tangents/predict.py in dfx_dt(unused_t, fx, train_size)
    266     def dfx_dt(unused_t, fx, train_size):
    267       fx_train = fx[:train_size]
--> 268       dfx_train = -ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
    269       dfx_test = -ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
    270       return np.concatenate((dfx_train, dfx_test), axis=0)
 
google3/third_party/py/jax/api.py in grad_f(*args, **kwargs)
    353   @wraps(fun, docstr=docstr, argnums=argnums)
    354   def grad_f(*args, **kwargs):
--> 355     _, g = value_and_grad_f(*args, **kwargs)
    356     return g
    357 
 
google3/third_party/py/jax/api.py in value_and_grad_f(*args, **kwargs)
    408     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    409     if not has_aux:
--> 410       ans, vjp_py = vjp(f_partial, *dyn_args)
    411     else:
    412       ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
 
google3/third_party/py/jax/api.py in vjp(fun, *primals, **kwargs)
   1267   if not has_aux:
   1268     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1269     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1270     out_tree = out_tree()
   1271   else:
 
google3/third_party/py/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    106 def vjp(traceable, primals, has_aux=False):
    107   if not has_aux:
--> 108     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    109   else:
    110     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
 
google3/third_party/py/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     95   _, in_tree = tree_flatten(((primals, primals), {}))
     96   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 97   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     98   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     99   aval_primals, const_primals = unzip2(pval_primals)
 
google3/third_party/py/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, **kwargs)
    313   with new_master(JaxprTrace) as master:
    314     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 315     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    316     assert not env
    317     del master
 
google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    151     gen = None
    152 
--> 153     ans = self.f(*args, **dict(self.params, **kwargs))
    154     del args
    155     while stack:
 
google3/third_party/py/jax/api.py in f_jitted(*args, **kwargs)
    148     _check_args(args_flat)
    149     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 150     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    151     return tree_unflatten(out_tree(), out)
    152 
 
google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)
--> 595     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    596   return apply_todos(env_trace_todo(), outs)
    597 
 
google3/third_party/py/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
    324     nonzero_tangents, in_tree_def = tree_flatten(tangents)
    325     f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), len(primals), in_tree_def)
--> 326     result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
    327     primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
    328     return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
 
google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)
--> 595     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    596   return apply_todos(env_trace_todo(), outs)
    597 
 
google3/third_party/py/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    113     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    114     fun, aux = partial_eval(f, self, in_pvs)
--> 115     out_flat = call_primitive.bind(fun, *in_consts, **params)
    116     out_pvs, jaxpr, env = aux()
    117     out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
 
google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
    590   if top_trace is None:
    591     with new_sublevel():
--> 592       outs = primitive.impl(f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)
 
google3/third_party/py/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    398   device = params['device']
    399   backend = params.get('backend', None)
--> 400   compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
    401   try:
    402     return compiled_fun(*args)
 
google3/third_party/py/jax/linear_util.py in memoized_fun(fun, *args)
    207       fun.populate_stores(stores)
    208     else:
--> 209       ans = call(fun, *args)
    210       cache[key] = (ans, fun.stores)
    211     return ans
 
google3/third_party/py/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *abstract_args)
    410   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    411   with core.new_master(pe.JaxprTrace, True) as master:
--> 412     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    413     assert not env  # no subtraces here
    414     del master, env
 
google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    151     gen = None
    152 
--> 153     ans = self.f(*args, **dict(self.params, **kwargs))
    154     del args
    155     while stack:
 
<ipython-input-96-9b76a38cb837> in ntk_loss(fx, y_hat)
      2 def scaled_loss_for_ntk(beta):
      3   def ntk_loss(fx,y_hat):
----> 4     return -np.mean(np.sum(jstax.logsoftmax(beta*fx) * y_hat,axis=1))
      5   return ntk_loss
      6 
 
google3/third_party/py/jax/numpy/lax_numpy.py in reduction(a, axis, dtype, out, keepdims)
   1184     a = a if isinstance(a, ndarray) else asarray(a)
   1185     a = preproc(a) if preproc else a
-> 1186     dims = _reduction_dims(a, axis)
   1187     result_dtype = dtype or _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
   1188     if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
 
google3/third_party/py/jax/numpy/lax_numpy.py in _reduction_dims(a, axis)
   1206     return tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
   1207   elif isinstance(axis, int):
-> 1208     return (_canonicalize_axis(axis, ndim(a)),)
   1209   else:
   1210     raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
 
google3/third_party/py/jax/numpy/lax_numpy.py in _canonicalize_axis(axis, num_dims)
    353       raise ValueError(
    354           "axis {} is out of bounds for array of dimension {}".format(
--> 355               axis, num_dims))
    356   return axis
    357 
 
ValueError: axis 1 is out of bounds for array of dimension 1

nt.predict.gradient_descent fails with vmap

Should we expect nt.predict.gradient_descent to fail when using jax vmap due to the scipy ode solver? Are there any suggested workarounds for speeding this function up over batches?

Question about `test_composition_conv` in Stax Tests

Hi, sorry for bothering. In the test_composition_conv_avg_pool test cases, some outer products on the covariance matrices are performed while doing Kernel transformation. In the outer product function, there is the interleave_ones operation which adds ones to the covariance dimensions:

def outer_prod(x, y, start_axis, end_axis, prod_op):
  if y is None:
    y = x
  x = interleave_ones(x, start_axis, end_axis, True)
  y = interleave_ones(y, start_axis, end_axis, False)
  tf.print("x: {}, y: {}".format(x.shape, y.shape), output_stream=sys.stdout)
  return prod_op(x, y)

When I print out the shapes after interleave_ones, some shapes are like x: (5, 1, 8, 1, 8, 1), y: (1, 5, 1, 8, 1, 8) which obviously do not match. In this case, would you mind explaining the role of interleave_ones and how could the unmatched shapes be multiplied together? Thanks!

Weight Evolution and Predictions from Weights

Hi all,

I would like to (analytically) compute the evolution of the weights under the linearized dynamics (i.e., Eqn. (8) in https://arxiv.org/pdf/1902.06720.pdf) and use the resulting weights after t "steps" of gradient flow to make predictions on the training data. More specifically, I would like these predictions to match the predictions obtained by solving the function-space dynamics on the training data (Eqn. (9) in the paper).

To do this, I modified gradient_descent_mse() in predict.py to implement Eqn. (8). Specifically, I added the function

def predict_params_using_kernel(dt, fx_train=0.):
  gx_train = fl(fx_train - y_train)
  dfx = inv_expm1_dot_vec(gx_train, dt)
  dfx = np.dot(Jacobian_f0, dfx)
  return params0 - dfx

where Jacobian_f0 is the Jacobian wrt to the parameters of the NN at initialization, evaluated on the training data.

With the resulting parameters, params_t, converted back to the appropriate pytree strcuture, I then compute predictions on the training data by calling apply_fn(params_t, x_train).

Unfortunately, this does not seem to result in sensible predictions, since the parameters explode, i.e., become large in magnitude, for even small t, and don't match the predictions obtained by solving the function-space dynamics--even on the training data. I am aware that the mapping between parameter states and function predictions is not bijective, but shouldn't the parameters obtained from Eqn. (8) lead to the same predictions as Eqn. (9)?

NB: I did confirm that pre-multiplying dfx = np.dot(Jacobian_f0, dfx) by the transpose of Jacobian_f0 does yield the same matrix as calling the inbuilt function predict_using_kernel().

EDIT: I forgot to mention that I of course also modified the arguments of the gradient_descent_mse() to gradient_descent_mse(g_dd, y_train, params0, Jacobian_f0, g_td=None, diag_reg=0.) (i.e., I added params0, Jacobian_f0).

Any help would be much appreciated!

Thank you!

Refactor to avoid use of lax.reduce_window_shape_tuple

lax.reduce_window_shape_tuple is not a JAX public API, and we change it from time to time.

A better alternative would be to use the public API jax.eval_shape to compute the output shape of a reduce_window operator.

Thanks!

Insight Needed for the Shape Inference of the Key

Hi Roman @romanngg and Sam @sschoenholz , I am currently working on the migration/reconstruction of Neural Tangents from JAX to TensorFlow 2.x, as an R&D project for the TensorFlow team. For the shape inference based on the abstract key https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L2076, we have not found an appropriate way in plugging in an equivalent TF API here. I am currently considering disabling the functionality around this line since akey only seems to be used in generating rng key later on? Could you provide us some insights from the Neural Tangents perspective? If I choose to disable it, what would be an appropriate range? Thanks a lot!

explanation of the implicit method

Hi,

I wonder if you might possibly provide a description of the implicit method used in the _compute_ntk() function or point me to some reference. I find the current codes concise but hard to follow. Thanks!

Regards,
Jerry

Analytic kernel evaluated on sparse inputs

Hi!

A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains nan entries. This can be reproduced with the following lines of codes:

from jax import random
from neural_tangents import stax

key = random.PRNGKey(1)

# a batch of dense inputs 
x_dense = random.normal(key, (3, 32, 32, 3))

# a batch of sparse inputs 
x_sparse = x_dense * (abs(x_dense) > 1.2)


# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
     stax.Conv(128, (3, 3)),
     stax.Relu(),
     stax.Flatten(),
     stax.Dense(10) )

# Evaluate the analytic NTK upon dense inputs

print('NTK evaluated w/ dense inputs: \n', kernel_fn(x_dense, x_dense, 'ntk')) # the outputs look fine.

print('\n')

# Evaluate the analytic NTK upon sparse inputs

print('NTK evaluated w/ sparse inputs: \n', kernel_fn(x_sparse, x_sparse, 'ntk')) # the outputs contains nan

The output of the above script should be:

NTK evaluated w/ dense inputs: 
 [[0.97102666 0.16131128 0.16714054]
 [0.16131128 0.9743941  0.17580226]
 [0.16714054 0.17580226 1.0097454 ]]


NTK evaluated w/ sparse inputs: 
 [[       nan        nan        nan]
 [       nan 0.66292834        nan]
 [       nan        nan        nan]]

Thanks for your time in advance!

Can A transpose conv function be added later?

Hi! Thanks for this awesome API. I try to apply this API to segmentation task so i was wondering if the code supports (or could support with simple extensions) transpose convolution layers(deconvolution layer). If not, is the obstacle practical or theoretical?

Possible typo in README.md

In the last part of 5-minute-intro

I guess the line assigning y_test_ntk should pass parameter get='ntk' instead of get='nngp'

y_test_ntk = predict_fn(x_test=x_test, get='nngp')

I think it should be:

y_test_ntk = predict_fn(x_test=x_test, get='ntk')

Failed to load neural_tangents.tangents on colab or on a local machine

Thanks a lot for making this repository public!

When running the notebooks weight_space_linearization.ipynb and function_space_linearization.ipynb on Google Colab using the link provided on these notebooks, I was unable to import neural_tangents.tangents. A screenshot is attached below:

Screen Shot 2019-09-18 at 1 36 36 PM

The same problem happens when I was trying to run the repository locally on my computer.

This issue seems to happen since the repository has been updated about a week ago. The old version of the codes (currently on the notebook branch) works fine.

MaxPool for NT

Hello, I notice there is no maxpooling in the stax library. Is there any way for me to put such maxpooling layer in my stax.serial() and compute the posterior for nngp or ntk?

I read the paper on arXiv NEURAL TANGENTS: FAST AND EASY INFINITE NEURAL NETWORKS IN PYTHON. The paper suggests to use Monte Carlo sampling technique to approximate the network distribution. My question is how do I construct the network in the first place since there is no maxpooling in the stax library? Currently, there are only avgPool and sumPool. Thanks!

Does stax.serial support Tanh activation

Hi,

When I tried to compute the NTK of a fully-connected network, I couldn't find the Tanh activation in stax.serial.

For example, the following code doesn't work.

init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Tanh(), )

If stax.serial doesn't support the Tanh activation, what else can I do to compute the NTK of Tanh network?

Memory Constraint (Approximation Available?)

Hello,

I currently have a dataset that has 1246064 observations and 94 features. It is my understanding that the GP process would have to generate a kernel size of 1246064 * 1246064, and I am not sure if that is the reason that I am currently running into the following memory error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-81-5095b4194ced> in <module>()
     29     r_mean, r_covariance = nt.predict.gp_inference(
     30         kernel_fn, z_train, r_train, z_test,
---> 31         diag_reg=1e-4, get='ntk', compute_cov=True)
     32     r_mean = np.reshape(r_mean, (-1,))[np.newaxis, ...]
     33     out_rsq_list.append((r_test.detach().cpu().numpy(), r_test.detach().cpu().numpy()))

8 frames
/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py in compile(self, c_computation, compile_options)
    148                                         compile_options.argument_layouts,
    149                                         options, self.client,
--> 150                                         compile_options.device_assignment)
    151 
    152   def get_default_device_assignment(self, num_replicas, num_partitions=None):

RuntimeError: Resource exhausted: Out of memory while trying to allocate 6210718745600 bytes.

I was wondering if there is a way around this (for example, to create a kernel approximation of some sort, similar to this one.

Thanks!

Question: Possible to use nt's stax implementation for a slightly less-linear neural net?

Hello, I have a (probably basic) question. I was wondering if it is possible to use NT's stax implementation to do a more basic neural net. I'm attempting to embed some continuous sequences into n-dimensional space, where inputs x1 and x2 are run through two dense layers, and the final output of the neural net is the manhattan distance between x1 and x2 after the dense layers. This is just so that embedded representation mimics the manhattan distance between the two continuous sequences.

Sorry if that isn't clear, my model is below:

    input1 = Input(shape=(k,5), dtype='float32', name="k1")
    input2 = Input(shape=(k,5), dtype='float32', name="k2")

    input1_flat = Flatten()(input1)
    input2_flat = Flatten()(input2)

    dense1 = Dense(1024, activation="relu", name="Dense1", use_bias=False)
    dense_out = Dense(dims, activation="linear", name="DenseOut", use_bias=False,)

    k1m = dense_out(dense1(input1_flat))
    k2m = dense_out(dense1(input2_flat))

    subtracted = Subtract()([k1m, k2m])
    abs = tf.math.abs(subtracted)
    output = tf.keras.backend.sum(abs, axis=1)

Because at the chosen sequence length the possible inputs are 5^17, I was hoping/wondering if neural tangent would be a good fit, but I can't quite figure out how to make the neural net work with the inputs/outputs from the colab notebook tutorial.

If it's not possible or not a good idea, I'm definitely open. Just exploring possibilities. If it is possible I'd appreciate some pointers, as I haven't used JAX/Stax before, and not sure how to integrate the Subtract layer or make it work with 2 different layers as inputs. I'll keep futzing around with it too in the meantime.

Cheers,
--Joseph

Questions about concat and logic of specifying number of neurons

Hi,

Thank you for sharing this great library. I have two questions which are relevant to each other:

  1. For nngp, we are assuming that the number of neurons goes to infinity. Why do you need to specify the number of neurons in the Dense (or the number of filters in conv) layer? Is this because we are not sure if the layer is mid or last layer?

  2. The answer to the first questions somewhat answers to this question: Does it make sense to have FanInConcat layer (the same as stax)? From one point of view, it doesn't b/c we are concatenating two infinities many layers. From another point of view, it does. For example, if you want to implement models like UNet, you need FanInConcat and I personally think it makes sense to implement it but I not sure.

I would be thankful if you clarify.

Thanks

Non-trainable layers

Thanks for making this great resource available!

I wonder if the layers (such as Conv and Dense) in stax can be specified to be non-trainable? If not, is there a way of modifying the output apply_fn so that the layer becomes non-trainable?

very large memory footprint for a simple UNet

Hi,

I hit a roadblock! I tried to compute kernel for a typical UNet for 10 images. The image size is not big (64,64) and the number of images is just 10 (for testing purposes). However, it crashes complaining about memory (see below). I think intermediate layers are probably using so much memory but that limits the usability. Perhaps, I am missing something?

gist collab: https://gist.github.com/kayhan-batmanghelich/f444e6cec65139070f1b3e5ade230de5

Side notes:

  • If you train the model using gradient descent, the performance is not always good. You should try different seed numbers. I have a different JAX implementation that uses upsample but that need developing a new layer in neural-tangent and I am not sure how to do that.

Error message:

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py:4571: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-053c34ab30f7> in <module>()
----> 1 kernel = mykernel(random_image[:10],random_image[:10])

6 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
    147     _check_args(args_flat)
    148     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    150     return tree_unflatten(out_tree(), out)
    151 

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    600   if top_trace is None:
    601     with new_sublevel():
--> 602       outs = primitive.impl(f, *args, **params)
    603   else:
    604     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    440   device = params['device']
    441   backend = params['backend']
--> 442   compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
    443   try:
    444     return compiled_fun(*args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    221       fun.populate_stores(stores)
    222     else:
--> 223       ans = call(fun, *args)
    224       cache[key] = (ans, fun.stores)
    225     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
    497   options = xb.get_compile_options(
    498       num_replicas=nreps, device_assignment=(device.id,) if device else None)
--> 499   compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend))
    500 
    501   if nreps == 1:

/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py in Compile(self, argument_shapes, compile_options, backend)
    607     if argument_shapes:
    608       compile_options.argument_layouts = argument_shapes
--> 609     return backend.compile(self.computation, compile_options)
    610 
    611   def GetProgramShape(self):

/usr/local/lib/python3.6/dist-packages/jaxlib/tpu_client.py in compile(self, c_computation, compile_options)
    103                                              compile_options.argument_layouts,
    104                                              options, self.client,
--> 105                                              compile_options.device_assignment)
    106 
    107   def get_default_device_assignment(self, num_replicas):

RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 25.99G of 7.48G hbm. Exceeded hbm capacity by 18.50G.

Total hbm usage >= 26.50G:
    reserved        529.00M 
    program          25.99G 
    arguments       unknown size 

Output size unknown.

Program hbm requirement 25.99G:
    reserved           4.0K
    global            36.0K
    HLO temp         25.99G (74.4% utilization: Unpadded (19.34G) Padded (25.98G), 0.0% fragmentation (10.31M))

  Largest program allocations in hbm:

  1. Size: 12.50G
     Operator: op_type="conv_general_dilated"
     Shape: f32[409600,1,64,64]{0,1,3,2:T(2,128)}
     Unpadded size: 6.25G
     Extra memory due to padding: 6.25G (2.0x expansion)
     XLA label: %convolution.5785 = f32[409600,1,64,64]{0,1,3,2:T(2,128)} convolution(bf16[409600,1,64,64]{0,1,3,2:T(4,128)(2,1)} %reshape.1452, bf16[3,3,1,1]{3,2,1,0:T(4,128)(2,1)} %constant.2723), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, metadata={op_t...
     Allocation type: HLO temp
     ==========================

  2. Size: 6.25G
     Shape: f32[409600,1,64,64]{0,3,2,1}
     Unpadded size: 6.25G
     XLA label: %copy.1516 = f32[409600,1,64,64]{0,3,2,1} copy(f32[409600,1,64,64]{0,1,3,2:T(2,128)} %convolution.4620)
     Allocation type: HLO temp
     ==========================

  3. Size: 6.25G
     Shape: f32[409600,1,64,64]{0,3,2,1}
     Unpadded size: 6.25G
     XLA label: %copy.1540 = f32[409600,1,64,64]{0,3,2,1} copy(f32[409600,1,64,64]{0,1,3,2:T(2,128)} %convolution.5785)
     Allocation type: HLO temp
     ==========================

  4. Size: 640.00M
     Operator: op_type="reshape"
     Shape: bf16[10,64,64,64,64]{2,1,0,4,3:T(8,128)(2,1)}
     Unpadded size: 320.00M
     Extra memory due to padding: 320.00M (2.0x expansion)
     XLA label: %reshape.753 = bf16[10,64,64,64,64]{2,1,0,4,3:T(8,128)(2,1)} reshape(bf16[40960,1,64,64]{0,1,3,2:T(4,128)(2,1)} %fusion.420), metadata={op_type="reshape"}
     Allocation type: HLO temp
     ==========================

  5. Size: 160.00M
     Operator: op_type="transpose"
     Shape: bf16[10,64,64,32,32]{2,1,0,4,3:T(8,128)(2,1)}
     Unpadded size: 80.00M
     Extra memory due to padding: 80.00M (2.0x expansion)
     XLA label: %copy.1153 = bf16[10,64,64,32,32]{2,1,0,4,3:T(8,128)(2,1)} copy(bf16[10,64,64,32,32]{2,1,4,3,0:T(8,128)(2,1)} %bitcast.127), metadata={op_type="transpose"}
     Allocation type: HLO temp
     ==========================

  6. Size: 100.00M
     Shape: f32[409600,64]{0,1:T(8,128)}
     Unpadded size: 100.00M
     XLA label: %reshape.1326 = f32[409600,64]{0,1:T(8,128)} reshape(f32[10,10,64,64,64]{3,2,1,0,4:T(8,128)} %broadcast.1682.remat)
     Allocation type: HLO temp
     ==========================

  7. Size: 100.00M
     Shape: f32[409600,64]{0,1:T(8,128)}
     Unpadded size: 100.00M
     XLA label: %reshape.1332 = f32[409600,64]{0,1:T(8,128)} reshape(f32[10,10,64,64,64]{3,2,1,0,4:T(8,128)} %broadcast.2053)
     Allocation type: HLO temp
     ==========================

  8. Size: 256.0K
     Operator: op_type="slice"
     Shape: f32[10,4096]{1,0:T(8,128)}
     Unpadded size: 160.0K
     Extra memory due to padding: 96.0K (1.6x expansion)
     XLA label: %fusion.671 = f32[10,4096]{1,0:T(8,128)} fusion(f32[10,4096,4096]{2,1,0:T(8,128)} %reshape.4392, pred[4096,4096]{1,0:T(8,128)E(32)} %fusion.1076.remat), kind=kLoop, calls=%fused_computation.591, metadata={op_type="slice"}
     Allocation type: HLO temp
     ==========================

  9. Size: 9.0K
     Shape: bf16[3,3,1,1]{3,2,1,0:T(4,128)(2,1)}
     Unpadded size: 18B
     Extra memory due to padding: 9.0K (512.0x expansion)
     XLA label: constant literal
     Allocation type: global
     ==========================

  10. Size: 4.0K
     XLA label: profiler
     Allocation type: reserved
     ==========================

  11. Size: 4.0K
     Shape: bf16[2,2,1,1]{3,2,1,0:T(4,128)(2,1)}
     Unpadded size: 8B
     Extra memory due to padding: 4.0K (512.0x expansion)
     XLA label: constant literal
     Allocation type: global
     ==========================

  12. Size: 4.0K
     Shape: u32[8,128]{1,0}
     Unpadded size: 4.0K
     XLA label: constant literal
     Allocation type: global
     ==========================

Using internal packages

Hi,
you are using internal packages in your code. For example, in the neural_tangents_cookbook.ipynb
"from google3.pyglib import gfile
with gfile.GFile( '/cns/od-d/home/schsam/rs=6.3/ntk/gd_inference.pdf', 'w') as f_out:"

Sparsely Connected Layers

Hi! Thanks for this awesome resource. I was wondering if the code supports (or could support with simple extensions) computing the NTK and/or linearization for sparsely connected (non-convolutional) layers. If not, is the obstacle practical or theoretical?

Question re: evaluating train and test loss with gradient_descent_mse_gp

Hi, I need to evaluate both the training and test loss using predict.gradient_descent_mse_gp. I'm wondering if there is a more efficient way to do this other than to call gradient_descent_mse_gp twice, i.e. with x_train and y_train fixed while varying the argument to x_test. I'm aiming to do this with ~10k data points so each call to this function is rather expensive. Thanks!

Some Tests Skipped

Hi, I am running the test files for Neural Tangents and a lot of cases were skipped. For example, 96 / 127 test cases were skipped in Neural Tangents stax test cases. I looked at the implementation of the tests and it seemed to be the invalid test cases that triggered the SkipTest in unit test. I am wondering if this is an expected thing. Thanks!

Problem running the Colab example: error when computing the diagonal of the the NNGP kernel?

When trying to run step 13 in the Colab Cookbook examples (in Colab, not my own Jupyter instance). When trying to compute the the diagonal of the the NNGP kernel:

kernel = kernel_fn(test_xs, test_xs, 'nngp')
std_dev = np.sqrt(np.diag(kernel))

I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-30-b771ce84abcc> in <module>()
----> 1 kernel = kernel_fn(test_xs, test_xs, 'nngp')
      2 std_dev = np.sqrt(np.diag(kernel))

9 frames
/usr/local/lib/python3.6/dist-packages/neural_tangents/stax.py in _inputs_to_kernel(x1, x2, marginal, cross, compute_ntk)
    284                      'Use `NO` instead to compute all covariances.')
    285 
--> 286   x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64))
    287   var1 = _get_variance(x1, marginal_type=marginal)
    288 

AttributeError: module 'jax.lib.xla_bridge' has no attribute 'canonicalize_dtype'

Looks like something wrong with the Jax installation, but the pip install:
!pip install -q git+https://www.github.com/neural-tangents/neural-tangents
Seemed to run fine.

AttributeError: module 'jax.core' has no attribute 'eval_context' - Where did I go so wrong?

Hello!

I'm trying to run simple model with ntk kernel. Running in Google Colab. Here is my code:

!pip install -q git+https://www.github.com/google/neural-tangents

import jax.numpy as jnp

from jax import random
from jax.experimental import optimizers
from jax.api import jit, grad, vmap
import jax

import functools

import neural_tangents as nt
from neural_tangents import stax

# ... making dataset cifar2 from cifar10
# >>> train_x.shape, train_y.shape, test_x.shape, test_y.shape
# ... ((300, 3072), (300, 1), (300, 3072), (300, 1))

init_fn, apply_fn, kernel_fn = stax.serial(
      stax.Dense(512, 1., 0.05),
      stax.Relu(),
      stax.Dense(512, 1., 0.05),
      stax.Relu(),
      stax.Dense(1, 1., 0.05),
      stax.Flatten()
      )

key = random.PRNGKey(0)
_, params = init_fn(key, (-1, 3072))

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))

k_train_train = kernel_fn(train_x, None, 'ntk')
k_test_train = kernel_fn(test_x, train_x, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, train_y)
fx_train_0 = apply_fn(params, train_x)
fx_test_0 = apply_fn(params, test_x)

predict_fn(t=1.0, fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train) # error!

Last line gives me an error:

/usr/local/lib/python3.6/dist-packages/neural_tangents/predict.py in predict_fn(t, fx_train_0, fx_test_0, k_test_train)
    246 
    247     # Finite time
--> 248     return get_predict_fn_finite()(t, fx_train_0, fx_test_0, k_test_train)
    249 
    250   return predict_fn

/usr/local/lib/python3.6/dist-packages/neural_tangents/predict.py in get_predict_fn_finite()
    161   @lru_cache(1)
    162   def get_predict_fn_finite():
--> 163     with jax.core.eval_context():
    164       expm1_fn, inv_expm1_fn = _get_fns_in_eigenbasis(
    165           k_train_train,

AttributeError: module 'jax.core' has no attribute 'eval_context'

I'm glad to figure out what is the problem here, or maybe I made some stupid mistake.
I know that MSE loss probably is not the best choice for classification task, but anyway, shouldn't this piece of code work?

Thanks for any help!

predict.gradient_descent wrong prediction dimensions

When calling predict in nt.predict.gradient_descent with variables of the following dimensions,
g_dd [256,256]
g_dt [256,256]
fx_train [256,1]
fx_test [256,1]
The tuple of predictions are ([2,256], [0,256]).
Running the same values in nt.predict_gradient_descent_mse returns predictions with dimensions `([1,256], [1,256]).
I am curious if there might be a bug in the following slicing code -

return fx[:train_size], fx[train_size:]

Also the example documentation seem to be outdated:

`test_mask_fc` for Neural Tangents stax

Hi, I am encountering the issue DarrenZhang01/TensorFlow_GSoC#26 for the test case MaskingTest.test_mask_fc_ [different_inputs_get=nngp_axis=(0, 1, 2, 3)_mask=10.0_concat=0_p=0.5] and I am trying to print out the layer and shape information for each block.

According to the code, the major component consists of a parallel block of three parallel serial blocks, where each single sub-block consists of Dense, elementwise and Dense layers:

    nn = stax.serial(
        stax.Flatten(),
        stax.FanOut(3),
        stax.parallel(
            stax.serial(
                stax.Dense(width, 1.5, 0.1),
                stax.Abs(),
                stax.Dense(width, 1.5, 0.1),
            ),
            stax.serial(
                stax.Dense(width, 1.5, 0.1),
                stax.Erf(),
                stax.Dense(width if concat != 1 else 512, 1.5, 0.1),
            ),
            stax.serial(
                stax.Dense(width, 1.5, 0.1),
                stax.ABRelu(-0.2, 0.4),
                stax.Dense(width if concat != 1 else 1024, 3, 0.5),
            )
        ),
        (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
        stax.Dense(width, 2., 0.01),
        stax.Relu()
    )

The printing information is as follows, where I added the =========== notation to highlight the beginning and the end of the parallel blocks. One thing that confuses me is that where are those many serial sub-blocks coming from before the parallel block? According to the above network design, there should only be Flatten and FanOut layer but the printing information suggests otherwise. I am sure the shape information is wrong (i.e., should all be (2, 512) rather than (4, 512)) in the second parallel block, according to the printing information of JAX version of Neural Tangents. But I must know where the 4 comes from in order to proceed. Thanks ahead!

Flatten layer: ndarray<Tensor("zeros:0", shape=(4, 210), dtype=float64)>

Flatten layer: ndarray<Tensor("zeros:0", shape=(2, 210), dtype=float64)>

Fan out: [ndarray<<tf.Tensor 'zeros:0' shape=(4, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(4, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(4, 210) dtype=float64>>]
Fan out: [ndarray<<tf.Tensor 'zeros:0' shape=(2, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(2, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(2, 210) dtype=float64>>]
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (2, 512)
iteration: 1
serial shapes: (2, 512)
iteration: 2
serial shapes: (2, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (2, 512)
iteration: 1
serial shapes: (2, 512)
iteration: 2
serial shapes: (2, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (2, 512)
iteration: 1
serial shapes: (2, 512)
iteration: 2
serial shapes: (2, 512)
================================================================
parallel layer: [(4, 210), (4, 210), (4, 210)]

serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
result: [(ndarray<<tf.Tensor 'zeros_3:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_1:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_2:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_3:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_7:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_4:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_5:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_6:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_7:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_11:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_8:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_9:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_10:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_11:0' shape=(1, 512) dtype=float32>)])]

================================================================
================================================================
parallel layer: [(4, 210), (4, 210), (4, 210)]

serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
result: [(ndarray<<tf.Tensor 'zeros_3:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_1:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_2:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_3:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_7:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_4:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_5:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_6:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_7:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_11:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_8:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_9:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_10:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_11:0' shape=(1, 512) dtype=float32>)])]

================================================================

Error using batch with jit

I get the error Too many leaves for PyTreeDef; expected 6. when I'm trying to run the following code -

def get_network(W_std=1):
    init_fun, apply_fun, ker_fun = stax.serial(
        stax.Dense(1, W_std=W_std, b_std=0.1)
    )
    ker_fun =jit(batch(ker_fun, batch_size=25, device_count=0))
    kdd = ker_fun(train_xs, None)
    return 0
jit(get_network)(2.0)

Memory and computation efficiency problems for empirical NTK kernel

Hi,

I need to compute the empirical NTK kernel ([email protected]) for a NN with ~2.5M parameters including convolution, pooling and dense layers. I need to compute the kernel for up to ~30000 examples of size 3x512x512 each. Before I realized I can do this in neural-tangets, I implemented empirical NTK kernel computation in PyTorch (aggregating layer-wise [email protected] kernels) but without batching and distributed computation. With my implementation on a single 12GB GPU I can compute the kernel for ~100 examples before I hit the GPU memory limit. The computation takes roughly 1 second. However, if I want to get the full kernel 30000x30000 then it will take like a day (and I need to implement batching).

Then I realized that neural-tangets can do exactly this and hoped it has a more efficient implementation than mine and I would be able to speed it up (also I could easily use batching and multi GPU computation). I implemented my NN in jax's stax (it has max pooling that is not handled by neural-tangets) and gave it a try with 100 examples:

init_fn, apply_fn = stax.serial(...model definition...)
key = jax.random.PRNGKey(0)
_, params = init_fn(key, (-1, 512, 512, 3))
x_train = onp.random.randn(100, 512, 512, 3).astype(onp.float32) 
ntk = nt.batch(jit(nt.empirical_ntk_fn(apply_fn)), batch_size=10, device_count=1)
kernel = ntk(x_train, None, params)

It turned out that:

  1. the batch size needs to be around 10; if it's 20 I get memory allocation errors (with my PyTorch implementation I could process 100 examples at once)
  2. the time to compute the kernel is 2 minutes (compared to 1 second with my implementation)

So there are options here - either I'm making some mistake in how I use jax / neural-tangents or the neural-tangets is not suitable for my use case (I really hope it's the former one to blame).

confused by the shape of analytic NTK when evaluated at data points

Hi,

It seems to me that when evaluated at input data, the analytic NTK dimensions are not consistent with empirical NTK dimensions. Concretely, consider a small MLP as the follows:

from neural_tangents import stax

init_fun, apply_fun, ker_fun = stax.serial(
    stax.Dense(5), stax.Relu(),
    stax.Dense(2))

Also, consider a set of 10 input data points, each of dimension 100.

nr_samples = 10
input_data_dim = 100 

from jax import random
# some data points which will be fed into the neural net.
x1 = random.normal(random.PRNGKey(1), (nr_samples, input_data_dim)) 

We can evaluate the empirical NTK with some random parameters rand_params

from neural_tangents.api import get_ker_fun_empirical

# the empirical kernel function
from neural_tangents.api import get_ker_fun_empirical
empirical_ker_fun = get_ker_fun_empirical(apply_fun) 

# some random parameters
_, rand_params = init_fun(random.PRNGKey(1), (-1, 100))  

# empirical NTK matrix evaluated on data points x1
emp_kernel_mat_x1 = empirical_ker_fun(x1, x1, rand_params).ntk 
print(emp_kernel_mat_x1.shape) #  gives (10, 10, 2, 2)

With emp_kernel_mat_x1.shape, we see that the shape of emp_kernel_mat_x1 is (10, 10, 2, 2), which is as expected -- the shape depends on both output dimension 2 and sample size 10. However, when evaluating the analytic kernel on the same data points, the shape differ.

analytic_kernel_mat_x1 = ker_fun(x1, x1)

print(analytic_kernel_mat_x1.ntk.shape) # gives (10, 10)

Here print(analytic_kernel_mat_x1.ntk.shape) gives (10, 10), which is different from the shape of empirical one (10, 10, 2, 2).

I am wondering why the analytic kernel here seems to ignore the neural network output dimension (2, in this case). Would it be possible to get an analytic kernel matrix of the format (#sample, #sample, output_dim, output_dim), which is the same format as the empirical one? Many thanks!!

Best,
Tianlin

Question on `shape1` [Help Needed] - Thanks ahead!

Hi Roman @romanngg ! I am trying to print out the shape1 attribute inside the Kernel object when doing requirement checking, as follows:

  def req(kernel_fn: LayerKernelFn):
    """Returns `kernel_fn` with additional consistency checks."""

    @utils.wraps(kernel_fn)
    def new_kernel_fn(k: Kernels, **user_reqs) -> Kernels:
      """Executes `kernel_fn` on `kernels` after checking consistency."""
      fused_reqs = _fuse_reqs(static_reqs, {}, **user_reqs)

      # `FanInConcat / FanInSum` have no requirements and
      # execute custom consistency checks.
      tf.print("how many times req is getting called", output_stream=sys.stdout)
      if not isinstance(k, list):
        for key, v in fused_reqs.items():
          if v is not None:  # `None` is treated as explicitly not having a req.
            if key in ('diagonal_batch', 'diagonal_spatial'):
              if getattr(k, key) and not v:
                raise ValueError(f'{kernel_fn} requires `{key} == {v}`, but '
                                 f'input kernel has `{key} == True`, hence '
                                 f'does not contain sufficient information. '
                                 f'Please recompute the input kernel with '
                                 f'`{key} == {v}`.')
            elif key in ('batch_axis', 'channel_axis'):
              tf.print("k.shape1: {}".format(k.shape1), output_stream=sys.stdout)
              ndim = len(k.shape1)
              v_kernel = getattr(k, key)
              v_pos = v % ndim
              tf.print("v_pos is: {}".format(v_pos), output_stream=sys.stdout)
              if v_kernel != v_pos:
                raise ValueError(f'{kernel_fn} requires `{key} == {v_pos}`, '
                                 f'but input kernel has `{key} == {v_kernel}`, '
                                 f'making the infinite limit ill-defined.')

In the test case [ RUN ] StaxTest.test_sparse_inputs_act=erf_kernel=nngp for example, the standard output is as follows:

how many times req is getting called
k.shape1: (4, 128)
v_pos is: 0
k.shape1: (4, 128)
v_pos is: 1
how many times req is getting called
k.shape1: (4, 128)
v_pos is: 0
k.shape1: (4, 128)
v_pos is: 1
how many times req is getting called
how many times req is getting called
k.shape1: (4, 4096)
v_pos is: 0
k.shape1: (4, 4096)
v_pos is: 1

In the sparse inputs test case, there seems to be one serial layer composed by one dense layer, one activation layer and another dense layer:

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        activation,
        stax.Dense(1 if kernel == 'ntk' else width))

In this case, would you mind explaining about the correspondence between the layers and the above shapes (I am very confused on this)?

The reason that I am asking this is I am coming across some shape conversion and evaluation problems after using the combination of TF np.zeros and TF Numpy arrays to wrap the shapes (the former is for using eval_on_shapes in TF and the latter is for avoiding general TF Tensor). The problem is pointed out in DarrenZhang01/TensorFlow_GSoC#11, and I want to thoroughly trace the shape flow process.

Thanks very much ahead!

Why is the standard deviation always within [0, 1] and why do I get negative or NaN covariance values?

I am trying to use NNGP/NTK to fit outputs of a black-box function. The y axis of my data has a pretty wide range (e.g. [x, y] where x could be as low as a large negative number and y could be as high as 20000). When I tried to use NNGP/NTK to find a suitable kernel I realized that I get lots of NaNs as standard deviation. When I looked at the [co]variance values I realized that 1) they are super small (e.g. 1e-6) and 2) they are sometimes negative which results in NaN standard deviation values. Also, it would be very likely (or almost certain) that I will get all NaNs for covariance if I set diag_reg to anything below 1e-3. Why is that?

Additionally, I learned the range of std/covariance is [0, 1] which is not correct but the means seem to be correct. I think this should be a bug (relevant to this) and it's possible that the normalization/unnormalization steps have not been implemented properly.

Below I wrote some code that shows these issues:

from jax import random
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax


key = random.PRNGKey(10)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)


train_xs = np.array([0.0000, 0.0200, 0.1000, 0.1200, 0.1400, 0.1600,
        0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.3400,
        0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200,
        0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
        0.8000, 0.8200, 0.8400, 0.8600, 0.8800,
        0.9000, 0.9200, 0.9400, 0.9600, 0.9800, 1.0000, 1.0200, 1.0400, 1.0600,
        1.0800, 1.1000, 1.1200, 1.1400, 1.1600, 1.1800, 1.2000, 1.2200, 1.2400,
        1.2600, 1.2800, 1.3000, 1.3200, 1.3400, 1.3600, 1.3800, 1.4000, 1.4200,
        1.4400, 1.4600, 1.4800, 1.5000, 1.5200, 1.5400, 1.5600, 1.5800, 1.6000,
        1.6200, 1.6400, 1.6600, 1.6800, 1.7000, 1.7200, 1.7400, 1.7600, 1.7800,
        1.8000, 1.8200, 1.8400, 1.8600, 1.8800, 1.9000, 1.9200, 1.9400, 1.9600,
        1.9800, 2.0000, 2.0200, 2.0400, 2.0600, 2.0800, 2.1000, 2.1200, 2.1400]).reshape(-1, 1)

train_ys = np.array([0.1811, 0.1755, 0.0703, 0.0458, 0.0321, 0.0281,
        0.0314, 0.0574, 0.1113, 0.1680, 0.2007, 0.1864,
        0.1542, 0.1240, 0.1012, 0.0931, 0.0928, 0.0932, 0.0932, 0.0993, 0.1158,
        0.1359, 0.1524, 0.1587, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610,
        0.1610, 0.1610, 0.1610, 0.1610, 0.1610,
        0.1610, 0.1610, 0.1610, 0.1610, 0.1705, 0.1995, 0.2493, 0.3048, 0.3482,
        0.3758, 0.3815, 0.3814, 0.3749, 0.3580, 0.3358, 0.3246, 0.3220, 0.3232,
        0.3352, 0.3619, 0.4008, 0.4347, 0.4507, 0.4541, 0.4534, 0.4461, 0.4272,
        0.4089, 0.4031, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025,
        0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025,
        0.4025, 0.4025, 0.4110, 0.4515, 0.5125, 0.5915, 0.6517, 0.6986, 0.7209,
        0.7261, 0.7246, 0.7246, 0.7232, 0.7122, 0.6844, 0.6524, 0.6344, 0.6308]).reshape(-1, 1)*1000
test_xs = np.linspace(0., 3.4, 70).reshape(-1, 1)

mean, covariance = nt.predict.gp_inference(kernel_fn, train_xs, train_ys, test_xs, get='ntk', diag_reg=1e-2, compute_cov=True) #you can also try get='nngp'

mean = np.reshape(mean, (-1,))
std = np.sqrt(np.diag(covariance))
print (mean)
print ('\n')
print (std) # you will get some NaNs and all stds are within (0, 1)

And here's the output:

[0.02037075 1.0244607  0.071396   0.07778245 0.01721494 0.02377584
 0.2626345  0.01136238 0.01295557 0.00731608 0.00607905 0.00560992
        nan 0.01006025 0.01062503 0.06651297 0.02710511 0.01521527
        nan        nan 0.00918696 0.01167288 0.00146484 0.00718454
 0.00580829 0.0038602  0.00803071        nan 0.00358812        nan
        nan 0.00651448 0.00179406        nan 0.00851347        nan
 0.01051223 0.00838651        nan 0.00743728 0.00571519        nan
        nan        nan 0.02975312 0.08047054 0.1433592  0.22170994
 0.3043489  0.38526937 0.46138063 0.5305047  0.591389   0.6448372
 0.6911122  0.73078007 0.76483166 0.7942772  0.819495   0.8412284
 0.859998   0.8762823  0.89045817 0.90289694 0.9137593  0.9233606
 0.9318279  0.9393407  0.94605935 0.95205545]


[177.5949     8.138118  68.319824  29.864521  51.93844  181.89476
 188.30295  178.64008  106.595825  92.69229   96.36267  137.6497
 160.45253  160.79393  160.80162  159.1641   160.38309  160.61182
 162.00093  157.15422  181.85912  287.67047  375.76706  377.4702
 334.8078   321.11066  371.47156  436.5739   451.28027  424.98627
 402.5257   397.91467  401.8235   406.74292  405.56165  393.05548
 383.6278   421.41656  512.12476  630.0038   710.7755   736.6554
 702.9308   640.1339   573.10583  510.43176  456.05994  409.10168
 367.6361   332.77747  300.55188  272.0462   247.56     225.53345
 206.23566  189.16211  174.04999  160.7716   148.92923  138.42578
 129.2586   120.928314 113.518555 106.788605 100.86783   95.36664
  90.43851   86.002396  81.907715  78.196014]

If you set diag_reg to anything like 1e-3 or lower you'll get NaNs for everything:

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]


[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

Memory and running time issues for CNN

Hi,

 I currently use the neural tangents to compute the kernel for CiFAR-10 images. I need to compute the kernel matrix for 10000 images x 10000 images and there are 3x32x32 pixels each image. If I use a 2-layer feedforward NNs with reshaped input 3072, it took me about 3G memory and several minutes to compute the kernel.

However, if I use a simple CNN network (one layer CNN), it will output an error with "failed to allocate request 381T memory". I can only reduce the size of minibatch each time. But it will make the computing process quite slower. And this is just one-layer CNN, I expect it will cost more time for multilayer CNN. And even for one batch (100 images), it still costs much more time than the 2-layer feedforward NNs.

Another strange thing is that I expect that I should be able to compute the kernel matrix for batch size 200 (out of 10000) each time because the server has a memory of 394G.  But it is still out of memory (manually checked) after running several minutes and killed without error prompt.

So I am wondering how to use your tools to compute the kernel matrix for CNNs. It either costs too much memory or too much time in my end. Do you have any suggestions to deal with this issue?  I am not sure about your latent mechanism to compute the kernel for CNN. But I expect it shouldn't cost so much memory and run so slow, because [Arora et al' 2019](https://arxiv.org/pdf/1904.11955.pdf) compute the kernel for 21-layer CNN.

It is really a good tool but I hope that you can help with the CNN memory and running time issue.

Thanks,
Hangfeng

Test Running Error

Hi, I reinstalled some packages and I reran the tests of Neural Tangents (latest version). But I am getting an interesting error and have not found a solution. Previously I ran Neural Tangents tests and this error never occurred. Has anyone else encountered this issue before and can give me some hints? Thanks!

ERROR: test_sample_vs_analytic_nngp_[batch_size=4, device_count=1 store_on_device=False ] (__main__.MonteCarloTest)
test_sample_vs_analytic_nngp_[batch_size=4, device_count=1 store_on_device=False ] (__main__.MonteCarloTest)
test_sample_vs_analytic_nngp_[batch_size=4, device_count=1 store_on_device=False ](batch_size=4, device_count=1, store_on_device=False)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 1955, in shape
    result = a.shape
AttributeError: 'tuple' object has no attribute 'shape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/absl/testing/parameterized.py", line 263, in bound_param_test
    test_method(self, **testcase_params)
  File "monte_carlo_test.py", line 152, in test_sample_vs_analytic_nngp
    ker_empirical = sample(x1, x2, 'nngp')
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/monte_carlo.py", line 103, in get_sampled_kernel
    for n, sample in get_samples(x1, x2, get, **apply_fn_kwargs):
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/monte_carlo.py", line 77, in get_samples
    one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 363, in serial_fn
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 327, in serial_fn_x1
    _, kernel = _scan(row_fn, 0, x1s)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 122, in _scan
    carry, y = f(carry, x)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 322, in row_fn
    return _, _scan(col_fn, x1, x2s)[1]
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 122, in _scan
    carry, y = f(carry, x)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 325, in col_fn
    return x1, kernel_fn(x1, x2, *args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 294, in kernel_fn
    return device_put(_kernel_fn(x1, x2, *args, **kwargs), devices('cpu')[0])
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 458, in parallel_fn
    return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 431, in parallel_fn_x1
    kernel = kernel_fn(x1, x2, *args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 593, in f_pmapped
    return _f(x_or_kernel, *args_np)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 169, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1100, in call_bind
    outs = primitive.impl(fun, *args, **params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 541, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 221, in memoized_fun
    ans = call(fun, *args)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_callable
    jaxpr, pvals, consts = pe.trace_to_jaxpr(
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 429, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 869, in batched_fun
    out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/batching.py", line 34, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 586, in _f
    return f(_x_or_kernel, *_args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/monte_carlo.py", line 53, in kernel_fn_sample_once
    keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1,
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1283, in where
    return _where(condition, x, y)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 169, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1103, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1112, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/batching.py", line 148, in process_call
    vals_out = call_primitive.bind(f, *vals, **params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1103, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1112, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 186, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 298, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1100, in call_bind
    outs = primitive.impl(fun, *args, **params)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 541, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 221, in memoized_fun
    ans = call(fun, *args)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_callable
    jaxpr, pvals, consts = pe.trace_to_jaxpr(
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 429, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1267, in _where
    condition, x, y = broadcast_arrays(condition, x, y)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1328, in broadcast_arrays
    shapes = [shape(arg) for arg in args]
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1328, in <listcomp>
    shapes = [shape(arg) for arg in args]
  File "<__array_function__ internals>", line 5, in shape
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 1957, in shape
    result = asarray(a).shape
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/numpy/core/_asarray.py", line 83, in asarray
    return array(a, dtype, copy=False, order=order)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 450, in __array__
    raise Exception(msg)
Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(uint32[2])>with<BatchTrace(level=0/2)>
  with val = Traced<ShapedArray(uint32[1,2]):JaxprTrace(level=-1/2)>
       batch_dim = 0.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.