Git Product home page Git Product logo

michael-psenka / manifold-linearization Goto Github PK

View Code? Open in Web Editor NEW
11.0 3.0 1.0 335.99 MB

Companion repository for the paper "Representation Learning via Manifold Flattening and Reconstruction"

Home Page: https://www.michaelpsenka.io/papers/flatteningnetwork/

License: MIT License

Jupyter Notebook 78.38% Python 21.62%
ai autoencoder deep-learning geometry machine-learning manifold-learning unsupervised-learning

manifold-linearization's Introduction

Manifold Linearization for Representation Learning

This is a research project focused on the automatic generation of autoencoders with minimal feature size, when the data is supported near an embedded submanifold. Using the geometric structure of the manifold, we can equivalently treat this problem as a manifold flattening problem when the manifold is flattenable1. See our paper, Representation Learning via Manifold Flattening and Reconstruction, for more details.

The main practical benefit: neural network autoencoders that build themselves from your data! Instead of trial-and-error guessing what architecture dimensions to use, simply run a geometric process that automatically optimizes layer width, depth, and stopping time based on your data2.

Installation

The pure FlatNet construction (training) code is available as a pip package and can be installed in the following way:

pip install flatnet

Optional installation

This repo also contains a number of testing and illustrative files to both familiarize new users with the framework and show the experiments run in the main paper. To install the appropriate remaining dependencies for this repo, first navigate to the project directory, then run the following command:

pip install -r requirements.txt

Quickstart usage

The follolwing is a simple example script using the pip package:

import torch
import flatnet

# create sine wave dataset
t = torch.linspace(0, 1, 50)
y = torch.sin(t * 2 * 3.14)

# format dataset of N points of dimension D as (N, D) matrix
X = torch.stack([t, y], dim=1)
# add noise
X = X + 0.02*torch.randn(*X.shape)

# normalize data
X = (X - X.mean(dim=0)) / X.std(dim=0)

# train encoder f and decoder g, then save a gif of the
# manifold evolution through the constructed layers of f
f, g = flatnet.train(X, n_iter=150, save_gif=True)

The script flatnet_test.py includes many example experiments to run FlatNet constructions on. To see an example experiment, simply run python flatnet_test.py in the main directory to see the flattening and reconstruction of a simple sine wave. Further experiments and options can be specified through command line arguments, managed through tyro; to see the full list of arguments, run python flatnet_test.py --help.

Directory Structure

  • flatnet_test.py: main test script, as described in above section.
  • flatnet/train.py: contains the main FlatNet construction (training) code.
  • flatnet/modules: contains code for the neural network modules used in FlatNet.
  • experiments-paper: contains scripts and results from experiments done in the paper.
  • models: contains code for various models that FlatNet was compared against in the paper.
  • tools: contains auxillery tools for evaulating the method, such as random manifold generators -- one of which being from the randman repo.

Citation

If you use this work in your research, please cite the following paper:

@article{psenka2023flatnet,
  author = {Psenka, Michael and Pai, Druv and Raman, Vishal and Sastry, Shankar and Ma, Yi},
  title = {Representation Learning via Manifold Flattening and Reconstruction},
  year = {2023},
  eprint = {2305.01777},
  url = {https://arxiv.org/abs/2305.01777},
}

We hope that you find this project useful. If you have any questions or suggestions, please feel free to contact us.

Copyright notice

THIS SOFTWARE AND/OR DATA WAS DEPOSITED IN THE BAIR OPEN RESEARCH COMMONS REPOSITORY ON 5/3/2023.

Footnotes

  1. Geometric note: while flattenability is not general, there are some heuristic reasons we can motivate this assumption for real world data. For example, if a dataset permits a VAE-like autoencoder, where samples from the data distribution can be generated via a standard Gaussian in the latent space, then the samples lie close in probability to a flattenable manifold, as this VAE has constructed a single-chart atlas.

  2. Of course, there is no free lunch. There are a still a few scalar hyperparameters that need to be set, but (a) this search space is drastically lower-dimensional than choosing the entire architecture structure, and (b) these are intuitive to choose and are incredibly robust (for example, these were hardly changed between the experiments on synthetic manifolds and MNIST).

manifold-linearization's People

Contributors

druvpai avatar michael-psenka avatar vraman23 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

yl602019618

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.