Git Product home page Git Product logo

hippo-jax's Introduction

HiPPO-Jax

This repo uses ideas and code that can be both found at HazyResearch/state-spaces. This code base implements the ideas and code in jax.

Installation

There are several ways to install HiPPO-Jax:

  1. Use a package manager
    1. poetry install (recommended for users)
    2. pip install from PyPI
  2. Clone repo to local machine and install from source (recommended for developers/contributors)

Ensure your CUDA drivers have been installed correctly, this will effect dependencies like Jax and PyTorch

Note: these instructions are for Linux. Commands may be different for other platforms.

Installation option 1: poetry install


  1. Install poetry:
curl -sSL https://install.python-poetry.org | python3 -
  1. Ensure python version is set to 3.8:
$ python --version
> 3.8.x
  1. Activate poetry virtual environment
poetry shell
  1. (optional) Update the dependencies to ensure dependencies work with your system
poetry update
  1. Install lock file dependencies:
poetry install --with jax,torch,mltools,jupyter,additional,dataset

Installation option 1: pip install


  1. Create and activate virtual environment
conda create --name hippo_jax python=3.8
conda activate hippo_jax
  1. Install dependencies
pip install -r requirements.txt

Installation option 2: clone repo and install from source


  1. Clone repo:

via HTTPS:

git clone https://github.com/Dana-Farber-AIOS/HiPPO-Jax.git
cd HiPPO-Jax

via SSH

git clone [email protected]:Dana-Farber-AIOS/HiPPO-Jax.git
cd HiPPO-Jax
  1. Create conda environment:
conda env create -f requirements.txt
conda activate hippo_jax
  1. Install Hippo-Jax from source:
pip install -e .

Thats it!

Examples

import jax.random as jr

key, subkey = jr.split(jr.PRNGKey(0), 2)

HiPPO Matrices

from src.models.hippo.transition import TransMatrix

N = 100
measure = "legs"

matrices = TransMatrix(N=N, measure=measure)
A = matrices.A
B = matrices.B

HiPPO (LTI) Operator

from src.models.hippo.hippo import HiPPOLTI

N = 50
T = 3
step = 1e-3
measure = "legs"
desc_val = 0.0

hippo = HiPPOLTI(
        N=N,
        step_size=step,
        GBT_alpha=desc_val,
        measure=measure,
        basis_size=T,
        unroll=False,
    )

HiPPO (LSI) Operator

from src.models.hippo.hippo import HiPPOLSI

N = 50
T = 3
step = 1e-3
L = int(T / step)
measure = "legs"
desc_val = 0.0

hippo = HiPPOLSI(
        N=N,
        max_length=L,
        step_size=step,
        GBT_alpha=desc_val,
        measure=measure,
        unroll=True,
    )

Use right out of the box, no training needed

params = hippo.init(key, f=x)
c, y = hippo.apply(params, f=x)

Contributing

HiPPO-Jax is an open source project. Consider contributing to benefit the entire community!

There are many ways to contribute to HiPPO-Jax, including:

  • Submitting bug reports
  • Submitting feature requests
  • Writing documentation and examples
  • Fixing bugs
  • Writing code for new features
  • Sharing workflows
  • Sharing trained model parameters
  • Sharing HiPPO-Jax with colleagues, students, etc.

License

The GNU GPL v2 version of HiPPO-Jax is made available via Open Source licensing. The user is free to use, modify, and distribute under the terms of the GNU General Public License version 2.

Commercial license options are available also.

Contact

Questions? Comments? Suggestions? Get in touch!

[email protected]

hippo-jax's People

Contributors

beegass avatar dana-farber avatar

Stargazers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.