google / neural-tangents Goto Github PK
View Code? Open in Web Editor NEWFast and Easy Infinite Neural Networks in Python
Home Page: https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
License: Apache License 2.0
Fast and Easy Infinite Neural Networks in Python
Home Page: https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
License: Apache License 2.0
I get the error Too many leaves for PyTreeDef; expected 6.
when I'm trying to run the following code -
def get_network(W_std=1):
init_fun, apply_fun, ker_fun = stax.serial(
stax.Dense(1, W_std=W_std, b_std=0.1)
)
ker_fun =jit(batch(ker_fun, batch_size=25, device_count=0))
kdd = ker_fun(train_xs, None)
return 0
jit(get_network)(2.0)
When installing using pip3
on my PC, the bug in line 421 of neural_tangents/utils/batch.py
still exists (i.e., np.onp
), although it seems to have already been revised in commit 272dc5e
Would there be significant obstacles in implementing recurrent neural network using this library?
Hi Roman @romanngg ! I am trying to print out the shape1
attribute inside the Kernel
object when doing requirement checking, as follows:
def req(kernel_fn: LayerKernelFn):
"""Returns `kernel_fn` with additional consistency checks."""
@utils.wraps(kernel_fn)
def new_kernel_fn(k: Kernels, **user_reqs) -> Kernels:
"""Executes `kernel_fn` on `kernels` after checking consistency."""
fused_reqs = _fuse_reqs(static_reqs, {}, **user_reqs)
# `FanInConcat / FanInSum` have no requirements and
# execute custom consistency checks.
tf.print("how many times req is getting called", output_stream=sys.stdout)
if not isinstance(k, list):
for key, v in fused_reqs.items():
if v is not None: # `None` is treated as explicitly not having a req.
if key in ('diagonal_batch', 'diagonal_spatial'):
if getattr(k, key) and not v:
raise ValueError(f'{kernel_fn} requires `{key} == {v}`, but '
f'input kernel has `{key} == True`, hence '
f'does not contain sufficient information. '
f'Please recompute the input kernel with '
f'`{key} == {v}`.')
elif key in ('batch_axis', 'channel_axis'):
tf.print("k.shape1: {}".format(k.shape1), output_stream=sys.stdout)
ndim = len(k.shape1)
v_kernel = getattr(k, key)
v_pos = v % ndim
tf.print("v_pos is: {}".format(v_pos), output_stream=sys.stdout)
if v_kernel != v_pos:
raise ValueError(f'{kernel_fn} requires `{key} == {v_pos}`, '
f'but input kernel has `{key} == {v_kernel}`, '
f'making the infinite limit ill-defined.')
In the test case [ RUN ] StaxTest.test_sparse_inputs_act=erf_kernel=nngp
for example, the standard output is as follows:
how many times req is getting called
k.shape1: (4, 128)
v_pos is: 0
k.shape1: (4, 128)
v_pos is: 1
how many times req is getting called
k.shape1: (4, 128)
v_pos is: 0
k.shape1: (4, 128)
v_pos is: 1
how many times req is getting called
how many times req is getting called
k.shape1: (4, 4096)
v_pos is: 0
k.shape1: (4, 4096)
v_pos is: 1
In the sparse inputs test case, there seems to be one serial layer composed by one dense layer, one activation layer and another dense layer:
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(width),
activation,
stax.Dense(1 if kernel == 'ntk' else width))
In this case, would you mind explaining about the correspondence between the layers and the above shapes (I am very confused on this)?
The reason that I am asking this is I am coming across some shape conversion and evaluation problems after using the combination of TF np.zeros
and TF Numpy arrays to wrap the shapes (the former is for using eval_on_shapes
in TF and the latter is for avoiding general TF Tensor). The problem is pointed out in DarrenZhang01/TensorFlow_GSoC#11, and I want to thoroughly trace the shape flow process.
Thanks very much ahead!
Hi all,
I would like to (analytically) compute the evolution of the weights under the linearized dynamics (i.e., Eqn. (8) in https://arxiv.org/pdf/1902.06720.pdf) and use the resulting weights after t "steps" of gradient flow to make predictions on the training data. More specifically, I would like these predictions to match the predictions obtained by solving the function-space dynamics on the training data (Eqn. (9) in the paper).
To do this, I modified gradient_descent_mse()
in predict.py to implement Eqn. (8). Specifically, I added the function
def predict_params_using_kernel(dt, fx_train=0.):
gx_train = fl(fx_train - y_train)
dfx = inv_expm1_dot_vec(gx_train, dt)
dfx = np.dot(Jacobian_f0, dfx)
return params0 - dfx
where Jacobian_f0
is the Jacobian wrt to the parameters of the NN at initialization, evaluated on the training data.
With the resulting parameters, params_t
, converted back to the appropriate pytree strcuture, I then compute predictions on the training data by calling apply_fn(params_t, x_train)
.
Unfortunately, this does not seem to result in sensible predictions, since the parameters explode, i.e., become large in magnitude, for even small t, and don't match the predictions obtained by solving the function-space dynamics--even on the training data. I am aware that the mapping between parameter states and function predictions is not bijective, but shouldn't the parameters obtained from Eqn. (8) lead to the same predictions as Eqn. (9)?
NB: I did confirm that pre-multiplying dfx = np.dot(Jacobian_f0, dfx)
by the transpose of Jacobian_f0
does yield the same matrix as calling the inbuilt function predict_using_kernel()
.
EDIT: I forgot to mention that I of course also modified the arguments of the gradient_descent_mse()
to gradient_descent_mse(g_dd, y_train, params0, Jacobian_f0, g_td=None, diag_reg=0.)
(i.e., I added params0, Jacobian_f0
).
Any help would be much appreciated!
Thank you!
Any time I try to call the initialization function on a network containing a convolutional layer I get the same "tuple index out of range error".
Here is a minimum example using one of the code snippets provided in the preprint:
from neural_tangents import stax
from jax import random
key = random.PRNGKey(10)
def ConvolutionalNetwork(depth, W_std=1.0, b_std=0.0):
layers = []
for _ in range(depth):
layers += [stax.Conv(1, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'), stax.Relu()]
layers += [stax.Flatten(), stax.Dense(1, W_std, b_std)]
return stax.serial(*layers)
init_fn, apply_fn, kernel_fn = ConvolutionalNetwork(4)
x = random.normal(key, (10, 100))
init_fn(key, x.shape)
The same issue arises using the WideResNet code in the preprint as well, or while using Cifar-10 data. Does anyone have insight on this?
Thanks!
Hi,
I hit a roadblock! I tried to compute kernel for a typical UNet for 10 images. The image size is not big (64,64) and the number of images is just 10 (for testing purposes). However, it crashes complaining about memory (see below). I think intermediate layers are probably using so much memory but that limits the usability. Perhaps, I am missing something?
gist collab: https://gist.github.com/kayhan-batmanghelich/f444e6cec65139070f1b3e5ade230de5
Side notes:
upsample
but that need developing a new layer in neural-tangent
and I am not sure how to do that.Error message:
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py:4571: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-053c34ab30f7> in <module>()
----> 1 kernel = mykernel(random_image[:10],random_image[:10])
6 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
147 _check_args(args_flat)
148 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
150 return tree_unflatten(out_tree(), out)
151
/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
600 if top_trace is None:
601 with new_sublevel():
--> 602 outs = primitive.impl(f, *args, **params)
603 else:
604 tracers = map(top_trace.full_raise, args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
440 device = params['device']
441 backend = params['backend']
--> 442 compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
443 try:
444 return compiled_fun(*args)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
221 fun.populate_stores(stores)
222 else:
--> 223 ans = call(fun, *args)
224 cache[key] = (ans, fun.stores)
225 return ans
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
497 options = xb.get_compile_options(
498 num_replicas=nreps, device_assignment=(device.id,) if device else None)
--> 499 compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend))
500
501 if nreps == 1:
/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py in Compile(self, argument_shapes, compile_options, backend)
607 if argument_shapes:
608 compile_options.argument_layouts = argument_shapes
--> 609 return backend.compile(self.computation, compile_options)
610
611 def GetProgramShape(self):
/usr/local/lib/python3.6/dist-packages/jaxlib/tpu_client.py in compile(self, c_computation, compile_options)
103 compile_options.argument_layouts,
104 options, self.client,
--> 105 compile_options.device_assignment)
106
107 def get_default_device_assignment(self, num_replicas):
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 25.99G of 7.48G hbm. Exceeded hbm capacity by 18.50G.
Total hbm usage >= 26.50G:
reserved 529.00M
program 25.99G
arguments unknown size
Output size unknown.
Program hbm requirement 25.99G:
reserved 4.0K
global 36.0K
HLO temp 25.99G (74.4% utilization: Unpadded (19.34G) Padded (25.98G), 0.0% fragmentation (10.31M))
Largest program allocations in hbm:
1. Size: 12.50G
Operator: op_type="conv_general_dilated"
Shape: f32[409600,1,64,64]{0,1,3,2:T(2,128)}
Unpadded size: 6.25G
Extra memory due to padding: 6.25G (2.0x expansion)
XLA label: %convolution.5785 = f32[409600,1,64,64]{0,1,3,2:T(2,128)} convolution(bf16[409600,1,64,64]{0,1,3,2:T(4,128)(2,1)} %reshape.1452, bf16[3,3,1,1]{3,2,1,0:T(4,128)(2,1)} %constant.2723), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, metadata={op_t...
Allocation type: HLO temp
==========================
2. Size: 6.25G
Shape: f32[409600,1,64,64]{0,3,2,1}
Unpadded size: 6.25G
XLA label: %copy.1516 = f32[409600,1,64,64]{0,3,2,1} copy(f32[409600,1,64,64]{0,1,3,2:T(2,128)} %convolution.4620)
Allocation type: HLO temp
==========================
3. Size: 6.25G
Shape: f32[409600,1,64,64]{0,3,2,1}
Unpadded size: 6.25G
XLA label: %copy.1540 = f32[409600,1,64,64]{0,3,2,1} copy(f32[409600,1,64,64]{0,1,3,2:T(2,128)} %convolution.5785)
Allocation type: HLO temp
==========================
4. Size: 640.00M
Operator: op_type="reshape"
Shape: bf16[10,64,64,64,64]{2,1,0,4,3:T(8,128)(2,1)}
Unpadded size: 320.00M
Extra memory due to padding: 320.00M (2.0x expansion)
XLA label: %reshape.753 = bf16[10,64,64,64,64]{2,1,0,4,3:T(8,128)(2,1)} reshape(bf16[40960,1,64,64]{0,1,3,2:T(4,128)(2,1)} %fusion.420), metadata={op_type="reshape"}
Allocation type: HLO temp
==========================
5. Size: 160.00M
Operator: op_type="transpose"
Shape: bf16[10,64,64,32,32]{2,1,0,4,3:T(8,128)(2,1)}
Unpadded size: 80.00M
Extra memory due to padding: 80.00M (2.0x expansion)
XLA label: %copy.1153 = bf16[10,64,64,32,32]{2,1,0,4,3:T(8,128)(2,1)} copy(bf16[10,64,64,32,32]{2,1,4,3,0:T(8,128)(2,1)} %bitcast.127), metadata={op_type="transpose"}
Allocation type: HLO temp
==========================
6. Size: 100.00M
Shape: f32[409600,64]{0,1:T(8,128)}
Unpadded size: 100.00M
XLA label: %reshape.1326 = f32[409600,64]{0,1:T(8,128)} reshape(f32[10,10,64,64,64]{3,2,1,0,4:T(8,128)} %broadcast.1682.remat)
Allocation type: HLO temp
==========================
7. Size: 100.00M
Shape: f32[409600,64]{0,1:T(8,128)}
Unpadded size: 100.00M
XLA label: %reshape.1332 = f32[409600,64]{0,1:T(8,128)} reshape(f32[10,10,64,64,64]{3,2,1,0,4:T(8,128)} %broadcast.2053)
Allocation type: HLO temp
==========================
8. Size: 256.0K
Operator: op_type="slice"
Shape: f32[10,4096]{1,0:T(8,128)}
Unpadded size: 160.0K
Extra memory due to padding: 96.0K (1.6x expansion)
XLA label: %fusion.671 = f32[10,4096]{1,0:T(8,128)} fusion(f32[10,4096,4096]{2,1,0:T(8,128)} %reshape.4392, pred[4096,4096]{1,0:T(8,128)E(32)} %fusion.1076.remat), kind=kLoop, calls=%fused_computation.591, metadata={op_type="slice"}
Allocation type: HLO temp
==========================
9. Size: 9.0K
Shape: bf16[3,3,1,1]{3,2,1,0:T(4,128)(2,1)}
Unpadded size: 18B
Extra memory due to padding: 9.0K (512.0x expansion)
XLA label: constant literal
Allocation type: global
==========================
10. Size: 4.0K
XLA label: profiler
Allocation type: reserved
==========================
11. Size: 4.0K
Shape: bf16[2,2,1,1]{3,2,1,0:T(4,128)(2,1)}
Unpadded size: 8B
Extra memory due to padding: 4.0K (512.0x expansion)
XLA label: constant literal
Allocation type: global
==========================
12. Size: 4.0K
Shape: u32[8,128]{1,0}
Unpadded size: 4.0K
XLA label: constant literal
Allocation type: global
==========================
Hi,
I need to compute the empirical NTK kernel ([email protected]) for a NN with ~2.5M parameters including convolution, pooling and dense layers. I need to compute the kernel for up to ~30000 examples of size 3x512x512 each. Before I realized I can do this in neural-tangets, I implemented empirical NTK kernel computation in PyTorch (aggregating layer-wise [email protected] kernels) but without batching and distributed computation. With my implementation on a single 12GB GPU I can compute the kernel for ~100 examples before I hit the GPU memory limit. The computation takes roughly 1 second. However, if I want to get the full kernel 30000x30000 then it will take like a day (and I need to implement batching).
Then I realized that neural-tangets can do exactly this and hoped it has a more efficient implementation than mine and I would be able to speed it up (also I could easily use batching and multi GPU computation). I implemented my NN in jax's stax (it has max pooling that is not handled by neural-tangets) and gave it a try with 100 examples:
init_fn, apply_fn = stax.serial(...model definition...)
key = jax.random.PRNGKey(0)
_, params = init_fn(key, (-1, 512, 512, 3))
x_train = onp.random.randn(100, 512, 512, 3).astype(onp.float32)
ntk = nt.batch(jit(nt.empirical_ntk_fn(apply_fn)), batch_size=10, device_count=1)
kernel = ntk(x_train, None, params)
It turned out that:
So there are options here - either I'm making some mistake in how I use jax / neural-tangents or the neural-tangets is not suitable for my use case (I really hope it's the former one to blame).
Hi, would you mind explaining the difference between a general Kernel
type and a <class 'neural_tangents.utils.utils.AnalyticKernel'>
type? Thanks ahead!
I'm getting 'out of memory' error when I use 'gradient_descent_mse_gp' function caused by the np.einsum in the 'prediction' function.
Is there a way to make a batch version of this function?
Hi, I need to evaluate both the training and test loss using predict.gradient_descent_mse_gp
. I'm wondering if there is a more efficient way to do this other than to call gradient_descent_mse_gp
twice, i.e. with x_train
and y_train
fixed while varying the argument to x_test
. I'm aiming to do this with ~10k data points so each call to this function is rather expensive. Thanks!
Hi, I reinstalled some packages and I reran the tests of Neural Tangents (latest version). But I am getting an interesting error and have not found a solution. Previously I ran Neural Tangents tests and this error never occurred. Has anyone else encountered this issue before and can give me some hints? Thanks!
ERROR: test_sample_vs_analytic_nngp_[batch_size=4, device_count=1 store_on_device=False ] (__main__.MonteCarloTest)
test_sample_vs_analytic_nngp_[batch_size=4, device_count=1 store_on_device=False ] (__main__.MonteCarloTest)
test_sample_vs_analytic_nngp_[batch_size=4, device_count=1 store_on_device=False ](batch_size=4, device_count=1, store_on_device=False)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 1955, in shape
result = a.shape
AttributeError: 'tuple' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/absl/testing/parameterized.py", line 263, in bound_param_test
test_method(self, **testcase_params)
File "monte_carlo_test.py", line 152, in test_sample_vs_analytic_nngp
ker_empirical = sample(x1, x2, 'nngp')
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
return g(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
fn_out = fn(*canonicalized_args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/monte_carlo.py", line 103, in get_sampled_kernel
for n, sample in get_samples(x1, x2, get, **apply_fn_kwargs):
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/monte_carlo.py", line 77, in get_samples
one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
return g(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 363, in serial_fn
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 327, in serial_fn_x1
_, kernel = _scan(row_fn, 0, x1s)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 122, in _scan
carry, y = f(carry, x)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 322, in row_fn
return _, _scan(col_fn, x1, x2s)[1]
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 122, in _scan
carry, y = f(carry, x)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 325, in col_fn
return x1, kernel_fn(x1, x2, *args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
return g(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 294, in kernel_fn
return device_put(_kernel_fn(x1, x2, *args, **kwargs), devices('cpu')[0])
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
return g(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 458, in parallel_fn
return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 431, in parallel_fn_x1
kernel = kernel_fn(x1, x2, *args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/utils.py", line 84, in h
return g(*args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 593, in f_pmapped
return _f(x_or_kernel, *args_np)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 169, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1100, in call_bind
outs = primitive.impl(fun, *args, **params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 541, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 221, in memoized_fun
ans = call(fun, *args)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_callable
jaxpr, pvals, consts = pe.trace_to_jaxpr(
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 429, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 869, in batched_fun
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/batching.py", line 34, in batch
return batched_fun.call_wrapped(*in_vals)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/batch.py", line 586, in _f
return f(_x_or_kernel, *_args, **kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/neural_tangents/utils/monte_carlo.py", line 53, in kernel_fn_sample_once
keys = np.where(utils.x1_is_x2(x1, x2), dropout_key1,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1283, in where
return _where(condition, x, y)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 169, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1103, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1112, in process
return trace.process_call(self, fun, tracers, params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/batching.py", line 148, in process_call
vals_out = call_primitive.bind(f, *vals, **params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1103, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1112, in process
return trace.process_call(self, fun, tracers, params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 186, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 298, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 1100, in call_bind
outs = primitive.impl(fun, *args, **params)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 541, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 221, in memoized_fun
ans = call(fun, *args)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_callable
jaxpr, pvals, consts = pe.trace_to_jaxpr(
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 429, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1267, in _where
condition, x, y = broadcast_arrays(condition, x, y)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1328, in broadcast_arrays
shapes = [shape(arg) for arg in args]
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 1328, in <listcomp>
shapes = [shape(arg) for arg in args]
File "<__array_function__ internals>", line 5, in shape
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 1957, in shape
result = asarray(a).shape
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/numpy/core/_asarray.py", line 83, in asarray
return array(a, dtype, copy=False, order=order)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/core.py", line 450, in __array__
raise Exception(msg)
Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(uint32[2])>with<BatchTrace(level=0/2)>
with val = Traced<ShapedArray(uint32[1,2]):JaxprTrace(level=-1/2)>
batch_dim = 0.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.
lax.reduce_window_shape_tuple
is not a JAX public API, and we change it from time to time.
A better alternative would be to use the public API jax.eval_shape
to compute the output shape of a reduce_window
operator.
Thanks!
I am trying to use NNGP/NTK to fit outputs of a black-box function. The y axis of my data has a pretty wide range (e.g. [x, y] where x could be as low as a large negative number and y could be as high as 20000). When I tried to use NNGP/NTK to find a suitable kernel I realized that I get lots of NaNs as standard deviation. When I looked at the [co]variance values I realized that 1) they are super small (e.g. 1e-6) and 2) they are sometimes negative which results in NaN standard deviation values. Also, it would be very likely (or almost certain) that I will get all NaNs for covariance if I set diag_reg
to anything below 1e-3
. Why is that?
Additionally, I learned the range of std/covariance is [0, 1] which is not correct but the means seem to be correct. I think this should be a bug (relevant to this) and it's possible that the normalization/unnormalization steps have not been implemented properly.
Below I wrote some code that shows these issues:
from jax import random
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
key = random.PRNGKey(10)
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
stax.Dense(1, W_std=1.5, b_std=0.05)
)
train_xs = np.array([0.0000, 0.0200, 0.1000, 0.1200, 0.1400, 0.1600,
0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.3400,
0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200,
0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
0.8000, 0.8200, 0.8400, 0.8600, 0.8800,
0.9000, 0.9200, 0.9400, 0.9600, 0.9800, 1.0000, 1.0200, 1.0400, 1.0600,
1.0800, 1.1000, 1.1200, 1.1400, 1.1600, 1.1800, 1.2000, 1.2200, 1.2400,
1.2600, 1.2800, 1.3000, 1.3200, 1.3400, 1.3600, 1.3800, 1.4000, 1.4200,
1.4400, 1.4600, 1.4800, 1.5000, 1.5200, 1.5400, 1.5600, 1.5800, 1.6000,
1.6200, 1.6400, 1.6600, 1.6800, 1.7000, 1.7200, 1.7400, 1.7600, 1.7800,
1.8000, 1.8200, 1.8400, 1.8600, 1.8800, 1.9000, 1.9200, 1.9400, 1.9600,
1.9800, 2.0000, 2.0200, 2.0400, 2.0600, 2.0800, 2.1000, 2.1200, 2.1400]).reshape(-1, 1)
train_ys = np.array([0.1811, 0.1755, 0.0703, 0.0458, 0.0321, 0.0281,
0.0314, 0.0574, 0.1113, 0.1680, 0.2007, 0.1864,
0.1542, 0.1240, 0.1012, 0.0931, 0.0928, 0.0932, 0.0932, 0.0993, 0.1158,
0.1359, 0.1524, 0.1587, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610,
0.1610, 0.1610, 0.1610, 0.1610, 0.1610,
0.1610, 0.1610, 0.1610, 0.1610, 0.1705, 0.1995, 0.2493, 0.3048, 0.3482,
0.3758, 0.3815, 0.3814, 0.3749, 0.3580, 0.3358, 0.3246, 0.3220, 0.3232,
0.3352, 0.3619, 0.4008, 0.4347, 0.4507, 0.4541, 0.4534, 0.4461, 0.4272,
0.4089, 0.4031, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025,
0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025,
0.4025, 0.4025, 0.4110, 0.4515, 0.5125, 0.5915, 0.6517, 0.6986, 0.7209,
0.7261, 0.7246, 0.7246, 0.7232, 0.7122, 0.6844, 0.6524, 0.6344, 0.6308]).reshape(-1, 1)*1000
test_xs = np.linspace(0., 3.4, 70).reshape(-1, 1)
mean, covariance = nt.predict.gp_inference(kernel_fn, train_xs, train_ys, test_xs, get='ntk', diag_reg=1e-2, compute_cov=True) #you can also try get='nngp'
mean = np.reshape(mean, (-1,))
std = np.sqrt(np.diag(covariance))
print (mean)
print ('\n')
print (std) # you will get some NaNs and all stds are within (0, 1)
And here's the output:
[0.02037075 1.0244607 0.071396 0.07778245 0.01721494 0.02377584
0.2626345 0.01136238 0.01295557 0.00731608 0.00607905 0.00560992
nan 0.01006025 0.01062503 0.06651297 0.02710511 0.01521527
nan nan 0.00918696 0.01167288 0.00146484 0.00718454
0.00580829 0.0038602 0.00803071 nan 0.00358812 nan
nan 0.00651448 0.00179406 nan 0.00851347 nan
0.01051223 0.00838651 nan 0.00743728 0.00571519 nan
nan nan 0.02975312 0.08047054 0.1433592 0.22170994
0.3043489 0.38526937 0.46138063 0.5305047 0.591389 0.6448372
0.6911122 0.73078007 0.76483166 0.7942772 0.819495 0.8412284
0.859998 0.8762823 0.89045817 0.90289694 0.9137593 0.9233606
0.9318279 0.9393407 0.94605935 0.95205545]
[177.5949 8.138118 68.319824 29.864521 51.93844 181.89476
188.30295 178.64008 106.595825 92.69229 96.36267 137.6497
160.45253 160.79393 160.80162 159.1641 160.38309 160.61182
162.00093 157.15422 181.85912 287.67047 375.76706 377.4702
334.8078 321.11066 371.47156 436.5739 451.28027 424.98627
402.5257 397.91467 401.8235 406.74292 405.56165 393.05548
383.6278 421.41656 512.12476 630.0038 710.7755 736.6554
702.9308 640.1339 573.10583 510.43176 456.05994 409.10168
367.6361 332.77747 300.55188 272.0462 247.56 225.53345
206.23566 189.16211 174.04999 160.7716 148.92923 138.42578
129.2586 120.928314 113.518555 106.788605 100.86783 95.36664
90.43851 86.002396 81.907715 78.196014]
If you set diag_reg
to anything like 1e-3 or lower you'll get NaNs for everything:
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]
I believe the following line should be kernel_fn(x_test, x_train, params)
instead of kernel_fn(x_train, x_test, params)
.
When calling predict
in nt.predict.gradient_descent
with variables of the following dimensions,
g_dd [256,256]
g_dt [256,256]
fx_train [256,1]
fx_test [256,1]
The tuple of predictions are ([2,256], [0,256])
.
Running the same values in nt.predict_gradient_descent_mse
returns predictions with dimensions `([1,256], [1,256]).
I am curious if there might be a bug in the following slicing code -
neural-tangents/neural_tangents/predict.py
Line 278 in 38e9ba9
Also the example documentation seem to be outdated:
neural-tangents/neural_tangents/predict.py
Line 215 in 38e9ba9
neural-tangents/neural_tangents/predict.py
Line 184 in 38e9ba9
Hi, I am encountering the issue DarrenZhang01/TensorFlow_GSoC#26 for the test case MaskingTest.test_mask_fc_ [different_inputs_get=nngp_axis=(0, 1, 2, 3)_mask=10.0_concat=0_p=0.5]
and I am trying to print out the layer and shape information for each block.
According to the code, the major component consists of a parallel block of three parallel serial blocks, where each single sub-block consists of Dense
, elementwise
and Dense
layers:
nn = stax.serial(
stax.Flatten(),
stax.FanOut(3),
stax.parallel(
stax.serial(
stax.Dense(width, 1.5, 0.1),
stax.Abs(),
stax.Dense(width, 1.5, 0.1),
),
stax.serial(
stax.Dense(width, 1.5, 0.1),
stax.Erf(),
stax.Dense(width if concat != 1 else 512, 1.5, 0.1),
),
stax.serial(
stax.Dense(width, 1.5, 0.1),
stax.ABRelu(-0.2, 0.4),
stax.Dense(width if concat != 1 else 1024, 3, 0.5),
)
),
(stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
stax.Dense(width, 2., 0.01),
stax.Relu()
)
The printing information is as follows, where I added the ===========
notation to highlight the beginning and the end of the parallel blocks. One thing that confuses me is that where are those many serial sub-blocks coming from before the parallel block? According to the above network design, there should only be Flatten
and FanOut
layer but the printing information suggests otherwise. I am sure the shape information is wrong (i.e., should all be (2, 512)
rather than (4, 512)
) in the second parallel block, according to the printing information of JAX version of Neural Tangents. But I must know where the 4
comes from in order to proceed. Thanks ahead!
Flatten layer: ndarray<Tensor("zeros:0", shape=(4, 210), dtype=float64)>
Flatten layer: ndarray<Tensor("zeros:0", shape=(2, 210), dtype=float64)>
Fan out: [ndarray<<tf.Tensor 'zeros:0' shape=(4, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(4, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(4, 210) dtype=float64>>]
Fan out: [ndarray<<tf.Tensor 'zeros:0' shape=(2, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(2, 210) dtype=float64>>, ndarray<<tf.Tensor 'zeros:0' shape=(2, 210) dtype=float64>>]
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (2, 512)
iteration: 1
serial shapes: (2, 512)
iteration: 2
serial shapes: (2, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (2, 512)
iteration: 1
serial shapes: (2, 512)
iteration: 2
serial shapes: (2, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (2, 512)
iteration: 1
serial shapes: (2, 512)
iteration: 2
serial shapes: (2, 512)
================================================================
parallel layer: [(4, 210), (4, 210), (4, 210)]
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
result: [(ndarray<<tf.Tensor 'zeros_3:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_1:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_2:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_3:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_7:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_4:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_5:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_6:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_7:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_11:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_8:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_9:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_10:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_11:0' shape=(1, 512) dtype=float32>)])]
================================================================
================================================================
parallel layer: [(4, 210), (4, 210), (4, 210)]
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a78a280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78a3a0>), (<function elementwise.<locals>.<lambda> at 0x14a78aaf0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78aee0>), (<function Dense.<locals>.ntk_init_fn at 0x14a78c3a0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a78c4c0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a7921f0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792310>), (<function elementwise.<locals>.<lambda> at 0x14a792a60>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a792e50>), (<function Dense.<locals>.ntk_init_fn at 0x14a795310>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a795430>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
serial layer: ((<function Dense.<locals>.ntk_init_fn at 0x14a79a160>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79a280>), (<function elementwise.<locals>.<lambda> at 0x14a79a9d0>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79adc0>), (<function Dense.<locals>.ntk_init_fn at 0x14a79e280>, <function _supports_masking.<locals>.supports_masking.<locals>.layer_with_masking.<locals>.apply_fn_with_masking at 0x14a79e3a0>))
iteration: 0
serial shapes: (4, 512)
iteration: 1
serial shapes: (4, 512)
iteration: 2
serial shapes: (4, 512)
result: [(ndarray<<tf.Tensor 'zeros_3:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_1:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_2:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_3:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_7:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_4:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_5:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_6:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_7:0' shape=(1, 512) dtype=float32>)]), (ndarray<<tf.Tensor 'zeros_11:0' shape=(4, 512) dtype=float64>>, [(<tf.Tensor 'stateless_random_normal_8:0' shape=(210, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_9:0' shape=(1, 512) dtype=float32>), (), (<tf.Tensor 'stateless_random_normal_10:0' shape=(512, 512) dtype=float32>, <tf.Tensor 'stateless_random_normal_11:0' shape=(1, 512) dtype=float32>)])]
================================================================
Hi! Thanks for this awesome resource. I was wondering if the code supports (or could support with simple extensions) computing the NTK and/or linearization for sparsely connected (non-convolutional) layers. If not, is the obstacle practical or theoretical?
For example,
def f(x):
print(_arr_is_on_cpu(x))
f(np.array([1.0])) # Prints False
jit(f)(np.array([1.0])) #Prints True
Thanks for making this great resource available!
I wonder if the layers (such as Conv
and Dense
) in stax
can be specified to be non-trainable? If not, is there a way of modifying the output apply_fn
so that the layer becomes non-trainable?
Hello, I notice there is no maxpooling in the stax library. Is there any way for me to put such maxpooling layer in my stax.serial() and compute the posterior for nngp or ntk?
I read the paper on arXiv NEURAL TANGENTS: FAST AND EASY INFINITE NEURAL NETWORKS IN PYTHON. The paper suggests to use Monte Carlo sampling technique to approximate the network distribution. My question is how do I construct the network in the first place since there is no maxpooling in the stax library? Currently, there are only avgPool and sumPool. Thanks!
Seems to be a bug in nt.predict.gradient_descent
, perhaps related to flattening of inputs. Code snippet and stacktrace below.
Code snippet:
def ntk_loss(fx,y_hat):
return -np.mean(np.sum(jstax.logsoftmax(beta*fx) * y_hat,axis=1))
g_dd = kernel_fn(x_train, x_train, 'ntk') # kernel_fn from nt.stax.serial
g_td = kernel_fn(x_test, x_train, 'ntk') # test and train numpy arrays
ntk_loss = scaled_loss_for_ntk(beta)
ntk_loss = jit(ntk_loss)
predict_fn = nt.predict.gradient_descent(g_dd, y_train, ntk_loss, g_td)
predict_fn(0.1,fx_train_initial,fx_test_initial)
Stacktrace of error:
ValueError Traceback (most recent call last)
<ipython-input-97-592ef42dc248> in <module>()
5 ntk_outputs, ntk_loss_fn, ntk_acc_fn = get_ntk_dynamics(
6 kernel_fn,x_train,x_test,y_train,
----> 7 y_test,fx_train_initial,fx_test_initial,beta)
8 # get results
9 train_loss = nnp.zeros(len(ts))
25 frames
<ipython-input-96-9b76a38cb837> in get_ntk_dynamics(kernel_fn, x_train, x_test, y_train, y_test, fx_train_initial, fx_test_initial, beta)
22 print('NTK initial loss: {}'.format(ntk_loss(fx_train_initial,y_train)))
23 predict_fn = nt.predict.gradient_descent(g_dd, y_train, ntk_loss, g_td)
---> 24 predict_fn(0.1,fx_train_initial,fx_test_initial)
25
26 ntk_outputs = functools.partial(
google3/third_party/py/neural_tangents/predict.py in predict(dt, fx_train, fx_test)
276 train_size, output_dim = fx_train.shape
277 r.set_initial_value(fx, 0).set_f_params(train_size * output_dim)
--> 278 r.integrate(dt)
279 fx = ufl(r.y)
280
google3/third_party/py/scipy/integrate/_ode.py in integrate(self, t, step, relax)
430 self._y, self.t = mth(self.f, self.jac or (lambda: None),
431 self._y, self.t, t,
--> 432 self.f_params, self.jac_params)
433 except SystemError:
434 # f2py issue with tuple returns, see ticket 1187.
google3/third_party/py/scipy/integrate/_ode.py in run(self, f, jac, y0, t0, t1, f_params, jac_params)
1170 def run(self, f, jac, y0, t0, t1, f_params, jac_params):
1171 x, y, iwork, istate = self.runner(*((f, t0, y0, t1) +
-> 1172 tuple(self.call_args) + (f_params,)))
1173 self.istate = istate
1174 if istate < 0:
google3/third_party/py/neural_tangents/predict.py in dfx_dt(unused_t, fx, train_size)
266 def dfx_dt(unused_t, fx, train_size):
267 fx_train = fx[:train_size]
--> 268 dfx_train = -ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
269 dfx_test = -ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
270 return np.concatenate((dfx_train, dfx_test), axis=0)
google3/third_party/py/jax/api.py in grad_f(*args, **kwargs)
353 @wraps(fun, docstr=docstr, argnums=argnums)
354 def grad_f(*args, **kwargs):
--> 355 _, g = value_and_grad_f(*args, **kwargs)
356 return g
357
google3/third_party/py/jax/api.py in value_and_grad_f(*args, **kwargs)
408 f_partial, dyn_args = _argnums_partial(f, argnums, args)
409 if not has_aux:
--> 410 ans, vjp_py = vjp(f_partial, *dyn_args)
411 else:
412 ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
google3/third_party/py/jax/api.py in vjp(fun, *primals, **kwargs)
1267 if not has_aux:
1268 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1269 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1270 out_tree = out_tree()
1271 else:
google3/third_party/py/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
106 def vjp(traceable, primals, has_aux=False):
107 if not has_aux:
--> 108 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
109 else:
110 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
google3/third_party/py/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
95 _, in_tree = tree_flatten(((primals, primals), {}))
96 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 97 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
98 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
99 aval_primals, const_primals = unzip2(pval_primals)
google3/third_party/py/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, **kwargs)
313 with new_master(JaxprTrace) as master:
314 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 315 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
316 assert not env
317 del master
google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:
google3/third_party/py/jax/api.py in f_jitted(*args, **kwargs)
148 _check_args(args_flat)
149 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 150 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
151 return tree_unflatten(out_tree(), out)
152
google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
593 else:
594 tracers = map(top_trace.full_raise, args)
--> 595 outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
596 return apply_todos(env_trace_todo(), outs)
597
google3/third_party/py/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
324 nonzero_tangents, in_tree_def = tree_flatten(tangents)
325 f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), len(primals), in_tree_def)
--> 326 result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
327 primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
328 return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
593 else:
594 tracers = map(top_trace.full_raise, args)
--> 595 outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
596 return apply_todos(env_trace_todo(), outs)
597
google3/third_party/py/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
113 in_pvs, in_consts = unzip2([t.pval for t in tracers])
114 fun, aux = partial_eval(f, self, in_pvs)
--> 115 out_flat = call_primitive.bind(fun, *in_consts, **params)
116 out_pvs, jaxpr, env = aux()
117 out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
590 if top_trace is None:
591 with new_sublevel():
--> 592 outs = primitive.impl(f, *args, **params)
593 else:
594 tracers = map(top_trace.full_raise, args)
google3/third_party/py/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
398 device = params['device']
399 backend = params.get('backend', None)
--> 400 compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
401 try:
402 return compiled_fun(*args)
google3/third_party/py/jax/linear_util.py in memoized_fun(fun, *args)
207 fun.populate_stores(stores)
208 else:
--> 209 ans = call(fun, *args)
210 cache[key] = (ans, fun.stores)
211 return ans
google3/third_party/py/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *abstract_args)
410 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
411 with core.new_master(pe.JaxprTrace, True) as master:
--> 412 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
413 assert not env # no subtraces here
414 del master, env
google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:
<ipython-input-96-9b76a38cb837> in ntk_loss(fx, y_hat)
2 def scaled_loss_for_ntk(beta):
3 def ntk_loss(fx,y_hat):
----> 4 return -np.mean(np.sum(jstax.logsoftmax(beta*fx) * y_hat,axis=1))
5 return ntk_loss
6
google3/third_party/py/jax/numpy/lax_numpy.py in reduction(a, axis, dtype, out, keepdims)
1184 a = a if isinstance(a, ndarray) else asarray(a)
1185 a = preproc(a) if preproc else a
-> 1186 dims = _reduction_dims(a, axis)
1187 result_dtype = dtype or _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
1188 if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
google3/third_party/py/jax/numpy/lax_numpy.py in _reduction_dims(a, axis)
1206 return tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
1207 elif isinstance(axis, int):
-> 1208 return (_canonicalize_axis(axis, ndim(a)),)
1209 else:
1210 raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
google3/third_party/py/jax/numpy/lax_numpy.py in _canonicalize_axis(axis, num_dims)
353 raise ValueError(
354 "axis {} is out of bounds for array of dimension {}".format(
--> 355 axis, num_dims))
356 return axis
357
ValueError: axis 1 is out of bounds for array of dimension 1
Hi Roman @romanngg and Sam @sschoenholz , I am currently working on the migration/reconstruction of Neural Tangents from JAX to TensorFlow 2.x, as an R&D project for the TensorFlow team. For the shape inference based on the abstract key https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L2076, we have not found an appropriate way in plugging in an equivalent TF API here. I am currently considering disabling the functionality around this line since akey
only seems to be used in generating rng key later on? Could you provide us some insights from the Neural Tangents perspective? If I choose to disable it, what would be an appropriate range? Thanks a lot!
Hello, I have a (probably basic) question. I was wondering if it is possible to use NT's stax implementation to do a more basic neural net. I'm attempting to embed some continuous sequences into n-dimensional space, where inputs x1 and x2 are run through two dense layers, and the final output of the neural net is the manhattan distance between x1 and x2 after the dense layers. This is just so that embedded representation mimics the manhattan distance between the two continuous sequences.
Sorry if that isn't clear, my model is below:
input1 = Input(shape=(k,5), dtype='float32', name="k1")
input2 = Input(shape=(k,5), dtype='float32', name="k2")
input1_flat = Flatten()(input1)
input2_flat = Flatten()(input2)
dense1 = Dense(1024, activation="relu", name="Dense1", use_bias=False)
dense_out = Dense(dims, activation="linear", name="DenseOut", use_bias=False,)
k1m = dense_out(dense1(input1_flat))
k2m = dense_out(dense1(input2_flat))
subtracted = Subtract()([k1m, k2m])
abs = tf.math.abs(subtracted)
output = tf.keras.backend.sum(abs, axis=1)
Because at the chosen sequence length the possible inputs are 5^17, I was hoping/wondering if neural tangent would be a good fit, but I can't quite figure out how to make the neural net work with the inputs/outputs from the colab notebook tutorial.
If it's not possible or not a good idea, I'm definitely open. Just exploring possibilities. If it is possible I'd appreciate some pointers, as I haven't used JAX/Stax before, and not sure how to integrate the Subtract layer or make it work with 2 different layers as inputs. I'll keep futzing around with it too in the meantime.
Cheers,
--Joseph
Hi, sorry for bothering. In the test_composition_conv_avg_pool
test cases, some outer products on the covariance matrices are performed while doing Kernel transformation. In the outer product function, there is the interleave_ones
operation which adds ones to the covariance dimensions:
def outer_prod(x, y, start_axis, end_axis, prod_op):
if y is None:
y = x
x = interleave_ones(x, start_axis, end_axis, True)
y = interleave_ones(y, start_axis, end_axis, False)
tf.print("x: {}, y: {}".format(x.shape, y.shape), output_stream=sys.stdout)
return prod_op(x, y)
When I print out the shapes after interleave_ones
, some shapes are like x: (5, 1, 8, 1, 8, 1), y: (1, 5, 1, 8, 1, 8)
which obviously do not match. In this case, would you mind explaining the role of interleave_ones
and how could the unmatched shapes be multiplied together? Thanks!
Should we expect nt.predict.gradient_descent to fail when using jax vmap due to the scipy ode solver? Are there any suggested workarounds for speeding this function up over batches?
Hi,
you are using internal packages in your code. For example, in the neural_tangents_cookbook.ipynb
"from google3.pyglib import gfile
with gfile.GFile( '/cns/od-d/home/schsam/rs=6.3/ntk/gd_inference.pdf', 'w') as f_out:"
Hello,
I currently have a dataset that has 1246064 observations and 94 features. It is my understanding that the GP process would have to generate a kernel size of 1246064 * 1246064, and I am not sure if that is the reason that I am currently running into the following memory error:
RuntimeError Traceback (most recent call last)
<ipython-input-81-5095b4194ced> in <module>()
29 r_mean, r_covariance = nt.predict.gp_inference(
30 kernel_fn, z_train, r_train, z_test,
---> 31 diag_reg=1e-4, get='ntk', compute_cov=True)
32 r_mean = np.reshape(r_mean, (-1,))[np.newaxis, ...]
33 out_rsq_list.append((r_test.detach().cpu().numpy(), r_test.detach().cpu().numpy()))
8 frames
/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py in compile(self, c_computation, compile_options)
148 compile_options.argument_layouts,
149 options, self.client,
--> 150 compile_options.device_assignment)
151
152 def get_default_device_assignment(self, num_replicas, num_partitions=None):
RuntimeError: Resource exhausted: Out of memory while trying to allocate 6210718745600 bytes.
I was wondering if there is a way around this (for example, to create a kernel approximation of some sort, similar to this one.
Thanks!
I'm having fun playing with the Neural Tangents Cookbook.ipynb and I'd like to try extending it to multivariate regression. However, when I changed the output dimension of last layer in stax.serial
, the dimensions of the predicted mean and predicted covariance remain the same. Why is this, and what do I need to change to extend to multivariate regression?
Hi,
I wonder if you might possibly provide a description of the implicit method used in the _compute_ntk() function or point me to some reference. I find the current codes concise but hard to follow. Thanks!
Regards,
Jerry
Hi,
It seems to me that when evaluated at input data, the analytic NTK dimensions are not consistent with empirical NTK dimensions. Concretely, consider a small MLP as the follows:
from neural_tangents import stax
init_fun, apply_fun, ker_fun = stax.serial(
stax.Dense(5), stax.Relu(),
stax.Dense(2))
Also, consider a set of 10 input data points, each of dimension 100.
nr_samples = 10
input_data_dim = 100
from jax import random
# some data points which will be fed into the neural net.
x1 = random.normal(random.PRNGKey(1), (nr_samples, input_data_dim))
We can evaluate the empirical NTK with some random parameters rand_params
from neural_tangents.api import get_ker_fun_empirical
# the empirical kernel function
from neural_tangents.api import get_ker_fun_empirical
empirical_ker_fun = get_ker_fun_empirical(apply_fun)
# some random parameters
_, rand_params = init_fun(random.PRNGKey(1), (-1, 100))
# empirical NTK matrix evaluated on data points x1
emp_kernel_mat_x1 = empirical_ker_fun(x1, x1, rand_params).ntk
print(emp_kernel_mat_x1.shape) # gives (10, 10, 2, 2)
With emp_kernel_mat_x1.shape
, we see that the shape of emp_kernel_mat_x1
is (10, 10, 2, 2), which is as expected -- the shape depends on both output dimension 2 and sample size 10. However, when evaluating the analytic kernel on the same data points, the shape differ.
analytic_kernel_mat_x1 = ker_fun(x1, x1)
print(analytic_kernel_mat_x1.ntk.shape) # gives (10, 10)
Here print(analytic_kernel_mat_x1.ntk.shape)
gives (10, 10), which is different from the shape of empirical one (10, 10, 2, 2).
I am wondering why the analytic kernel here seems to ignore the neural network output dimension (2, in this case). Would it be possible to get an analytic kernel matrix of the format (#sample, #sample, output_dim, output_dim), which is the same format as the empirical one? Many thanks!!
Best,
Tianlin
Hi,
Thank you for sharing this great library. I have two questions which are relevant to each other:
For nngp
, we are assuming that the number of neurons goes to infinity. Why do you need to specify the number of neurons in the Dense
(or the number of filters in conv
) layer? Is this because we are not sure if the layer is mid or last layer?
The answer to the first questions somewhat answers to this question: Does it make sense to have FanInConcat
layer (the same as stax
)? From one point of view, it doesn't b/c we are concatenating two infinities many layers. From another point of view, it does. For example, if you want to implement models like UNet, you need FanInConcat
and I personally think it makes sense to implement it but I not sure.
I would be thankful if you clarify.
Thanks
Hi!
A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains nan
entries. This can be reproduced with the following lines of codes:
from jax import random
from neural_tangents import stax
key = random.PRNGKey(1)
# a batch of dense inputs
x_dense = random.normal(key, (3, 32, 32, 3))
# a batch of sparse inputs
x_sparse = x_dense * (abs(x_dense) > 1.2)
# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(128, (3, 3)),
stax.Relu(),
stax.Flatten(),
stax.Dense(10) )
# Evaluate the analytic NTK upon dense inputs
print('NTK evaluated w/ dense inputs: \n', kernel_fn(x_dense, x_dense, 'ntk')) # the outputs look fine.
print('\n')
# Evaluate the analytic NTK upon sparse inputs
print('NTK evaluated w/ sparse inputs: \n', kernel_fn(x_sparse, x_sparse, 'ntk')) # the outputs contains nan
The output of the above script should be:
NTK evaluated w/ dense inputs:
[[0.97102666 0.16131128 0.16714054]
[0.16131128 0.9743941 0.17580226]
[0.16714054 0.17580226 1.0097454 ]]
NTK evaluated w/ sparse inputs:
[[ nan nan nan]
[ nan 0.66292834 nan]
[ nan nan nan]]
Thanks for your time in advance!
Hi Roman @romanngg , would you mind explaining a bit about the kernel shape calculation, especially the _propagate_shape
function? I am also curious about the use of akey
. Thanks a lot!
https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L2072
Hi, could anyone explain to me the use of _reduce_window_sum
in the core file stax.py
? It appears twice on line 3095 and 3049 and it is an internal method for JAX lax (hard to find documentation). Thanks ahead!
https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L3049
https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L3095
Hi,
I currently use the neural tangents to compute the kernel for CiFAR-10 images. I need to compute the kernel matrix for 10000 images x 10000 images and there are 3x32x32 pixels each image. If I use a 2-layer feedforward NNs with reshaped input 3072, it took me about 3G memory and several minutes to compute the kernel.
However, if I use a simple CNN network (one layer CNN), it will output an error with "failed to allocate request 381T memory". I can only reduce the size of minibatch each time. But it will make the computing process quite slower. And this is just one-layer CNN, I expect it will cost more time for multilayer CNN. And even for one batch (100 images), it still costs much more time than the 2-layer feedforward NNs.
Another strange thing is that I expect that I should be able to compute the kernel matrix for batch size 200 (out of 10000) each time because the server has a memory of 394G. But it is still out of memory (manually checked) after running several minutes and killed without error prompt.
So I am wondering how to use your tools to compute the kernel matrix for CNNs. It either costs too much memory or too much time in my end. Do you have any suggestions to deal with this issue? I am not sure about your latent mechanism to compute the kernel for CNN. But I expect it shouldn't cost so much memory and run so slow, because [Arora et al' 2019](https://arxiv.org/pdf/1904.11955.pdf) compute the kernel for 21-layer CNN.
It is really a good tool but I hope that you can help with the CNN memory and running time issue.
Thanks,
Hangfeng
Hi Roman @romanngg , would you mind explaining the use of jax.lax.cond
in https://github.com/google/neural-tangents/blob/master/neural_tangents/utils/utils.py#L359-L361? (My question is: one boolean is enough for cond
but three are given?) Thanks!
Hi, could anyone explain to me the reason for composing (akey,)
and {}
as the input for the tree_flatten
function in https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L2074? Thanks!
Hi,
When I tried to compute the NTK of a fully-connected network, I couldn't find the Tanh activation in stax.serial.
For example, the following code doesn't work.
init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Tanh(), )
If stax.serial doesn't support the Tanh activation, what else can I do to compute the NTK of Tanh network?
In the setup.py file, it is required that 'jaxlib>=0.1.37'. However, the latest version of jax (updated 18 days ago) seems to be 0.1.36. So perhaps this is a typo (or maybe an internal jax version was used).
Hi, I am running the test files for Neural Tangents and a lot of cases were skipped. For example, 96 / 127 test cases were skipped in Neural Tangents stax test cases. I looked at the implementation of the tests and it seemed to be the invalid test cases that triggered the SkipTest
in unit test. I am wondering if this is an expected thing. Thanks!
Hi, could you confirm that the input is (-1, 7)
in https://github.com/google/neural-tangents/blob/master/tests/stax_test.py#L1017? Since this is not reshaping, what does -1
represent here as the first dimension? Thanks!
Hi, maybe somewhat similar to #48, but what would be a good alternative instead of passing -1
as the first input dimension here? (TF Numpy array does not recognize negative shapes) Thanks!
https://github.com/google/neural-tangents/blob/master/tests/empirical_test.py#L85
_, params = init_fn(key, (-1,) + input_shape)
Thanks a lot for making this repository public!
When running the notebooks weight_space_linearization.ipynb
and function_space_linearization.ipynb
on Google Colab using the link provided on these notebooks, I was unable to import neural_tangents.tangents. A screenshot is attached below:
The same problem happens when I was trying to run the repository locally on my computer.
This issue seems to happen since the repository has been updated about a week ago. The old version of the codes (currently on the notebook branch) works fine.
Hi, can we tell which dimension is the channel dimension for the input matrix mat
and for the parameter filter_shape
? Or can I assume there is only one channel by default? Thanks a lot!
https://github.com/google/neural-tangents/blob/master/neural_tangents/stax.py#L3058
I guess the line assigning y_test_ntk
should pass parameter get='ntk'
instead of get='nngp'
y_test_ntk = predict_fn(x_test=x_test, get='nngp')
I think it should be:
y_test_ntk = predict_fn(x_test=x_test, get='ntk')
When trying to run step 13 in the Colab Cookbook examples (in Colab, not my own Jupyter instance). When trying to compute the the diagonal of the the NNGP kernel:
kernel = kernel_fn(test_xs, test_xs, 'nngp')
std_dev = np.sqrt(np.diag(kernel))
I get the following error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-30-b771ce84abcc> in <module>()
----> 1 kernel = kernel_fn(test_xs, test_xs, 'nngp')
2 std_dev = np.sqrt(np.diag(kernel))
9 frames
/usr/local/lib/python3.6/dist-packages/neural_tangents/stax.py in _inputs_to_kernel(x1, x2, marginal, cross, compute_ntk)
284 'Use `NO` instead to compute all covariances.')
285
--> 286 x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64))
287 var1 = _get_variance(x1, marginal_type=marginal)
288
AttributeError: module 'jax.lib.xla_bridge' has no attribute 'canonicalize_dtype'
Looks like something wrong with the Jax installation, but the pip install:
!pip install -q git+https://www.github.com/neural-tangents/neural-tangents
Seemed to run fine.
Hello!
I'm trying to run simple model with ntk kernel. Running in Google Colab. Here is my code:
!pip install -q git+https://www.github.com/google/neural-tangents
import jax.numpy as jnp
from jax import random
from jax.experimental import optimizers
from jax.api import jit, grad, vmap
import jax
import functools
import neural_tangents as nt
from neural_tangents import stax
# ... making dataset cifar2 from cifar10
# >>> train_x.shape, train_y.shape, test_x.shape, test_y.shape
# ... ((300, 3072), (300, 1), (300, 3072), (300, 1))
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512, 1., 0.05),
stax.Relu(),
stax.Dense(512, 1., 0.05),
stax.Relu(),
stax.Dense(1, 1., 0.05),
stax.Flatten()
)
key = random.PRNGKey(0)
_, params = init_fn(key, (-1, 3072))
apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2,))
k_train_train = kernel_fn(train_x, None, 'ntk')
k_test_train = kernel_fn(test_x, train_x, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, train_y)
fx_train_0 = apply_fn(params, train_x)
fx_test_0 = apply_fn(params, test_x)
predict_fn(t=1.0, fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train) # error!
Last line gives me an error:
/usr/local/lib/python3.6/dist-packages/neural_tangents/predict.py in predict_fn(t, fx_train_0, fx_test_0, k_test_train)
246
247 # Finite time
--> 248 return get_predict_fn_finite()(t, fx_train_0, fx_test_0, k_test_train)
249
250 return predict_fn
/usr/local/lib/python3.6/dist-packages/neural_tangents/predict.py in get_predict_fn_finite()
161 @lru_cache(1)
162 def get_predict_fn_finite():
--> 163 with jax.core.eval_context():
164 expm1_fn, inv_expm1_fn = _get_fns_in_eigenbasis(
165 k_train_train,
AttributeError: module 'jax.core' has no attribute 'eval_context'
I'm glad to figure out what is the problem here, or maybe I made some stupid mistake.
I know that MSE loss probably is not the best choice for classification task, but anyway, shouldn't this piece of code work?
Thanks for any help!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.