Git Product home page Git Product logo

e3nn-jax's Introduction

e3nn-jax

import e3nn_jax as e3nn

# Create a random array made of a scalar (0e) and a vector (1o)
array = e3nn.normal("0e + 1o", jax.random.PRNGKey(0))

print(array)  
# 1x0e+1x1o [ 1.8160863  -0.75488514  0.33988908 -0.53483534]

# Compute the norms
norms = e3nn.norm(array)
print(norms)
# 1x0e+1x0e [1.8160863  0.98560894]

# Compute the norm of the full array
total_norm = e3nn.norm(array, per_irrep=False)
print(total_norm)
# 1x0e [2.0662997]

# Compute the tensor product of the array with itself
tp = e3nn.tensor_square(array)
print(tp)
# 2x0e+1x1o+1x2e
# [ 1.9041989   0.25082085 -1.3709364   0.61726785 -0.97130704  0.40373924
#  -0.25657722 -0.18037902 -0.18178469 -0.14190137]

🚀 44% faster than pytorch*

*Speed comparison done with a full model (MACE) during training (revMD-17) on a GPU (NVIDIA RTX A5000)

Please always check the CHANGELOG for breaking changes.

Installation

To install the latest released version:

pip install --upgrade e3nn-jax

To install the latest GitHub version:

pip install git+https://github.com/e3nn/e3nn-jax.git

Need Help?

Ask a question in the discussions tab.

What is different from the PyTorch version?

The main difference is the presence of the class IrrepsArray. IrrepsArray contains the irreps (Irreps) along with the data array.

Citing

  • Euclidean Neural Networks
@misc{thomas2018tensorfieldnetworksrotation,
      title={Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds}, 
      author={Nathaniel Thomas and Tess Smidt and Steven Kearnes and Lusann Yang and Li Li and Kai Kohlhoff and Patrick Riley},
      year={2018},
      eprint={1802.08219},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/1802.08219}, 
}

@misc{weiler20183dsteerablecnnslearning,
      title={3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data}, 
      author={Maurice Weiler and Mario Geiger and Max Welling and Wouter Boomsma and Taco Cohen},
      year={2018},
      eprint={1807.02547},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/1807.02547}, 
}

@misc{kondor2018clebschgordannetsfullyfourier,
      title={Clebsch-Gordan Nets: a Fully Fourier Space Spherical Convolutional Neural Network}, 
      author={Risi Kondor and Zhen Lin and Shubhendu Trivedi},
      year={2018},
      eprint={1806.09231},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/1806.09231}, 
}
  • e3nn
@misc{e3nn_paper,
    doi = {10.48550/ARXIV.2207.09453},
    url = {https://arxiv.org/abs/2207.09453},
    author = {Geiger, Mario and Smidt, Tess},
    keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), Neural and Evolutionary Computing (cs.NE), FOS: Computer and information sciences, FOS: Computer and information sciences},
    title = {e3nn: Euclidean Neural Networks},
    publisher = {arXiv},
    year = {2022},
    copyright = {Creative Commons Attribution 4.0 International}
}

@software{e3nn,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Alby M. and
                  Benjamin Kurt Miller and
                  Wouter Boomsma and
                  Bradley Dice and
                  Kostiantyn Lapchevskyi and
                  Maurice Weiler and
                  Michał Tyszkiewicz and
                  Simon Batzner and
                  Dylan Madisetti and
                  Martin Uhrin and
                  Jes Frellsen and
                  Nuri Jung and
                  Sophia Sanborn and
                  Mingjian Wen and
                  Josh Rackers and
                  Marcel Rød and
                  Michael Bailey},
  title        = {Euclidean neural networks: e3nn},
  month        = apr,
  year         = 2022,
  publisher    = {Zenodo},
  version      = {0.5.0},
  doi          = {10.5281/zenodo.6459381},
  url          = {https://doi.org/10.5281/zenodo.6459381}
}

e3nn-jax's People

Contributors

ameya98 avatar jamaliki avatar mariogeiger avatar mitkotak avatar pabloferz avatar sauravmaheshkar avatar songk42 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

e3nn-jax's Issues

Ensure consistent code formatting.

To make it easier to read (and contribute), I propose using the black code formatting tool.
Very easy to install:

pip install black

then run:

black file.py

to format file.py.

Please wrap `clebsch_gordan` inside `functools.cache`

Right now, CG coefficients are computed not once but thrice at every TP call, and every inhomogeneous IrrepsArray.__mul__ call

I'm not sure to understand the 3x part completely, though I guess it has to do with the flax module internals wrapping elementwise_tensor_product

# test.py
import jax
import jax.numpy as np

import e3nn_jax as e3nn

# some dummy e3-array
a = e3nn.IrrepsArray.zeros("8x0e + 8x1o + 8x2e", (1,))
# some scalar array
n = a.irreps.num_irreps
b = e3nn.IrrepsArray.zeros(f"{n}x0e", (1,))

# __mul__ calls elementwise_tensor_product
for i in range(3):
    print(a * b)

Outputs:

I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
8x0e+8x1o+8x2e
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
8x0e+8x1o+8x2e
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
8x0e+8x1o+8x2e
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

See #71
Thanks!

Add instructions for local development.

Should cover:

  • How to create a virtual environment?
  • What dependencies to install?
  • How to install e3nn-jax from local source?
  • How to run tests?

Currently, I get this error when running tests/dropout_test.py (after my attempt at installing from local source):

_____________________________________________________________________ ERROR collecting tests/dropout_test.py _______________________________________________________________________
tests/dropout_test.py:3: in <module>
    from e3nn_jax import Dropout, Irreps, IrrepsData
e3nn_jax/__init__.py:35: in <module>
    from ._irreps import Irrep, Irreps, IrrepsData
e3nn_jax/_irreps.py:7: in <module>
    import chex
../../miniconda3/lib/python3.8/site-packages/chex/__init__.py:17: in <module>
    from chex._src.asserts import assert_axis_dimension
../../miniconda3/lib/python3.8/site-packages/chex/_src/asserts.py:26: in <module>
    from chex._src import asserts_internal as _ai
../../miniconda3/lib/python3.8/site-packages/chex/_src/asserts_internal.py:32: in <module>
    from chex._src import pytypes
../../miniconda3/lib/python3.8/site-packages/chex/_src/pytypes.py:40: in <module>
    CpuDevice = jax.lib.xla_extension.CpuDevice
E   AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'

TensorProduct and FullyConnectedTensorProduct in e3nn_jax

Hi, Do we have equivalent of e3nn.o3.TensorProduct and e3nn.o3.FullyConnectedTensorProduct [torch version] in e3nn_jax? I am attaching the code here for your reference, which is being taken from here

`self.sc = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_out)

    self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in)

    irreps_mid = []
    instructions = []
    for i, (mul, ir_in) in enumerate(self.irreps_in):
        for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
            for ir_out in ir_in * ir_edge:
                if ir_out in self.irreps_out:
                    k = len(irreps_mid)
                    irreps_mid.append((mul, ir_out))
                    instructions.append((i, j, k, 'uvu', True))
    irreps_mid = o3.Irreps(irreps_mid)
    irreps_mid, p, _ = irreps_mid.sort()

    instructions = [
        (i_1, i_2, p[i_out], mode, train)
        for i_1, i_2, i_out, mode, train in instructions
    ]

    tp = TensorProduct(
        self.irreps_in,
        self.irreps_edge_attr,
        irreps_mid,
        instructions,
        internal_weights=False,
        shared_weights=False,
    )
    self.fc = FullyConnectedNet([number_of_basis] + radial_layers * [radial_neurons] + [tp.weight_numel], torch.nn.functional.silu)
    self.tp = tp

    self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out)

`

Thanks for your help.

Batchnorm

I implemented a batch norm based on the e3nn repository. I can paste the code here or if you give me access to make a branch, I can create a PR. @mariogeiger

Error when upgrading to `0.8.0`

Hi, and thanks again for this great package!

I recently upgraded to version 0.8.0 and when simply loading e3nn_jax, the decorator overload_for_irreps_without_array is throwing the following error.

E   TypeError: <Signature (input1: e3nn_jax._irreps_array.IrrepsArray, input2: e3nn_jax._irreps_array.IrrepsArray, filter_ir_out: Optional[List[e3nn_jax._irreps.Irrep]] = None, irrep_normalization: Optional[str] = None)> is not a callable object

Obviously replacing the decorator with

def overload_for_irreps_without_array(
    irrepsarray_argnums=None, irrepsarray_argnames=None, shape=()
):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        return wrapper

    return decorator 

works but isn't satisfying.

I have

Python 3.9.13 (main, May 18 2022, 00:00:00)
[GCC 11.3.1 20220421 (Red Hat 11.3.1-2)] on linux
absl-py==1.0.0
aiohttp==3.8.1
aiosignal==1.2.0
antlr4-python3-runtime==4.8
asttokens==2.0.5
astunparse==1.6.3
async-timeout==4.0.2
attrs==21.4.0
autograd==1.3
backcall==0.2.0
black==22.1.0
blessings==1.7
cachetools==5.0.0
certifi==2021.10.8
charset-normalizer==2.0.12
chex==0.1.3
click==8.0.4
cloudpickle==2.0.0
colorlog==6.6.0
contextlib2==21.6.0
cycler==0.11.0
debugpy==1.5.1
decorator==5.1.1
Deprecated==1.2.13
dill==0.3.4
distrax==0.1.2
dm-haiku==0.0.7
dm-tree==0.1.6
docker-pycreds==0.4.0
e3nn-jax==0.8.0
einops==0.4.1
entrypoints==0.4
etils==0.6.0
executing==0.8.3
flatbuffers==2.0
fonttools==4.30.0
frozenlist==1.3.0
fsspec==2022.5.0
future==0.18.2
gast==0.5.3
-e git+ssh://[email protected]/oxcsml/geomstats.git@78111276c9b2f98bdc24826f0c2d8729a4bbcca0#egg=geomstats&subdirectory=../../geomstats
gitdb==4.0.9
GitPython==3.1.27
google-auth==2.6.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.55.0
gpflow @ git+https://github.com/GPflow/GPflow.git@7b989c804e280ca71319ac3899c4f359f1a795ed
GPJax @ git+https://github.com/thomaspinder/GPJax.git@76527de93499dcd00523c5f88fd6d8ed827d1e88
gpustat==0.6.0
grpcio==1.44.0
h5py==3.6.0
hydra-colorlog==1.1.0
hydra-core==1.1.1
hydra-joblib-launcher==1.1.5
hydra-submitit-launcher @ git+https://github.com/emilemathieu/hydra.git@7c41ae8f080d577efccf902aa9c955d7c9a2d864#subdirectory=plugins/hydra_submitit_launcher
idna==3.3
importlib-metadata==4.11.2
importlib-resources==5.8.0
iniconfig==1.1.1
install==1.3.5
ipykernel==6.9.1
ipython==8.1.1
isort==5.10.1
jax==0.3.5
jaxlib==0.3.5+cuda11.cudnn805
jaxtyping==0.0.2
jedi==0.18.1
Jinja2==3.1.2
jmp==0.0.2
joblib==1.1.0
jupyter-client==7.1.2
jupyter-core==4.9.2
keras==2.8.0
Keras-Preprocessing==1.1.2
kiwisolver==1.3.2
lark==1.1.2
libclang==13.0.0
Markdown==3.3.6
MarkupSafe==2.1.1
matplotlib==3.5.1
matplotlib-inline==0.1.3
ml-collections==0.1.0
mpmath==1.2.1
multidict==6.0.2
multipledispatch==0.6.0
mypy-extensions==0.4.3
nest-asyncio==1.5.4
numpy==1.22.3
nvidia-ml-py3==7.352.0
oauthlib==3.2.0
omegaconf==2.1.1
opt-einsum==3.3.0
optax==0.1.3
packaging==21.3
pandas==1.4.1
parso==0.8.3
pathspec==0.9.0
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.0.1
platformdirs==2.5.1
pluggy==1.0.0
promise==2.3
prompt-toolkit==3.0.28
protobuf==3.19.0
psutil==5.9.0
ptyprocess==0.7.0
pure-eval==0.2.2
py==1.11.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.11.2
pyparsing==3.0.7
pytest==7.1.2
python-dateutil==2.8.2
pytz==2021.3
PyYAML==6.0
pyzmq==22.3.0
requests==2.27.1
requests-oauthlib==1.3.1
rsa==4.8
scikit-learn==1.0.2
scipy==1.8.0
-e git+ssh://[email protected]/oxcsml/score-sde.git@d3536f1167762eba27e9bd5b404376701d8441c6#egg=score_sde
seaborn==0.11.2
sentry-sdk==1.5.7
setGPU==0.0.7
setproctitle==1.2.2
shortuuid==1.0.8
six==1.16.0
smmap==5.0.0
stack-data==0.2.0
submitit @ git+https://github.com/emilemathieu/submitit.git@aeccb447a1e1a33ec60749c70594eef3d888b16d
sympy==1.10.1
tabulate==0.8.9
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.8.1
tensorflow-datasets==4.5.2
tensorflow-estimator==2.8.0
tensorflow-io-gcs-filesystem==0.24.0
tensorflow-metadata==1.7.0
tensorflow-probability==0.16.0
termcolor==1.1.0
tf-estimator-nightly==2.8.0.dev2021122109
threadpoolctl==3.1.0
tomli==2.0.1
toolz==0.11.2
tornado==6.1
tqdm==4.63.0
traitlets==5.1.1
typeguard==2.13.3
typing_extensions==4.3.0
urllib3==1.26.8
wandb==0.12.11
wcwidth==0.2.5
Werkzeug==2.0.3
wrapt==1.14.0
yarl==1.7.2
yaspin==2.1.0
zipp==3.7.0

"tensor_product_with_spherical_harmonics" slower + uses more memory

I compared (eSCN) tensor_product_with_spherical_harmonics vs (Manual) tensor_product with spherical harmonics for speed and memory use. Despite the claims of the eSCN paper, I found Manual to be faster and have lower memory use. Here is my code to test speed:

import jax
import e3nn_jax as e3nn
import time


B = 128
n_edges = 600
leading_shape = (B, n_edges)
degree = 5
input_irreps = "128x0e + 128x1o"
n_iters = 100

sh_irreps = e3nn.Irreps.spherical_harmonics(degree)

@jax.jit
def manual(inputs, vectors):
    out = e3nn.tensor_product(input1=inputs,
                              input2=e3nn.spherical_harmonics(input=vectors, irreps_out=sh_irreps, normalize=True),
                              filter_ir_out=None)
    return out

@jax.jit
def escn(inputs, vectors):
    out = e3nn.tensor_product_with_spherical_harmonics(input=inputs, 
                                                       vector=vectors, 
                                                       degree=degree)
    return out

inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(0), leading_shape=leading_shape)
vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(0), leading_shape=leading_shape)

out1 = manual(inputs, vectors)
print(out1.irreps, out1.shape, out1.array.sum())
del out1
out2 = escn(inputs, vectors)
print(out2.irreps, out2.shape, out2.array.sum())
del out2
print("----------------------------------")

def run_manual():
    t0 = time.time()
    for i in range(1, n_iters+1):
        inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(i), leading_shape=leading_shape)
        vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(i), leading_shape=leading_shape)
        out = manual(inputs, vectors)
        jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
        del out
    t1 = time.time()
    print("Manual:", (t1-t0) / n_iters, "s")

def run_escn():
    t0 = time.time()
    for i in range(1, n_iters+1):
        inputs = e3nn.normal(irreps=input_irreps, key=jax.random.PRNGKey(i), leading_shape=leading_shape)
        vectors = e3nn.normal(irreps="1o", key=jax.random.PRNGKey(i), leading_shape=leading_shape)
        out = escn(inputs, vectors)
        jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
        del out
    t1 = time.time()
    print("eSCN:", (t1-t0) / n_iters, "s")

print(leading_shape)
run_manual()
run_escn()

I've attached the output on an A6000 GPU. On futher inspection, most of the cost seems to be from the rotation operation at the end.

Screenshot from 2024-04-15 18-55-38

Is this a bug or expected?
Also note the difference in output values.

Availability of FullTensorProduct function similar to PyTorch version

Hello,

I've been using the FullTensorProduct function in the PyTorch version of e3nn and found it very beneficial. As I'm transitioning to JAX and Haiku, I was exploring e3nn-jax and couldn't locate a similar function.

Is there an equivalent to FullTensorProduct in the JAX version? If not, are there plans to add it?

Thank you for your great work on this library.

Add custom initialization options in e3nn_jax.flax.Linear

Description:

Currently, the e3nn_jax.flax.Linear module lacks flexibility in providing custom initialization options. This limitation prevents users from initializing network weights according to specific requirements, potentially impacting model performance and convergence speed.

Expected Outcome:

It is desired to add an option in the e3nn_jax.flax.Linear module that allows users to pass a custom initialization function or parameters. This would enable more control over the weight initialization process.
Proposed Solution:

  1. Add a new parameter named initializer to the constructor of e3nn_jax.flax.Linear that can accept either a function or a dictionary specifying the initialization method and its parameters.
  2. Use the provided initializer parameter within the module's __init__ method to initialize the weights.

Current Code:

In e3nn_jax.flax.linear, the code is:

def param(name, shape, std, dtype):
    return self.param(
        name, flax.linen.initializers.normal(stddev=std), shape, dtype
    )

But when I want to use uniform initializer, I can't do it unless I modify the code of e3nn_jax. So can the option to change the initialization method be added to the official implementation? Thank you!

References:

IrrepsArray.transform_by_angles and LinearSHTP does not support batch operation

the experimental class LinearSHTP does not support batch opertation, i.e., it only supports input of dimension [1, dim of Irreps]. However, eSCN works on IrrepsArrays of dimension [#nodes, dim of Irreps]. I wonder if this is because the rotation procedure in eSCN, implemented by transform_by_angles in LinearSHTP, only supports rotation by the same angle for all node features, while eSCN requires rotation by different angles for different nodes(i.e., output[i,...] = D[i,...] * input[i,...] for each node index i, where D contains representations of rotation matrices).

to support such kind of operation, a quick solution that comes up to my mind is to rewrite transform_by_angles as

D = {
ir: ir.D_from_angles(alpha, beta, gamma, k)
for ir in {ir for _, ir in self.irreps}
}
if inverse:
D = {ir: jnp.swapaxes(D[ir], -2, -1) for ir in D}
new_chunks = [
jnp.einsum("ij,...uj->...ui", D[ir], x)
if x is not None
else None
)
for (mul, ir), x in zip(self.irreps, self.chunks)
]
return e3nn.from_chunks(self.irreps, new_chunks, self.shape[:-1], self.dtype)

where now alpha, beta, gammas are arrays of shape self.shape[:-1].

Since I am not a developer of e3nn, I may not be able to realize some subtleties regarding implementations. Are there any concerns about implementing this function?

(although I can also use vmap and the current LinearSHTP to implement eSCN, this may require initializing an nn.Module in call. I wonder if this will bring problems.)

Backward pass runtime degradation (Linear + `tensor_product`) in the latest versions

Hey. The tensor_product implementation has changed substantially from 0.17.3 to 0.19.3, which made it simpler and much faster in inference. This does not seem to be the case for backpropagation (especially at high Ls). I noticed this while training segnn with L=3 on a problem that used to take hours with e3nn-jax 0.17.3, but with 0.19.3 did not finish in days.

Model init and backward jitting also always take a lot of time, which is not a big deal in practice but could mean something.

e3nn-jax version vs Backprop time

On the other hand, differentiating tensor_product only (without linear layer) is about the same/faster in the newer versions, which is unexpected.

e3nn-jax version vs Backprop time (tp only)


Detailed results in table. Forward and backward are 100 (compiled) passes on single Linear + tensor_product layer, with the spherical harmonics at the respective order as input and 1x0e + 1x1o as output. Jit time refers to the backward pass jitting.

L=0 L=3 L=5
Frwd Bkwd Jit Frwd Bkwd Jit Frwd Bkwd Jit
0.17.3 16.63 10.70 397.70 30.32 27.91 751.56 57.19 80.14 1232.86
0.19.1 15.14 22.61 747.90 32.16 500.49 5220.34 52.06 1275.34 34265.00
0.19.2 16.19 35.12 1123.24 26.00 499.58 8296.36 43.45 2113.12 60620.00

Reproduction script

import os
import jax
import jax.numpy as jnp 
import time
import haiku as hk
import warnings
from multiprocessing import Process, Queue

warnings.filterwarnings("ignore")


def tp_test(queue):
    import e3nn_jax as e3nn

    def run(tp_fn, ir, bs=100, reruns=100):    
        x = e3nn.normal(ir, key=jax.random.PRNGKey(0), leading_shape=(bs,))

        st = time.perf_counter_ns()
        tp = hk.without_apply_rng(hk.transform(tp_fn))
        w = tp.init(jax.random.PRNGKey(0), x)
        apply = jax.jit(tp.apply)
        jax.block_until_ready(apply(w, x))
        init_end = (time.perf_counter_ns() - st) / 1e6

        st = time.perf_counter_ns()
        for _ in range(reruns):
            jax.block_until_ready(apply(w, x))
        forward_end = (time.perf_counter_ns() - st) / 1e6

        # backwards
        @jax.jit
        def grad_fn(x):
            def loss_fn(x):
                return jnp.mean(apply(w, x))
            return jax.grad(loss_fn)(x)

        st = time.perf_counter_ns()
        jax.block_until_ready(grad_fn(x))
        jit_end = (time.perf_counter_ns() - st) / 1e6

        st = time.perf_counter_ns()
        for _ in range(reruns):
            jax.block_until_ready(grad_fn(x))
        backward_end = (time.perf_counter_ns() - st) / 1e6
        
        return forward_end, backward_end, init_end, jit_end


    print(f"Version {e3nn.__version__}")

    results = {}

    for L in [1, 3, 5]:
        ir_L = e3nn.Irreps.spherical_harmonics(L) * 8
        def only_tp(x):
            return e3nn.tensor_product(x, x, filter_ir_out="1x0e+1x1o").array
        
        frwd_time, bkwd_time, init_time, jit_time = run(only_tp, ir_L)
        print(
            f"[L={L} TP only]: forward={frwd_time:.2f}ms (init={init_time:.2f}ms) - "
            f"backward={bkwd_time:.2f}ms (jit={jit_time:.2f}ms)"
        )
                
        def linear_tp(x):
            return e3nn.haiku.Linear("1x0e+1x1o")(e3nn.tensor_product(x, x)).array
        
        frwd_time, bkwd_time, init_time, jit_time = run(linear_tp, ir_L)
        print(
            f"[L={L} Linear TP]: forward={frwd_time:.2f}ms (init={init_time:.2f}ms) - "
            f"backward={bkwd_time:.2f}ms (jit={jit_time:.2f}ms)"
        )

        results[f"L={L}"] = bkwd_time
    
    queue.put({e3nn.__version__: results})


if __name__ == "__main__":

    queue = Queue()

    os.system("pip install e3nn-jax==0.17.3 >/dev/null 2>&1")
    p = Process(target=tp_test, args=(queue,))
    p.start()
    p.join()

    print("---")

    os.system("pip install e3nn-jax==0.19.1 >/dev/null 2>&1")
    p = Process(target=tp_test, args=(queue,))
    p.start()
    p.join()

    print("---")

    os.system("pip install e3nn-jax==0.19.2 >/dev/null 2>&1")
    p = Process(target=tp_test, args=(queue,))
    p.start()
    p.join()

    try:
        import matplotlib.pyplot as plt
        import pandas as pd
        results = {}
        while not queue.empty():
            results.update(queue.get())
        results = pd.DataFrame(results)
        results.plot.bar()
        plt.title("Backward time (Linear + tensor_product only)")
        plt.ylabel("time [ms]")
        plt.savefig("e3nn_backward.png")
    
    except ImportError:
        pass

migrate to `pyproject.toml` for packaging

pyproject.toml is the new standard for packaging python packages, setup.py is now deprecated (first introduced in PEP 518 and later expanded in PEP 517, PEP 621 and PEP 660).

Shifting to a pyproject.toml based build would lead to a much simpler project structure (a single pyproject.toml as opposed to setup.py + pyproject.toml + requirements-dev.txt).

Nothing much would change for the end user, the way to install the package still remains the same. The only minimal change would be to use python -m build.

I'm happy to take this up 😄

References:

Gate output irreps

From what I understand from here, we expect the output of the gate to be the direct sum of scalar irreps and gated irreps. But in e3nn_jax, it is the direct sum of scalar irreps, gated irreps, and gated irreps as well. I think the reason is the line below. Is this an issue? @mariogeiger

self.irreps_in = irreps_scalars + irreps_gates + irreps_gated

refactor test suite

I'd like to suggest a refactoring of the current test suite. Having all tests under the tests/ directory which follows the same pattern as the module. i.e.

e3nn_jax/
 ...
tests/
 noxfile.py
 conftest.py
 ...

A lot of modern JAX based frameworks follow the same structure, viz flax, jax-md, jax (itself) and kfac-jax.

FullyConnectedTensorProduct Feature Discrepancy in JAX vs. Torch

In the PyTorch's FullyConnectedTensorProduct, there's an option to set internal_weights to False. I couldn't find a similar feature in JAX. Also, I'm curious if JAX and Haiku allow using predefined weights for this operation. If I'm overlooking something or if I'm incorrect, please correct me. If not, integrating these features would be great.

GitHub install fails silently

Installing from GitHub used to work:

pip install git+https://github.com/e3nn/e3nn-jax

but now

import e3nn_jax

raises an ImportError.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.