Git Product home page Git Product logo

augmented-neural-odes's Introduction

Augmented Neural ODEs

This repo contains code for the paper Augmented Neural ODEs (2019).

Examples

Requirements

The requirements that can be directly installed from PyPi can be found in requirements.txt. This code also builds on the awesome torchdiffeq library, which provides various ODE solvers on GPU. Instructions for installing torchdiffeq can be found in this repo.

Usage

The usage pattern is simple:

# ... Load some data ...

import torch
from anode.conv_models import ConvODENet
from anode.models import ODENet
from anode.training import Trainer

# Instantiate a model
# For regular data...
anode = ODENet(device, data_dim=2, hidden_dim=16, augment_dim=1)
# ... or for images
anode = ConvODENet(device, img_size=(1, 28, 28), num_filters=32, augment_dim=1)

# Instantiate an optimizer and a trainer
optimizer = torch.optim.Adam(anode.parameters(), lr=1e-3)
trainer = Trainer(anode, optimizer, device)

# Train model on your dataloader
trainer.train(dataloader, num_epochs=10)

More detailed examples and tutorials can be found in the augmented-neural-ode-example.ipynb and vector-field-visualizations.ipynb notebooks.

Running experiments

To run a large number of repeated experiments on toy datasets, use the following

python main_experiment.py config.json

where the specifications for the experiment can be found in config.json. This will log all the information about the experiments and generate plots for losses, NFEs and so on.

Running experiments on image datasets

To run large experiments on image datasets, use the following

python main_experiment_img.py config_img.json

where the specifications for the experiment can be found in config_img.json.

Demos

We also provide two demo notebooks that show how to reproduce some of the results and figures from the paper.

Vector fields

The vector-field-visualizations.ipynb notebook contains a demo and tutorial for reproducing the experiments on 1D ODE flows in the paper.

Augmented Neural ODEs

The augmented-neural-ode-example.ipynb notebook contains a demo and tutorial for reproducing the experiments comparing Neural ODEs and Augmented Neural ODEs on simple 2D functions.

Data

The MNIST and CIFAR10 datasets can be directly downloaded using torchvision (this will happen automatically if you run the code, unless you already have those datasets downloaded). To run experiments on ImageNet, you will need to download the data from the Tiny ImageNet website.

Citing

If you find this code useful in your research, consider citing with

@article{dupont2019augmented,
  title={Augmented Neural ODEs},
  author={Dupont, Emilien and Doucet, Arnaud and Teh, Yee Whye},
  journal={arXiv preprint arXiv:1904.01681},
  year={2019}
}

License

MIT

augmented-neural-odes's People

Contributors

emiliendupont avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

augmented-neural-odes's Issues

running augmented-neural-ode-example.ipynb error


NotImplementedError Traceback (most recent call last)
in
1 from viz.plots import multi_feature_plt
2
----> 3 multi_feature_plt(feature_history, targets)

~/Box/recent_ML/augmented-neural-odes/viz/plots.py in multi_feature_plt(features, targets, save_fig)
312 labelbottom=False, right=False, left=False,
313 labelleft=False)
--> 314 ax.set_aspect(get_square_aspect_ratio(ax))
315
316 fig.subplots_adjust(wspace=0.01)

~/opt/anaconda3/lib/python3.7/site-packages/mpl_toolkits/mplot3d/axes3d.py in set_aspect(self, aspect, adjustable, anchor, share)
322 if aspect != 'auto':
323 raise NotImplementedError(
--> 324 "Axes3D currently only supports the aspect argument "
325 f"'auto'. You passed in {aspect!r}."
326 )

NotImplementedError: Axes3D currently only supports the aspect argument 'auto'. You passed in 0.9608689503197515.

Confusion about eval_times

Hello again,
I'm pretty new to neural ODEs and differential equations in general. I hope I'm not wasting your time as this may be a sort of dumb question.

I am trying to learn the functions that govern the light curves of astronomical objects for my internship according to this data spec.
Eventually, I am going to try and classify the light curves (and use the other parameters they give), but for now I want to understand just how to learn the function for brightness over time (for a given object). So this I believe is going to be 1-Dimensional data per time point.

Right now I have a DataLoader with a getitem() that gives me a ([352, 2], class id) tensor where the first axis is the time values corresponding to the brightness (flux) at that time. I was going to do it this way so that I could try to learn the light curve functions in batches (my intuition is that the ODE might be able to better estimate the curves of one object based on how it evaluated others).

I see that for the Trainer class, you provide an ODENet instantiation as an argument. I was wondering how you are able to then provide the ODEBlock with the eval_times for the flux data in the forward function, as the ODENet.forward() function does not take in eval_times as an argument.

So I've probably demonstrated a couple misunderstandings. For one, is it just going to be flawed methodology to try and chunk up my data according to individual astronomical objects? ie: should I just train one huge ODE for the entire time series? (this seems wrong if I ignore the object id and other data, since the fluxes of one object aren't related to those of another).

TL;DR: I'm not exactly sure how to provide the solver eval_times (from a DataLoader) in ODENet's forward function / training.Trainer class.

Thanks again.

Import error for ConvODENet but not ODENet

I was following the README instructions, however, I get an ImportError "cannot import name 'ConvODENet'.

import torch
from anode.models import ODENet, ConvODENet

I do not have a solution, but some preliminary google's indicate that there might be a circular dependency stackoverflow post.

Visualization of high-dimension vector field

Hi,

Thanks for deliverying this amazing work!

I'm think about visualization the vector field of my model, however, it is a high dimensional one with shape [16, 16, 256].
Do you have any idea abou how to visualization it?

Cheers,
Hilbert

config files for training cifar10

Hello,

Thank you very much for the nice work. I am playing around your code and was wondering would it be possible for you to share the config files to train the NODE or ANODE on cifar10? I am able to get good accuracy on mnist, but for cifar using the same configs as you mentioned here, I can only get around 55%. Is that normal?

About acc performance on CIFAR-10 and training time.

Hi guys, your works and codes are great.
However, I am a little confused about why there were no accuracy comparsion between NODE and your ANODE on CIFAR-10 since you did tests with both methods.
I am trying to do the test on CIFAR-10, and I noticed that training time per epoch of ANODE increasing very fast rather than NODE, is that how things should going on with ANODE?

Regards

multi_feature_plt aspect error

The cell

from viz.plots import multi_feature_plt
multi_feature_plt(feature_history, targets)

raises

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-26-0ca89ec1decb> in <module>
      1 from viz.plots import multi_feature_plt
      2 
----> 3 multi_feature_plt(feature_history, targets)

~/Code/augmented-neural-odes/viz/plots.py in multi_feature_plt(features, targets, save_fig)
    312                            labelbottom=False, right=False, left=False,
    313                            labelleft=False)
--> 314             ax.set_aspect(get_square_aspect_ratio(ax))
    315 
    316     fig.subplots_adjust(wspace=0.01)

/usr/lib/python3.8/site-packages/matplotlib/axes/_base.py in set_aspect(self, aspect, adjustable, anchor, share)
   1278 
   1279         if (not cbook._str_equal(aspect, 'auto')) and self.name == '3d':
-> 1280             raise NotImplementedError(
   1281                 'It is not currently possible to manually set the aspect '
   1282                 'on 3D axes')

NotImplementedError: It is not currently possible to manually set the aspect on 3D axes

so the notebook isn't reproduced.

scipy 1.4.1
numpy 1.18.1
torch 1.4.0
torchvision 0.5.0a0
matplotlib 3.1.3

How you type interesting symbols in the commit?

Hi Emilien,

Besides your excellent work, I also like the symbols you type in the commit message, like the the 'bug'.

Could you kindly let me know how you type it out so I can follow it.

Thnak you
GUANGYUAN

No module named 'anode'.

I try to run experiments_img.py but got the ModuleNotFoundError: No module named 'anode'.

I cannot even simplily pip install it from anaconda, as shown "Could not find a version that satisfies the requirement anode".

If anyone can give some suggestion would very helpful.

Not making it to run the code for images (Request!)

Hello! Thank you so much for the work and providing codes.
I am new to this field. I tried the code on Google Colab following the pattern you have provided, but my efforts went in vain (Mainly, the dataloader.). I would be very thankful if you provide any notebook for any of the image dataset.
Regards

architectures compared to neural ode

Hi, @EmilienDupont , thanks for the code implementation!

I am concerned about the structure implemented in this repo, actually, the architecture in the orginal Neural ODE (chen et al. ) has a feature extractor that consists of purely convolutional layers followed by the ODE represetation transformation layer and the final classification layer, while in this repo, only ode layers are present,

so in your experiments, when you compare to the neural ode, the neural ode you used actually does not have the feature extraction layer? but you only remove the concatenated channels (zeros) compared to the augmented neural ode?

I personally do not think if of a fair comparison, since removing the feature extraction layer will affect the classification model? A better comparison will be evaluting the two models when a proper feature extractor is present.

Please correct me if I am wrong

Regarding Concentric Sphere Experiment for ODE

Hey Emilien!

In my understanding, if I run main_experiment.py as is in the current repo, I would train ResNet, ANODE, and NODE on two datasets---one being concentric spheres. I also understand that in your paper you proved that NODE cannot represent this sphere dataset. Why is it that the loss for NODE is quite close to the other two (near zero)? Is that just a display of NODE "approximately solving the problem"? Thanks!

Solving parameterized ODE for prediction

Once I have trained the ODEFun, how can I make future time prediction as an initial value problem if the network does not accept first and second derivative initial values?

Or is this project only considering first order ODEs?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.