kmheckel / spyx Goto Github PK
View Code? Open in Web Editor NEWSpyx: Spiking Neural Networks in JAX
Home Page: https://spyx.readthedocs.io/en/latest/
License: MIT License
Spyx: Spiking Neural Networks in JAX
Home Page: https://spyx.readthedocs.io/en/latest/
License: MIT License
If the batch size perfectly divides the dataset then the shuffling function tool will result in a leading axis of zero, inducing a bug.
Could implement EXODUS to drastically cut gradient calculation time for certain classes of models and improve training speed even further.
https://www.frontiersin.org/articles/10.3389/fnins.2023.1110444/full
Add Stochastic Parallelizable Spiking Neuron model.
Torch implementation:
https://github.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/blob/main/neurons/spsn.py
Use Chex to add type checking to the library in spots not already handled by Haiku.
It would be awesome to implement the ability to train spiking phasor networks in Spyx. JAX has the ability to support complex-valued autodifferentiation, so this should be possible. Doing this would enable extremely fast training by eliminating recurrence when learning before converting to a recurrent architecture for inference.
It'd be great to provide some template layers within the library for more complex SNN architectures.
I think it's appropriate for Spyx to adopt a functional approach as much as possible, so using functions that return other functions (higher order functions) rather than classes would be fitting since it helps guide the user towards JIT compiling.
The current documentation is passable but it could always be better.
Create functions to load/save models to HDF5 under the Neuromorphic Intermediate Representation standard to facilitate cross-platform deployment.
https://nnir.readthedocs.io/en/latest/what.html
Implement spyx.nir.to_nir()
Implement spyx.nir.from_nir()
Support feed-forward network import
Support ConvNet import
Support explicitly recurrent import
Implement FFN exporting
Implement CSNN exporting
Implement RSNN exporting
Currently if the user specifies the inverse time constant/beta value it will not be tracked in the PyTree for the network, making the layer invisible when trying to export it to NIR for cross platform function.
Each neuron model needs an "else" clause that calls hk.get_parameter() but with the init argument set to the user specified value in order to fix this.
See the fixed LI neuron as an example of what needs to be done for the other neuron models (except for IF... This will need a different solution/approach to be visible.)
Make spyx.optimize:
https://jax.readthedocs.io/en/latest/pallas/design.html
Create pallas kernel for LIF neurons and investigate to see if it's faster than just static unrolling the LIF as defined in Spyx.
There might be a bug with respect to time_major = False or True that could present incorrect results since scanning over one axis is faster than scanning over the other (time steps vs channel dim...)
Test and make an example using AEStream to load data into Spyx.
An interesting question is whether AEStream can write to JAX tensors being used by Spyx in a compiled loop?
I am experimenting cartpole with spiking neural network with spyx, but i got the following error. could you please assist?
AttributeError Traceback (most recent call last)
in <cell line: 4>()
2 init_state = (jnp.zeros(64), jnp.zeros(2))
3 policy = hk.without_apply_rng(hk.transform(controller))
----> 4 policy_params = policy.init(rng=key, x=adapter(obs), state=init_state)
2 frames
in controller(x, state)
7 core = hk.DeepRNN([
8 hk.Linear(64, with_bias=False),
----> 9 snn.LIF(64, beta=0.8, activation=spyx.axn.Axon()),
10 hk.Linear(2, with_bias=False),
11 snn.LI(2)
AttributeError: module 'spyx.axn' has no attribute 'Axon'
Right now there's no test cases to verify that changes to the code base work for actually training models/there's no way to detect if changes to other packages might break the library.
Right now the Sphinx documentation isn't showing the members of each submodule.
I would love to not need to install torchvision and PyTorch just to load data into Spyx, as Torch often lags JAX in terms of CUDA support.
It's a common mistake to get the output shape wrong compared to the target label shape, adding some kind of check in the loss and acc functions to throw an error would be useful.
Right now the dataloading part of the library uses torch dataloaders and torchvision. This increases the dependency requirements and makes installing less convenient, so it would be nice to use a more JAX native approach if possible.
This might be a potential solution:
https://github.com/birkhoffg/jax-dataloader
Ideally, make it that spyx.data is only installed explicitly as "pip install spyx[data]"
It would be cool to see how synthetic gradients could fit into the SNN training schema:
https://arxiv.org/abs/1608.05343
https://greydanus.github.io/2016/11/26/synthetic-gradients/
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.