Comments (3)
code snippet
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example doing inference with an infinitely wide fully-connected network.
By default, this example does inference on a small CIFAR10 subset.
"""
import time
from absl import app
from absl import flags
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
from examples import util
flags.DEFINE_integer('train_size', 10000,
'Dataset size to use for training.')
flags.DEFINE_integer('test_size', 1000,
'Dataset size to use for testing.')
flags.DEFINE_integer('batch_size', 0,
'Batch size for kernel computation. 0 for no batching.')
FLAGS = flags.FLAGS
def main(unused_argv):
# Build data pipelines.
print('Loading data.')
x_train, y_train, x_test, y_test = \
datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)
# Build the infinite network.
_, _, kernel_fn = stax.serial(
stax.Dense(1, 2., 0.05),
stax.Relu(),
stax.Dense(1, 2., 0.05)
)
# Optionally, compute the kernel in batches, in parallel.
kernel_fn = nt.batch(kernel_fn,
device_count=0,
batch_size=FLAGS.batch_size)
start = time.time()
# Bayesian and infinite-time gradient descent inference with infinite network.
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train, diag_reg=1e-3)
fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
fx_test_nngp.block_until_ready()
fx_test_ntk.block_until_ready()
duration = time.time() - start
print('Kernel construction and inference done in %s seconds.' % duration)
# Print out accuracy and loss for infinite network predictions.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
from jax import grad
def MSELoss(x_test):
loss = 0.5 * ((predict_fn(x_test=x_test, get='ntk') - y_test) ** 2).mean()
return loss
def Norm(x_test):
return (x_test ** 2).mean()
print(grad(MSELoss)(x_test).shape)
print(x_test.shape)
import eagerpy as ep
print(type(x_test))
x_test = np.array(x_test)
print(type(x_test))
x_test = ep.astensor(x_test)
print(type(x_test))
loss, g = ep.value_and_grad(MSELoss, x_test) # Error!
loss, g = ep.value_and_grad(Norm, x_test)
print(g.shape)
if __name__ == '__main__':
app.run(main)
from eagerpy.
As can be seen from the above code, eagerpy works well with JAX for pure function, but will break as soon as predict_fn
is involved.
from eagerpy.
Thanks for reporting this. Could you add syntax highlighting to your code and share the exact error message? Can you try value_and_grad_fn
instead of value_and_grad
. I think we should be able to fix this once we know what the exact error message is.
from eagerpy.
Related Issues (20)
- Will it support for SparseTensor (Tensorflow or Pytorch)? HOT 1
- How to Transform a torch tensor to tensorflow tensor HOT 1
- Equivalent of `np.diag`? HOT 5
- topk
- Inclusion of probability distributions (scope question) HOT 1
- implementation of `slogdet` in eagerpy HOT 5
- ep.totensor method? HOT 5
- add type conversions [feature request] HOT 4
- Python Scalars Support HOT 3
- Missing support for ep.nonzero() and ep.flatnonzero() HOT 1
- Have a decorator to wrap universal functions ? HOT 6
- Support for @ operator ? HOT 6
- Does a universal function can be compiled in tensorflow? HOT 5
- where method do not works with pytorch
- `index_update` seems very slow for tensorflow backend
- Why restrict cross entropy to 2D inputs only? HOT 1
- TensorFlowTensor.index_update fails for int64/float64 tensors and int/float values
- ValueError: Unknown type: <class 'tuple'>
- sigmoid support HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from eagerpy.