Git Product home page Git Product logo

ilabcode / pyhgf Goto Github PK

View Code? Open in Web Editor NEW
38.0 2.0 7.0 1.47 GB

PyHGF: A neural network library for predictive coding

Home Page: https://ilabcode.github.io/pyhgf/

License: GNU General Public License v3.0

Python 100.00%
bayesian-inference computational-psychiatry reinforcement-learning active-inference bayesian-filter predictive-coding hierarchical-gaussian-filter belief-propagation jax graph-neural-networks

pyhgf's People

Contributors

chmathys avatar legrandnico avatar lilianaweber avatar mkhm avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pyhgf's Issues

Examples and tutorials

We need the following examples and tutorials for v0.0.1

  • Models Comparison
  • Multi levels models
  • Parameter recovery

prediction steps should index their own node

Currently, prediction and prediction-error update functions are indexed by a pointer but update the parameters of value and volatility parents. This is probably the best solution for prediction error, but prediction should be indexed directly to the node.

This will require changes in the prediction propagation function but will simplify the update list and the codebase, as we do not need a specific prediction function for input nodes.

ensure network consistency for improper use of add_input_node

Using the add_input_node method can currently produce an invalid network structure (e.g. when setting two inputs using the same index). The method should first check if a node already exists with the same index (in the attributes dictionary), and raise an error if this is the case, so we avoid the silent updating of nodes.

use a swich condition for belief propagation

The belief propagation function consists of a for loop, which might be sub-optimal when caching the function. It could instead be coded as a switch node where relevant update functions are indexed.

remove nu from the nodes' attributes

We do not need to store nu at all time steps, this variable should be dropped from the default node parameters and not returned by the update function, as we do not need to access its previous value.

This should be done once #95 is merged.

Drag and drop HGF

Creating a graphical interface to edit node structures manually. This should ideally create an object that can be an input to both the Python and Julia HGFs.

Add more decision models in the response sub-module

We need to expand our collection of decision models to get closer to what is available in the Matlab toolbox. We are currently missing:

  • Sigmoid response model with a temperature parameter.
  • Linear regression
  • Noisy state

cleaning the update function for binary input prediction errors

The update function for binary input prediction error probably needs some cleaning.

  • currently, the expectation is computed in the prediction error step, this is probably redundant.
  • the update function does not support multiple binary inputs as children of a unique binary node (don't know if this would be useful)

Use scan to broadcast models fitting along the data axis

This for loop is slowing model fitting when many datasets/models are provided.

https://github.com/ilabcode/ghgf/blob/8fec5d71e1677e85103ae9dd79dda26056b5b06f/ghgf/distribution.py#L167

We should use scan or for loop instead, but this might require careful refactoring.

It is not clear however if we would get significant performance improvement from this refactoring, if so it might only appear when n is large. Also, it might require uniformizing the shape of the input arrays, which will also come with a cost.

add more tutorials

The following tutorials need to be expanded and rewritten:

  • Multilevel embedding of Hierarchical Gaussian Filters
  • Parameters recovery, prior predictive and posterior predictive sampling
  • The cardiac filtering HGF could include an example of seasonality decomposition through value coupling

We need new tutorials on the following topics:

  • Binary HGF with finite precision
  • The multivariate normal HGF (once #144 completed)

Use dictionaries to create node structures

In order to use multi parents/children updates, we will have to refactor the way nodes are declared and typed, going from nested tuples to a more flexible dictionary structure where nodes index to each other instead of containing their parents. This will also make the code more readable and facilitate the creation of a network API in the future.

Remove the old Python code from the package

Once v0.01 is released, we should remove the old Python version of the HGF that is not based on JAX as this one will not be more actively used. This also applies to the documentation. The code will always be available in the v0.0.1 release.

Plotting prior and posterior predictive sampling

Sampling the HGF could mean different things. For prior predictive sampling, we could try to implement:

  • Predictive sampling from the generative model using the probabilistic graph (starting from the leaves up to the roots). This would require writing sampling update steps and a sampling update sequence. The function only requires a number of time step as input argument.
  • Prior and posterior predictive sampling from a distribution over a parameter given the input data. In that case, we just run the model forward using different parameters.

Both samplings should interface with a dedicated plotting function to visualize trajectories.

Implement value coupling weigths

Value coupling should come with coupling weights (alpha) in the same way that volatility coupling has volatility weights (kappa). Those values are the parameters of the nodes.

This should also come with a proper testing suite for complex node structures where a single node can share many parents or many children.

Implement AR1 processes

Both the JAX and the based Python versions are referring to GRW and AR1 processes, but currently, only the GRW is implemented.

We need to

  • Implement AR1 nodes in the JAX functions. The process type is a node parameter. A graph can mix AR1 nodes with GRW nodes.
  • Add a section in the Theory section describing the difference between the two processes.
  • Document the functions accordingly, and remove reference to GRW and AR1 parameters that can be found in some docstring.
  • Clear reference to AR1 processes in the pure Python code.

Vectorize the edges and attributes of the networks

The data format we are currently using to represent attributes and edges is suboptimal regarding JAX transformation. Per JAX standard, PyTrees are only accessible at compile time and cannot be indexed using Tracer, that are accessible at run time. For this reason, update functions need to set the node_idx and edges variables as static arguments, which cache a new function for each node separately. This makes us lose the advantages provided by the modularity of the implementation, and large models will definitely benefit from having a uniquely cached update function.

The solution I see would be:

  • [x ] Use a dictionary of arrays to store the edges using a connectivity matrix representation.
  • [x ] Use a dictionary of arrays for each node parameter.

Update: The current status is that it is (very) difficult to write readable update functions that can pass messages with a dynamically valued number of nodes without using something like Dynamic shapes. It is under development in JAX but not yet available. Until such a feature is available it seems unreasonable to try to move the code to this implementation. We have a working example for the two-level binary HGF and the total execution time is longer than the default implementation, so it is unclear if we would really benefit from this, besides compilation time.

add a plotting function for networks using NetworkX

The plot_network function is currently using GraphViz, however, it would be convenient to have something more Matplotlib-compatible, and NetworkX seems to be the go-to library. This would also let us convert our networks into this format, which can provide many metrics tools, and can be further exported to PyVis, which could be nice also to support.

The first step would be to convert the plot_network function into a NetworkX equivalent.

Type checking in CI

mypy is currently not passing. Type checking should be cleared before v0.02.

add modularity in the update submodules

The updates functions are currently provided as a call to a rather long function call for update/PE and prediction. It should instead be organized on top of a small modular update that can be fully jitted, which should reduce compilation time and the size of the compiled functions.

allow sampling the additional parameters of a response function

The current PyMC distribution only wrap parameters that are passed to the main probabilistic network, but no to the response model. The following features need to be added to the API for creating response models:

  • Sampling the value of the parameters that are passed to the response model. This requires updating the distribution and setting a format that can work with any data format in the response function (probably using a 1d array will work).
  • Updating the tutorials to explain hierarchical modelling with different response functions (no API changes should be needed).
  • Ideally, add an example in the tutorial of a response function with a temperature parameter that would illustrate all these points.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.