Git Product home page Git Product logo

ramsey's Introduction

Ramsey

active ci codecov Codacy quality documentation version

Probabilistic deep learning using JAX

About

Ramsey is a library for probabilistic modelling using JAX, Flax and NumPyro. It offers high quality implementations of neural processes, Gaussian processes, Bayesian time series and state-space models, clustering processes, and everything else Bayesian.

Ramsey makes use of

  • Flax`s module system for models with trainable parameters (such as neural or Gaussian processes),
  • NumPyro for models where parameters are endowed with prior distributions (such as Gaussian processes, Bayesian neural networks, ARMA models)

and is hence aimed at being fully compatible with both of them.

Example usage

You can, for instance, construct a simple neural process like this:

from jax import random as jr

from ramsey import NP
from ramsey.nn import MLP
from ramsey.data import sample_from_sine_function

def get_neural_process():
    dim = 128
    np = NP(
        decoder=MLP([dim] * 3 + [2]),
        latent_encoder=(
            MLP([dim] * 3), MLP([dim, dim * 2])
        )
    )
    return np

key = jr.PRNGKey(23)
data = sample_from_sine_function(key)

neural_process = get_neural_process()
params = neural_process.init(key, x_context=data.x, y_context=data.y, x_target=data.x)

The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically MLPs, but Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize its parameters just like in Flax.

Installation

To install from PyPI, call:

pip install ramsey

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>

See also the installation instructions for JAX, if you plan to use Ramsey on GPU/TPU.

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled "good first issue".

In order to contribute:

  1. Install Ramsey and dev dependencies via pip install -e '.[dev]',
  2. test your contribution/implementation by calling tox on the (Unix) command line before submitting a PR.

Why Ramsey

Just as the names of other probabilistic languages are inspired by researchers in the field (e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, Frank Ramsey.

ramsey's People

Contributors

dirmeier avatar mamarder avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ramsey's Issues

Use TFDS iterators

The current data iterators are slow. We could use the tf.data iterators.

Does that work with dynamic input/output dimensions?

Fix code coverage

Code coverage does not update properly (wrong link in README) and does not show coverage reports on PRs.

High-level API

Add a high-level API, that contains methods to train/predict every model easily, following GPFlow and the like....

How to dispatch?

Not sure which is better

__call__(self, method="predict", **kwargs):

or dispatch based on number of args:

__call__(self, x **kwargs):
  if "y" in kwargs:
      return self._this_method(x, kwargs["y"]):
  return return self._that_method(x)   

Check what Haiku/Flax/.. recommend.

As batch iterators

This is not good. Rewrite like this and document the fn:

 shuffle_key, rng_key = random.split(rng_key)
    shuffle_idxs = random.choice(shuffle_key, jnp.arange(n), shape=(n,), replace=False)
    if shuffle:
        data = ctor(*[el[shuffle_idxs] for _, el in enumerate(data)])

    y_train = ctor(*[el[:n_train] for el in data])
    y_val = ctor(*[el[n_train:] for el in data])
    train_rng_key, val_rng_key = random.split(rng_key)

Cleanup

  • fix codecov
  • fix documentation not showing
  • replace all old links (one in index.rst to dirmeier/...)
  • beautify code
  • gitlint ini, pylintrc, jupytext file, mypy init, pytest ini, setup cfg ,manifest, makefile, pydocstyle, bandit.yaml
  • code of conduct, contribuing,
  • codecov/coveragerc

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.