allegro-jax's People
allegro-jax's Issues
question about the line 47 in allegro.py
Hi Mario, thank you very much for the project!
When I read the code, I am a bit confused about the line 47 of allegro.py, should the number of the neurons of w be the num_irreps of V rather than its mul_gcd?
Performance issues
Hi Mario,
Thanks for the implementation! I wanted to ask about the inference performance of allegro-jax, as it seems to be quite different from what I'm seeing in the official Pytorch implementation. In the lingo of Allegro's config file, here are details of the model I'm testing:
- r_max = 6.0
- num_layer = 1
- l_max = 1
- parity = SO3
- num_tensor_features = 4
- two_body_latent_mlp_latent_dimensions = [64,64]
- latent_mlp_latent_dimensions = [64, 64, 64]
- edge_eng_mlp_latent_dimensions = [32]
Here's how I'm setting up the model in the JAX implementation:
class FlaxModel(flax.linen.Module):
def setup(self):
self.model = Allegro(
avg_num_neighbors=40.0,
max_ell=1,
irreps=4*e3nn.Irreps('0e + 1o'),
mlp_n_hidden=64,
mlp_n_layers=2,
radial_cutoff=6.0,
num_layers=1
)
@flax.linen.compact
def __call__(self, graph):
node_attrs = jax.nn.one_hot(graph.nodes["species"], 3)
vectors = e3nn.IrrepsArray(
"1o",
graph.nodes["positions"][graph.receivers]
- graph.nodes["positions"][graph.senders],
)
return jnp.sum(self.model(node_attrs, vectors, graph.senders, graph.receivers).array)
I might be setting up something incorrectly here, but that's my best guess so far. Here's the performance of the models I've tried using a structure with 340 atoms:
Pytorch (with script): 3.5 ms / call
JAX (after JIT and warmup): 204 ms / call
Would you have a guess for what's slowing JAX down?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.