jaxgaussianprocesses / jaxutils Goto Github PK
View Code? Open in Web Editor NEWModel training utilities in JAX.
License: Apache License 2.0
Model training utilities in JAX.
License: Apache License 2.0
Create a workflow to creata a jaxutils-nightly
build.
There is no test workflow checks in place on a PR to this library.
Move the package's CI/CD to CircleCI.
Would be nice to have a train test split akin to scikit-learn, for the Dataset
.
import jax.random as jr
from jax.random import KeyArray
from jaxutils import Dataset
# Need to define this function
def train_test_split(data: Dataset, Key: KeyArray, test_size: float, ...) -> Tuple[Dataset, Dataset]
...
# Example usage:
data = Dataset(...)
key = jr.PRNGKey(42)
size = 0.3
train, test = train_test_split(data, key, size)
Use versioneer for automatic package versioning.
Right now, a custom PyTree
object is used. This should be replaced with Flax's PyTreeNode
. This would then be consumed by the Parameters
object to enable more robust parameter handling.
For some applications with GPs, like Bayesian Optimization, the dataset grows dynamically with time. Unfortunately, dynamic array sizes with Jax jit
compiled functions causes the computation to be re-compiled for every different buffer size. This means that the computation will take much longer than should be neccesary...
In my own code I was able to work around the recompilation with dynamic shapes by using a fixed buffer and modifying the Gaussian Process logic through a dynamic masks that treats all data at index i>t
as independent of j<=t
in the Kernel computation. One downside is of course that all iterations from t=1, ... n, will induce a time and memory complexity proportional to n
. For most applications, however, the speed-up provided by jit
makes this completely negligible.
I am not sure whether a solution already exists within gpjax
as I'm still relatively new to this cool library :).
Describe Preferred Solution
I believe something like this can be implemented as follows, though I haven't yet tried.
gpx.Dataset
and create a sub-class gpx.OnlineDataset(gpx.Dataset)
with a new integer time_step
variable and requiring the exact shapes of the data-buffer for initialization.jax.ops
.DynamicKernel
class that wraps around the standard kernel K
computation along the lines of K(a, b, a_idx, b_idx, t)
that returns K(a, b)
if a_idx <= b_idx <= t
and otherwise int(a_idx == b_idx)
.Describe Alternatives
NA
Related Code
Example of the jit
recompilation based on the Documentation Regression notebook:
import gpjax as gpx
from jax import jit, random
from jax import numpy as jnp
n = 5
x = jnp.linspace(-1, 1, n)[..., None]
y = jnp.sin(x)[..., None]
xtest = jnp.linspace(-2, 2, 100)[..., None]
@jit
def gp_predict(xs, x_train, y_train):
posterior = gpx.Prior(kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=len(x_train))
params, *_ = gpx.initialise(
posterior, random.PRNGKey(0), kernel={"lengthscale": jnp.array([0.5])}
).unpack()
post_predictive = posterior(params, gpx.Dataset(X=x_train, y=y_train))
out_dist = post_predictive(xs)
return out_dist.mean(), out_dist.stddev()
# First call - compile
print('compile')
for i in range(len(x)):
%time gp_predict(xtest, x[:i+1], y[:i+1])
print()
# Second call - use cached
print('jitted')
for i in range(len(x)):
%time gp_predict(xtest, x[:i+1], y[:i+1])
# Output
compile
CPU times: user 519 ms, sys: 1.64 ms, total: 521 ms
Wall time: 293 ms
CPU times: user 1.06 s, sys: 0 ns, total: 1.06 s
Wall time: 316 ms
CPU times: user 956 ms, sys: 17.9 ms, total: 974 ms
Wall time: 219 ms
jitted
CPU times: user 3.66 ms, sys: 443 µs, total: 4.11 ms
Wall time: 2.46 ms
CPU times: user 2.89 ms, sys: 348 µs, total: 3.23 ms
Wall time: 1.84 ms
CPU times: user 894 µs, sys: 0 ns, total: 894 µs
Wall time: 568 µs
Additional Context
Example issue on the Jax: google/jax#2521
If the feature request is approved, would you be willing to submit a PR?
When I have time available I can try and port my solution to the gpjax API, though, I am still quite new to the library.
Would be nice to have a Scaler
object that scales inputs or and outputs of a jaxutils.Dataset
, and that saves the mean and variance, to scale test inputs for later.
from jaxutils import PyTree
class Scaler(PyTree):
...
# call method scales data and "fits the scale transform"
train = jaxutils.Dataset(X=..., y=...)
test = jaxutils.Dataset(X=..., y=...)
scaler = Scaler(...)
scaled_train = Scaler(train) # learn the transform
scaled_test = Scaler(test) # scales the test data, under the learnt transform of the train data
I just stumbled across jaxutils
, and spotted jaxutils.PyTree
. I can see that this is based off of Distrax's Jittable
base class.
I wanted to give a heads-up that Distrax's approach has some performance issues, and some compatibility issues. So I'd really recommend against using it.
Equinox has an equinox.Module
which accomplishes the same thing (registering a class as a pytree), and also automatically handles a lot of edge cases. (E.g. bound methods are pytrees too; multiple inheritance works smoothly; good performance; pretty-printing; etc.) I realise I am advertising my own libary here... but hopefully it's of interest!
The __add__
method that concatenates two jaxutils.Dataset
's does not provide any checks for both dataset batch shapes. Though this will error via jnp.concatenate
, it would be nice to write a function that checks shapes and gives a clear error message to users.
Adapt progress bar from GPJax
for general purposes and add to a .scan
submodule. Add compilation message to let users know code is still compiling.
Thanks to @fonnesbeck.
Add a check in JaxUtils.Dataset
, to ensure an error is raised when the arguments X
and y
are not 2D.
When handling complex dictionary structures, Benedict makes tasks such as indexing easier. We should transition all dictionaries to use Benedict.
A clear place where this will be helpful is in selecting and updating individual parameters. In Benedict, the syntax could be as simple as
def update_param_value(self, key: str, value: jax.Array) -> None:
self.params[key] = value
In Benedict, the key could be kernel.lengthscale
whereas with regular nested dictionaries, one would have to write something more complex to index with ['kernel']['lengthscale']
.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.