Git Product home page Git Product logo

neural-processes's Introduction

Neural Processes

Pytorch implementation of Neural Processes. This repo follows the best practices defined in Empirical Evaluation of Neural Process Objectives.

Examples

Function Regression Image inpainting

Usage

Simple example of training a neural process on functions or images.

import torch
from neural_process import NeuralProcess, NeuralProcessImg
from training import NeuralProcessTrainer

# Define neural process for functions...
neuralprocess = NeuralProcess(x_dim=1, y_dim=1, r_dim=10, z_dim=10, h_dim=10)

# ...or for images
neuralprocess = NeuralProcessImg(img_size=(3, 32, 32), r_dim=128, z_dim=128,
                                 h_dim=128)

# Define optimizer and trainer
optimizer = torch.optim.Adam(neuralprocess.parameters(), lr=3e-4)
np_trainer = NeuralProcessTrainer(device, neuralprocess, optimizer,
                                  num_context_range=(3, 20),
                                  num_extra_target_range=(5, 10))

# Train on your data
np_trainer.train(data_loader, epochs=30)

1D functions

For a detailed tutorial on training and using neural processes on 1d functions, see the notebook example-1d.ipynb.

Images

To train an image model, use python main_experiment.py config.json. This will log information about training and save model weights.

For a detailed tutorial on how to load a trained model and how to use neural processes for inpainting, see the notebook example-img. Trained models for MNIST and CelebA are also provided in the trained_models folder.

Note, to train on CelebA you will have to download the data from here.

Acknowledgements

Several people at OxCSML helped me understand various aspects of neural processes, especially Kaspar Martens, Jin Xu, Jef Ton and Hyunjik Kim.

Useful resources:

License

MIT

neural-processes's People

Contributors

emiliendupont avatar philippeitis avatar sauravmaheshkar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

neural-processes's Issues

NP architecture is described differently in "Attentive Neural Processes" paper

Thank you for the wonderful implementation! I have a high-level question:

The code looks like it is following the latent variable formulation of Conditional Neural Processes (CNP), which also sounds exactly like the description in section 2.4 of the NP paper:
latent CNP

  • Encoding context pairs (x, y) as representation r in xy_to_r
  • Transforming r to latent parameters of z Normal in r_to_mu_sigma
  • Decoding pairs (x, z) to parameters of y Normal in xz_to_y

However, the NP model described in the "Attentive Neural Processes" (ANP) paper has a few differences:
NP

  1. The Encoder (self.xy_to_r) is split into two different encoders:
  • One produces the global representation r (deterministic path)
  • The other produces a separate code s that parameterises latent z, instead of directly parameterising z with r (latent path)
  1. The Decoder also accepts r, not only target location x and z (so it should be xzr_to_y)

I wonder what the reasoning behind the differences are, and why the NP model described in the NP paper differs from the NP model described in the ANP paper. I suspect difference 1 only matters because they introduced cross-attention in the deterministic path and not the latent path.

The interpretation of the latent path is that z gives rise to correlations in the marginal distribution of the target predictions y_T , modelling the global structure of the stochastic process realisation, whereas the deterministic path models the fine-grained local structure.

As for difference 2, it probably doesn't matter too much?

The official open-source implementation for Neural Processes also uses the architecture described in the ANP paper, not the NP paper.

Thanks for reading!
P.S. Do you know of any implementation of ANP for 2D functions?

[Feature Request ๐Ÿš€] Add a `CITATION.cff`

Github recently released a new feature where repository owners can add a CITATION.cff file making it easy for others to cite the repository. Adding a CITATION.cff would make the attribution process very easy for others (myself included ๐Ÿ˜… ) who want to cite this work.

Prediction with expanded x axis is too confident?

This is not really an issue with the code, but something that I expected to work differently with neural processes and I would value your opinion.

When making a prediction, if I expand the x axis to a wider range I would expect the functions to become less certain and more varied. When I tried expanding the axis this is what I actually saw happen...

np-prediction

I am having a hard time justifying why this would be the case that the neural process remains confident in its prediction and every function just goes off in the same direction. Do you have any insight?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.