Git Product home page Git Product logo

neural-processes's Introduction

Neural Processes

Replication of the Conditional Neural Processes paper by Marta Garnelo et al., 2018 [arXiv] and the Neural Processes paper by Marta Garnelo et al, 2018 [arXiv]. The model and data pipeline were implemented using Tensorflow 2.10.

Code released in complement of the report and the poster.

Project for the Advanced Machine Learning module, MPhil in MLMI, University of Cambridge.

William Baker, Alexandra Shaw, Tony Wu

1. Introduction

While neural networks excel at function approximation, Gaussian Processes (GPs) address different challenges such as uncertainty prediction, continuous learning, and the ability to deal with data scarcity. Therefore, each model is only suited for a restricted spectrum of tasks that strongly depends on the nature of available data.

Neural Processes use neural networks to encode distributions over functions to approximate the dis- tributions over functions given by stochastic processes like GPs. This allows for efficient inference and scalability to large datasets. The performance of these models will be evaluated on 1D-regression and image completion to demonstrate visually how they learn distributions over complex functions.

np_poster_diagram

Figure 1: Comparison between Gaussian Process, Neural Network and Neural Process

2. Instructions

  1. Using an environment with python 3.10.8, install modules using:

    pip install -r requirements.txt
    
  2. To create, train, and evaluate instances of neural processes, run the train.py script. Use python train.py --help to display its arguments. In particular, specify the --model flag with CNP, HNP, LNP, or HNPC to choose the used model. Example:

    python train.py --task regression --model cnp --epochs 50 --batch 128
  3. The model will be saved in the checkpoints directory.

2. Data pipeline

data_pipeline

Figure 2: Data pipeline and examples of generated data for Neural Processes

Contrarily to neural networks which predict functions, NPs predict distributions of functions. For this reason, we have built a specific data loader class using the tf.data API to produce the examples for both training and validation. Note that the class definitions for data generators can be found in the dataloader module directory.

3. Models

architecture

Figure 3: Architecture diagram of CNP, LNP, and HNP

CNP, LNP and HNP all have a similar encoder-decoder architecture. They have been implemented using classes that inherit from tf.keras.Model. Thus, training with the tf.data API is straightforward and optimized.

4. Experiments

Training can either be conducted in a interactive session (iPython) with arguments set in the section beginning ln 40 (Training parameters). Or by commenting section ln40 and uncommenting section ln 25 (Parse Training parameters) the terminal and it's cmd arguments can be used.

4.1. Regression training

python train.py --task regression

Example of obtained result:

1d_regression-fixed_kernel

Figure 4: Comparison between GP, CNP and LNP on the 1D-regression task (fixed kernel parameter)

4.2. MNIST training

python train.py --task mnist

Example of obtained result:

mnist-image_completion

Figure 5: : CNP pixel mean and variance predictions on images from MNIST

4.3. CelebA training

Instructions: Download the aligned and cropped images from here and extract files in the ./data directory.

python train.py --task celeb

Example of obtained result:

celebA-image_completion

Figure 6: CNP pixel mean and variance predictions on images from CelebA

4.4. Extension: HNP and HNPC

Objective: Combine the deterministic link between the context representations (used by CNP) with the non-deterministic link from the latent space representation space (used by LNP) to produce a model with a richer embedding space.

extension-hnp_hnpc

Figure 7: Latent Variable Distribution - Mean and Standard Deviation Statistics during training.

5. Appendix

To go further, read the poster and the report that can be found in the report folder of this repository.

poster-thumbnail

Figure 8: Miniature of the CNP/LNP poster

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.