Git Product home page Git Product logo

escnn's Introduction

E(n)-equivariant Steerable CNNs (escnn)

escnn is a PyTorch extension for equivariant deep learning. escnn is the successor of the e2cnn library, which only supported planar isometries. Instead, escnn supports steerable CNNs equivariant to both 2D and 3D isometries, as well as equivariant MLPs.

If you prefer using Jax, check our this fork escnn_jax of our library!


Equivariant neural networks guarantee a specified transformation behavior of their feature spaces under transformations of their input. For instance, classical convolutional neural networks (CNNs) are by design equivariant to translations of their input. This means that a translation of an image leads to a corresponding translation of the network's feature maps. This package provides implementations of neural network modules which are equivariant under all isometries $\mathrm{E}(2)$ of the image plane $\mathbb{R}^2$ and all isometries $\mathrm{E}(3)$ of the 3D space $\mathbb{R}^3$, that is, under translations, rotations and reflections (and can, potentially, be extended to all isometries $\mathrm{E}(n)$ of $\mathbb{R}^n$). In contrast to conventional CNNs, $\mathrm{E}(n)$-equivariant models are guaranteed to generalize over such transformations, and are therefore more data efficient.

The feature spaces of $\mathrm{E}(n)$-equivariant Steerable CNNs are defined as spaces of feature fields, being characterized by their transformation law under rotations and reflections. Typical examples are scalar fields (e.g. gray-scale images or temperature fields) or vector fields (e.g. optical flow or electromagnetic fields).

feature field examples

Instead of a number of channels, the user has to specify the field types and their multiplicities in order to define a feature space. Given a specified input- and output feature space, our R2conv and R3conv modules instantiate the most general convolutional mapping between them. Our library provides many other equivariant operations to process feature fields, including nonlinearities, mappings to produce invariant features, batch normalization and dropout.

In theory, feature fields are defined on continuous space $\mathbb{R}^n$. In practice, they are either sampled on a pixel grid or given as a point cloud. escnn represents feature fields by GeometricTensor objects, which wrap a torch.Tensor with the corresponding transformation law. All equivariant operations perform a dynamic type-checking in order to guarantee a geometrically sound processing of the feature fields.

To parameterize steerable kernel spaces, equivariant to an arbitrary compact group $G$, in our paper, we generalize the Wigner-Eckart theorem in A Wigner-Eckart Theorem for Group Equivariant Convolution Kernels from $G$-homogeneous spaces to more general spaces $X$ carrying a $G$-action. In short, our method leverages a $G$-steerable basis for unconstrained scalar filters over the whole Euclidean space $\mathbb{R}^n$ to generate steerable kernel spaces with arbitrary input and output field types. For example, the left side of the next image shows two elements of a $\mathrm{SO}(2)$-steerable basis for functions on $\mathbb{R}^2$ which are used to generate two basis elements for $\mathrm{SO}(2)$-equivariant steerable kernels on the right. In particular, the steerable kernels considered map a frequency $l=1$ vector field (2 channels) to a frequency $J=2$ vector field (2 channels).

we_theorem_example

$\mathrm{E}(n)$-Equivariant Steerable CNNs unify and generalize a wide range of isometry equivariant CNNs in one single framework. Examples include:

For more details, we refer to our ICLR 2022 paper A Program to Build E(N)-Equivariant Steerable CNNs and our NeurIPS 2019 paper General E(2)-Equivariant Steerable CNNs.


The library is structured into four subpackages with different high-level features:

Component Description
escnn.group implements basic concepts of group and representation theory
escnn.kernels solves for spaces of equivariant convolution kernels
escnn.gspaces defines the Euclidean spaces and their symmetries
escnn.nn contains equivariant modules to build deep neural networks

WARNING: escnn.kernels received major refactoring in version 1.0.0 and it is not compatible with previous versions of the library. These changes do not affect the interface provided in the rest of the library but, sometimes, the weights of a network trained with a previous version might not load correctly in a newly instantiated model. We recommend using version v0.1.9 for backward compatibility.

Demo

Since $\mathrm{E}(2)$-steerable CNNs are equivariant under rotations and reflections, their inference is independent from the choice of image orientation. The visualization below demonstrates this claim by feeding rotated images into a randomly initialized $\mathrm{E}(2)$-steerable CNN (left). The middle plot shows the equivariant transformation of a feature space, consisting of one scalar field (color-coded) and one vector field (arrows), after a few layers. In the right plot we transform the feature space into a comoving reference frame by rotating the response fields back (stabilized view).

Equivariant CNN output

The invariance of the features in the comoving frame validates the rotational equivariance of $\mathrm{E}(2)$-steerable CNNs empirically. Note that the fluctuations of responses are discretization artifacts due to the sampling of the image on a pixel grid, which does not allow for exact continuous rotations.

For comparison, we show a feature map response of a conventional CNN for different image orientations below.

Conventional CNN output

Since conventional CNNs are not equivariant under rotations, the response varies randomly with the image orientation. This prevents CNNs from automatically generalizing learned patterns between different reference frames.

Experimental results

$\mathrm{E}(n)$-steerable convolutions can be used as a drop in replacement for the conventional convolutions used in CNNs. While using the same base architecture (with similar memory and computational cost), this leads to significant performance boosts compared to CNN baselines (values are test accuracies in percent).

model Rotated ModelNet10
CNN baseline 82.5 ± 1.4
SO(2)-CNN 86.9 ± 1.9
Octa-CNN 89.7 ± 0.6
Ico-CNN 90.0 ± 0.6
SO(3)-CNN 89.5 ± 1.0

All models share approximately the same architecture and width. For more details we refer to our paper.

This library supports $\mathrm{E}(2)$-steerable CNNs implemented in our previous e2cnn library as a special case; we include some representative results in the 2D setting from there:

model CIFAR-10 CIFAR-100 STL-10
CNN baseline 2.6   ± 0.1   17.1   ± 0.3   12.74 ± 0.23
E(2)-CNN * 2.39 ± 0.11 15.55 ± 0.13 10.57 ± 0.70
E(2)-CNN 2.05 ± 0.03 14.30 ± 0.09   9.80 ± 0.40

While using the same training setup (no further hyperparameter tuning) used for the CNN baselines, the equivariant models achieve significantly better results (values are test errors in percent). For a fair comparison, the models without * are designed such that the number of parameters of the baseline is approximately preserved while models with * preserve the number of channels, and hence compute. For more details we refer to our previous e2cnn paper.

Getting Started

escnn is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).

from escnn import gspaces                                          #  1
from escnn import nn                                               #  2
import torch                                                       #  3
                                                                   #  4
r2_act = gspaces.rot2dOnR2(N=8)                                    #  5
feat_type_in  = nn.FieldType(r2_act,  3*[r2_act.trivial_repr])     #  6
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr])     #  7
                                                                   #  8
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5)       #  9
relu = nn.ReLU(feat_type_out)                                      # 10
                                                                   # 11
x = torch.randn(16, 3, 32, 32)                                     # 12
x = feat_type_in(x)                                                # 13
                                                                   # 14
y = relu(conv(x))                                                  # 15

Line 5 specifies the symmetry group action on the image plane $\mathbb{R}^2$ under which the network should be equivariant. We choose the cyclic group $\mathrm{C}_8$, which describes discrete rotations by multiples of $2\pi/8$. Line 6 specifies the input feature field types. The three color channels of an RGB image are thereby to be identified as three independent scalar fields, which transform under the trivial representation of $\mathrm{C}_8$ (when the input image is rotated, the RGB values do not change; compare the scalar and vector fields in the first image above). Similarly, the output feature space in line 7 is specified to consist of 10 feature fields which transform under the regular representation of $\mathrm{C}_8$. The $\mathrm{C}_8$-equivariant convolution is then instantiated by passing the input and output type as well as the kernel size to the constructor (line 9). Line 10 instantiates an equivariant ReLU nonlinearity which will operate on the output field and is therefore passed the output field type.

Lines 12 and 13 generate a random minibatch of RGB images and wrap them into a nn.GeometricTensor to associate them with their correct field type feat_type_in. The equivariant modules process the geometric tensor in line 15. Each module is thereby checking whether the geometric tensor passed to them satisfies the expected transformation law.

Because the parameters do not need to be updated anymore at test time, after training, any equivariant network can be converted into a pure PyTorch model with no additional computational overhead in comparison to conventional CNNs. The code currently supports the automatic conversion of a few commonly used modules through the .export() method; check the documentation for more details.

To get started, we provide some examples and tutorials:

  • The introductory tutorial introduces the basic functionality of the library.
  • A second tutorial goes through building and training an equivariant model on the rotated MNIST dataset.
  • Note that escnn also supports equivariant MLPs; see these examples.
  • Check also the tutorial on Steerable CNNs using our library in the Deep Learning 2 course at the University of Amsterdam.

More complex 2D equivariant Wide Resnet models are implemented in e2wrn.py. To try a model which is equivariant under reflections call:

cd examples
python e2wrn.py

A version of the same model which is simultaneously equivariant under reflections and rotations of angles multiple of 90 degrees can be run via:

python e2wrn.py --rot90

You can find more examples in the example folder. For instance, se3_3Dcnn.py implements a 3D CNN equivariant to rotations and translations in 3D. You can try it with

cd examples
python se3_3Dcnn.py

Useful material to learn about Equivariance and Steerable CNNs

If you want to better understand the theory behind equivariant and steerable neural networks, you can check these references:

  • Erik Bekkers' lectures on Geometric Deep Learning at in the Deep Learning 2 course at the University of Amsterdam
  • The course material also includes a tutorial on group convolution and another about Steerable CNNs, using this library.
  • Gabriele's MSc thesis provides a brief overview of the essential mathematical ingredients needed to understand Steerable CNNs.
  • Maurice's PhD thesis develops the representation theory of steerable CNNs, deriving the most prominent layers and explaining the gauge theoretic viewpoint.

Dependencies

The library is based on Python3.7

torch>=1.3
numpy
scipy
lie_learn
joblib
py3nj

Optional:

torch-geometric
pymanopt>=1.0.0
autograd

WARNING: py3nj enables a fast computation of Clebsh Gordan coefficients. If this package is not installed, our library relies on a numerical method to estimate them. This numerical method is not guaranteed to return the same coefficients computed by py3nj (they can differ by a sign). For this reason, models built with and without py3nj might not be compatible.

To successfully install py3nj you may need a Fortran compiler installed in you environment.

Installation

You can install the latest release as

pip install escnn

or you can clone this repository and manually install it with

pip install git+https://github.com/QUVA-Lab/escnn

Contributing

Would you like to contribute to escnn? That's great!

Then, check the instructions in CONTRIBUTING.md and help us to improve the library!

Do you have any doubts? Do you have some idea you would like to discuss? Feel free to open a new thread under in Discussions!

Cite

The development of this library was part of the work done for our papers A Program to Build E(N)-Equivariant Steerable CNNs and General E(2)-Equivariant Steerable CNNs. Please cite these works if you use our code:


   @inproceedings{cesa2022a,
        title={A Program to Build {E(N)}-Equivariant Steerable {CNN}s },
        author={Gabriele Cesa and Leon Lang and Maurice Weiler},
        booktitle={International Conference on Learning Representations},
        year={2022},
        url={https://openreview.net/forum?id=WE4qe9xlnQw}
    }
    
   @inproceedings{e2cnn,
       title={{General E(2)-Equivariant Steerable CNNs}},
       author={Weiler, Maurice and Cesa, Gabriele},
       booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
       year={2019},
       url={https://arxiv.org/abs/1911.08251}
   }

Feel free to contact us.

License

escnn is distributed under BSD Clear license. See LICENSE file.

escnn's People

Contributors

chawater avatar danfoa avatar gabri95 avatar kalekundert avatar kartikchincholikar avatar mauriceweiler avatar psteinb avatar rlee3359 avatar zxp-s-works avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

escnn's Issues

Fully Conv Nets - Architecture Design and Training

Hi there! Thanks for the library, really enjoying using it so far, and it seems really promising for my application!

I had a couple of questions regarding design and training choices for fully conv nets, would greatly appreciate some help! For context, I'm trying to track points in an image by training the network output heatmap against masks with 2D gaussians located at the label co-ordinates, with BCE + Sigmoid.

  1. Training. When training such a model, how should the output of the network and the label be handled?
    I'm using output_type = nn.FieldType(gc, 1*[gc.trivial_repr for the final layer of the network. Do I need to cast the label to the same GeometricTensor:
    y = model.output_type(y) # (results in an error in BCELoss - no attribute 'numel')
    Or do I pull out the torch tensor of the output?
    y_hat = y_hat.tensor # (this trains but seems to break the equivariance of the model)

To test, I'm just training the model on one image and then testing the equivariance by rotating that one example and visualizing the output (using an architecture and viz setup similar to the animation example in the readme).

  1. Upsampling: I noticed that R2Upsampling doesn't support a 'size' parameter as well as 'scale_factor' like torch.nn.Upsample does, is there a reason this is not implemented? Codewise it seems like a straightforward addition to the wrapper around torchs interpolate, but I thought maybe it might affect equivariance if the scaling factor was not perfect integer multiples, is that correct? I'm trying to make my output a heatmap with a specific shape (for example the same as the input shape or an specific down-scaling of the input shape). If this is not possible, I'll have to design the layers very deliberately to achieve a specific output size, I think.

  2. On that note, I was wondering about conv layer design choices. I read through the paper and the documentation but I'm still not sure about how to choose kernel_size, stride and padding in the best way to preserve equivariance.
    I did notice that the mnist example mentioned that stride 2 should not be used with odd kernel sizes (unless image padding is added), could you expand on this? What combination of kernel_size and stride is safe to use? I'm trying to re-create a model that has kernel of 5 and stride 2, but is it best to modify my network to avoid odd kernel sizes altogether? The example also mentioned that adding padding can allow the network to observe rotation artefacts and thus breaks equivariance somewhat. Is adding no padding safe? Sorry if that's a redundant question, I just thought I'd clarify to be sure :)

  3. Lastly, I was curious about image padding and the MaskModule. Is it best practice to use one of these approaches? For now I'm manually masking everything outside the "image circle", is this strictly necessary? Is the mask module or padding out the image to prevent image corner loss better?

Sorry for the long post! Really appreciate any help I can get with these questions. Thanks so much!

Saving and loading escnn models

Hi,

Firstly, thanks for making such an accessible library to implement equivariant models alongside such informative documentation!

I wanted to train a toy model with MNIST before moving on to a bigger architecture and chose the model provided in the model.ipynb notebook in the 'examples' folder. After plugging it into my training script, I saved it using the regular PyTorch save procedure:
torch.save(model.state_dict(), 'mnist_model_e2cnn_{}.pt'.format(n_orientations))

In my test script, when I try to load this model using:
model.load_state_dict(torch.load('mnist_model_e2cnn_{}.pt'.format(n_orientations), map_location='cpu'))

However, trying to load the model throws up the following error:

RuntimeError: Error(s) in loading state_dict for MNISTE2CNN:
Missing key(s) in state_dict: "block1.1.filter", "block2.0.filter", "block3.0.filter", "block4.0.filter", "block5.0.filter", "block6.0.filter".

I'm not sure what I'm doing incorrectly, is there a special procedure involved in saving models that use escnn.nn.SequentialModule to stack ops?

EDIT: The torch version I am using is 1.7.0

Cheers,
Ishaan

Specifying gspaces to negate the input

Hi, I am a beginner and want to ask this. I have a vector input (not an image) of size (2, 1). I would like to define a group action that negates all the input, i.e., input (x, y) --> (-x, -y). How would I do that using gspaces? And is there any way that I can check that it would indeed negate all the input? I tried the below code, but it didn't work. Thanks!

r2_act = gspaces.flip2dOnR2()
feat_type_in  = nn.FieldType(r2_act,  2*[r2_act.trivial_repr])     

x = torch.randn(1, 2, 1, 1)
x = feat_type_in(x)

for g in r2_act.testing_elements:
    print(x.transform(g))

Give devs the flexibility to control PointwiseAvgPoolAntialiased2D output

Just bumped my head into this issue for a while during debugging. It would be great if the implicit kernel_size could be made an explicit part of the API of PointwiseAvgPoolAntialiased2D and others:

class PointwiseAvgPoolAntialiased2D(EquivariantModule):

    def __init__(self,
                 in_type: FieldType,
                 sigma: float,
                 stride: Union[int, Tuple[int, int]],
                 # kernel_size: Union[int, Tuple[int, int]] = None, #<- this provides only rigid control of this class
                 padding: Union[int, Tuple[int, int]] = None,
                 #dilation: Union[int, Tuple[int, int]] = 1,
                 ):

I suggest to make kernel_size part of the API again and fill filter with 0 for example. This would give downstream users the freedom to play more with the output of the underlying conv operation.

Installation issue

I am unable to install the package. The problem is coming from the package py3nj for which I am unable to build the wheel. I have posted an issue on the repo, please take a look at fujiisoup/py3nj#18

Does anyone also have installation issues ?

Gimbal lock warning

The following code gives a Gimbal lock warning. I suppose I can ignore?

G = so3_group()
rho = G.irrep(1)
subgroup_id = (False, -1)
G.restrict_representation(subgroup_id, rho)

Gives

/usr2/pim/.local/lib/python3.8/site-packages/escnn/group/groups/so3_utils.py:96: UserWarning: Gimbal lock detected. Setting third angle to zero since it is not possible to uniquely determine all angles.
  return element.as_euler(param)
Out[72]: SO(2)|[SO(3):irrep_1]:3

Incorrect type of Real irrep as Complex irrep in `C(n)` and `SO(2)`

In the function Group.irrep of both C(n) and SO(2) the convention of real Irreducible representations is used. However ever there is a mislabeling of the type of irreps as complex, which will create issues for applications where the distinction of real and complex irrep is necessary.

More precisely, irreducibles associated to rotations on the plane not aligned any coordinate axis, are constructed in:

# 2 dimensional Irreducible Representations
irrep = _build_irrep_cn(k)
character = _build_char_cn(k)
supported_nonlinearities = ['norm', 'gated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 2, 'C',
supported_nonlinearities=supported_nonlinearities,
character=character,
frequency=k)

# 2 dimensional Irreducible Representations
supported_nonlinearities = ['norm', 'gated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 2, 'C',
supported_nonlinearities=supported_nonlinearities,
character=character,
frequency=k)

As can be clearly seen from the comments and from the method group.utils.psi used to build the 2-dimensional representations:

def psi(theta: float, k: int = 1, gamma: float = 0.):
r"""
Rotation matrix corresponding to the angle :math:`k \theta + \gamma`.
"""
x = k * theta + gamma
c, s = np.cos(x), np.sin(x)
return np.array(([
[c, -s],
[s, c],
]))

These representations are Real rotation matrix representations. As it is known, any 2D rotation matrix is decomposable into two 1-dimensional Complex irreducible representations.

This brings several problems all emerging from the fact that you are labeling a complex representation as irreducible, when in fact it is decomposable. Beyond the trivial solution to the issue of changing the type of these irreps to Real, this issue points to a potential source of error in ESCNN by not having a method for actually checking if an irrep is in fact an irrep.

At least for complex irreducibles, this can be easily checked with:

def is_complex_irreducible(
    G: Group, representation: Union[Dict[GroupElement, np.ndarray], Callable[[GroupElement], np.ndarray]]
):
    """
    Check if a representation is complex irreducible. We check this by asserting weather non-scalar (no multiple of 
    identity ) Hermitian matrix `H` exists, such that `H` commutes with all group elements' representation.  
    If rho is irreducible, this function returns (True, H=I)  where I is the identity matrix.
    Otherwise, returns (False, H) where H is a non-scalar matrix that commutes with all elements' representation.
    """
    if isinstance(representation, dict):
        rep = lambda g: representation[g]
    else:
        rep = representation

    # Compute the dimension of the representation
    n = rep(G.sample()).shape[0]

    possible_transformations = []
    # Run through all r,s = 1,2,...,n
    for r in range(n):
        for s in range(n):
            # Define H_rs
            H_rs = np.zeros((n, n), dtype=complex)
            if r == s:
                H_rs[r, s] = 1
            elif r > s:
                H_rs[r, s] = 1
                H_rs[s, r] = 1
            else:  # r < s
                H_rs[r, s] = 1j
                H_rs[s, r] = -1j

            # Compute H
            H = sum([rep(g).conj().T @ H_rs @ rep(g) for g in G.elements]) / G.order()

            # If H is not a scalar matrix, then it is a matrix that commutes with all group actions.
            if not np.allclose(H[0, 0] * np.eye(H.shape[0]), H):
                return False, H
    # No Hermitian matrix was found to commute with all group actions. This is an irreducible rep
    return True, np.eye(n)

Let me know if this sounds like a good possible PR. I will contribute it gladly.

pymanopt version

Hi,
I am trying to use group gspaces.octaOnR3(). However, it requires pymanopt and I cannot find the correct version of pymanopt. It seems like the package changed API that escnn is no longer compatible with it. Could you tell me which version of pymanopt should I use?

Thank you!

Best,
XP

advanced indexing is not supported

Hi @Gabri95,

So far, only basic slicing is supported. Would be cool to have advanced indexing :).

import torch
import escnn
print(escnn.__version__)
from escnn import nn, group, gspaces

G = group.so3_group()
gspace = gspaces.no_base_space(G)
in_type = gspace.type(G.standard_representation()) # input: 3D coordinates
x = nn.GeometricTensor(torch.randn(10,3), in_type)
indices = torch.randint(0, 10, (5,))

Slicing would work:
print(x[1:5])

But indexing will give an error:
print(x[indices])

Missing indexing dimensions in GeometricTensors

Hi @Gabri95,

I found some limitations in the GeometricTensor class when one needs to define GeometricTensors
with indexing dimensions, not in the fiber nor base space, such as time. Consider the following case:

For an equivariant dynamical system, you define the state s as a GeometricTensor, and often you handle trajectories of state Ts = [s_1, s_2, s_3].

In practice, to process these trajectories of motion, your tensors should have a shape similar to (batch, time <indexing dimension>, points, features) or (batch, time <indexing dimension>, features). Where the last to dimensions are your defined points in the base space and the feature fields in the fiber space for each point.

The problem of restricting the shape of the GeometricTensor, as it is restricted now, relies in the fact that I cannot appropriately process a state trajectory efficiently, as I have no way to concatenate GeometricTensors (along with indexing dimensions). This means I cannot apply the group action to the trajectories of the state as a tensor operation but instead have to apply it to each individual state (GeometricTensor) instance.

I could build a bigger GeometricTensor and concatenate tensors over the batch dimension (I am currently bypassing this limitation this way). Still, the definition of GeometricTensor might benefit from enabling these indexing dimensions, on which no action of symmetries should be applied. I expect people to be in need of indexing dimensions often.

Let me know if it's a good idea. In that case, I can append it to the list of contributions I have promised to make.

check_equivariance test failed

gspace = gspaces.rot2dOnR2(4)
n_feat = 16
in_type = enn.FieldType(gspace, [gspace.trivial_repr]*3)
out_type = enn.FieldType(gspace, [gspace.regular_repr]*n_feat)

net = enn.R2Conv(in_type, out_type, kernel_size=3, stride=1, padding=1, dilation=1, bias=True, sigma=None)
net.check_equivariance()

I got error AssertionError: The error found during equivariance check with element "10[2pi/12]" is too high: max = 3.4398996829986572, mean = 0.3915270268917084 var =0.21077707409858704

Is this failure expected?

The order of basis

Hi, I notice that the standard representation is not exactly the same as the irreps[1], and they are off by a permutation of axis. Is there any reason why this is the case?

import escnn.group as g
Group = g.so3_group(L)
g_element = Group.sample()
Group.standard_representation()(element_g)
#array([[-0.59306326, -0.24286753,  0.76765313],
#       [-0.13035531,  0.96980604,  0.20611582],
#       [-0.79453349,  0.02217205, -0.60681541]])
Group.irreps()[1](element_g)
#array([[ 0.96980604,  0.20611582, -0.13035531],
#      [ 0.02217205, -0.60681541, -0.79453349],
#       [-0.24286753,  0.76765313, -0.59306326]])

Mismatch in coordinate convention?

I'm working with rot3dOn3d, and I don't understand the following coordinate convention change that seems to happen in transform_fibers:

from escnn.gspaces import rot3dOnR3

a = 2*np.pi/3
R = np.array([[np.cos(a), -np.sin(a),  0.0000],
              [np.sin(a),  np.cos(a),  0.0000],
              [   0.0000,     0.0000,  1.0000]])
# array([[-0.5000, -0.8660,  0.0000],
#        [ 0.8660, -0.5000,  0.0000],
#        [ 0.0000,  0.0000,  1.0000]])

gspace = rot3dOnR3()
el = gspace.fibergroup.element(R, param="MAT")

field_type = gspace.type(gspace.irrep(1))
field_type.representations[0](el)
# array([[-0.5      ,  0.       ,  0.8660254],
#        [ 0.       ,  1.       ,  0.       ],
#        [-0.8660254,  0.       , -0.5      ]])

I assumed the vector channels had XYZ order, but this seems to suggest they have ZYX? (Is this related to -Z -Y X somehow?)

I don't think it's -Z -Y -X because of this example
z = torch.tensor([1, 1, 0]).float().reshape(1, 1, 3) # batch, channels, vector
z = z[:, :, [2, 1, 0]] * torch.tensor((-1, -1, 1))   # XYZ -> -Z -Y X

# convert to grid
z = z.reshape(1, 3, 1, 1, 1) # batch, vector, *coords

# transform
zrot = gspace.type(gspace.irrep(1))(z).transform(el)

# convert back to vector
z = z.mean(dim=(2, 3, 4)).reshape(1, 1, 3)
zrot = zrot.tensor.mean(dim=(2, 3, 4)).reshape(1, 1, 3)

z = z[:, :, [2, 1, 0]] * torch.tensor((1, -1, -1)) # -Z -Y X -> XYZ
zrot = zrot[:, :, [2, 1, 0]]  * torch.tensor((1, -1, -1)) # -Z -Y X -> XYZ
zrot
# tensor([[[-0.5000,  1.0000, -0.8660]]])

Which is wrong since this is a 2d rotation, (z coordinate should't change) and x and y are also wrong

However this example, where I do XYZ -> YZX seems to be correct
z = torch.tensor([1, 1, 0]).float().reshape(1, 1, 3) # batch, channels, vector
z = z[:, :, [1, 2, 0]] # XYZ -> YZX

# convert to grid
z = z.reshape(1, 3, 1, 1, 1) # batch, vector, *coords

# transform
zrot = gspace.type(gspace.irrep(1))(z).transform(el)

# convert back to vector
z = z.mean(dim=(2, 3, 4)).reshape(1, 1, 3)
zrot = zrot.tensor.mean(dim=(2, 3, 4)).reshape(1, 1, 3)

z = z[:, :, [2, 0, 1]] # YZX -> XYZ
zrot = zrot[:, :, [2, 0, 1]] # YZX -> XYZ
zrot
# tensor([[[-1.3660e+00,  3.6603e-01,  1.3004e-17]]])

Any pointers or ideas?

Slight inconsistensies between groups.

When using the groups for an implementation which is supposed to work regardless of the group, but does need to adapt a few things depending on certain characteristics of the group, I noticed an inconsistency with the 'rotation_order' attribute between the groups ,which can make it a bit convoluted to code up some parts of the logic. Namely: The O(2) and the Dihedral groups have the rotation_order attribute, but the O(3) group, SO(2) group and the Cyclic Group do not (in the case of the cyclic group this rotation order is sort of stored as the N attribute. I think it might make more sense to have all of these have the 'rotation_order' attribute.

Additionally, I am working with the Fourier transform on different groups and as far as I can tell there's no nice way to retrieve which grid types are supported by which group from the code. It would be nice to have this as some stored property for each group, such that it can easily be obtained.

Finally, I also noticed that certain groups, such as the octahedral and the Directproduct groups do not have the bl_irreps method, which is also something. I am not sure if it would be possible to have this for these groups, since I am not too familiar with these groups.

No module named 'lie_learn.representations.SO3.irrep_bases

Running from escnn.group import *, I get the following error:

>>> from  escnn.group import *
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/key/.virtualenvs/pt1.13.1/lib64/python3.11/site-packages/escnn-1.0.3-py3.11.egg/escnn/group/__init__.py", line 22, in <module>
    from .groups.factory import *
  File "/home/key/.virtualenvs/pt1.13.1/lib64/python3.11/site-packages/escnn-1.0.3-py3.11.egg/escnn/group/groups/__init__.py", line 2, in <module>
    from .factory import *
  File "/home/key/.virtualenvs/pt1.13.1/lib64/python3.11/site-packages/escnn-1.0.3-py3.11.egg/escnn/group/groups/factory.py", line 8, in <module>
    from .so3group import SO3
  File "/home/key/.virtualenvs/pt1.13.1/lib64/python3.11/site-packages/escnn-1.0.3-py3.11.egg/escnn/group/groups/so3group.py", line 8, in <module>
    from .so3_utils import *
  File "/home/key/.virtualenvs/pt1.13.1/lib64/python3.11/site-packages/escnn-1.0.3-py3.11.egg/escnn/group/groups/so3_utils.py", line 1, in <module>
    from lie_learn.representations.SO3.wigner_d import wigner_D_matrix
  File "/home/key/.virtualenvs/pt1.13.1/lib64/python3.11/site-packages/lie_learn-0.0.1.post1-py3.11.egg/lie_learn/representations/SO3/wigner_d.py", line 5, in <module>
    from lie_learn.representations.SO3.irrep_bases import change_of_basis_matrix
ModuleNotFoundError: No module named 'lie_learn.representations.SO3.irrep_bases'

I've tried various ways of installing escnn as well as lie_learn, with no success. (The last method I tried was cloning the repo and doing python setup.py install.)
My Python version is 3.11.1.

I'm aware of a similar issue in lie-learn, but it is 2.5 years old, and unsolved: AMLab-Amsterdam/lie_learn#16

Could someone please take a look? Thanks!

Unable to instantiate `R3IcoConv`

I'm running into an error when trying to use the R3IcoConv module. Here's some code that generates the error:

from escnn.nn import *
from escnn.gspaces import *

gspace = icoOnR3()

in_type = FieldType(gspace, [gspace.trivial_repr])
out_type = FieldType(gspace, [gspace.regular_repr])

conv = R3IcoConv(in_type, out_type, 3)

And here's the error itself:

________________________________________________________________________________
[Memory] Calling escnn.group.groups.ico._build_ico_irrep...
_build_ico_irrep(Icosahedral, 3)
__________________________________________________build_ico_irrep - 9.3s, 0.2min
________________________________________________________________________________
[Memory] Calling escnn.group.groups.ico._build_ico_irrep...
_build_ico_irrep(Icosahedral, 4)
__________________________________________________build_ico_irrep - 9.2s, 0.2min
Traceback (most recent call last):
  File "/home/kale/research/software/projects/atompaint/tests/foo.py", line 11, in <module>
    conv = R3IcoConv(in_type, out_type, 3)
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/nn/modules/conv/r3_ico_convolution.py", line 99, in __init__
    super(R3IcoConv, self).__init__(
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/nn/modules/conv/r3convolution.py", line 160, in __init__
    super(R3Conv, self).__init__(
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/nn/modules/conv/rd_convolution.py", line 223, in __init__
    self._basisexpansion = BlocksBasisExpansion(in_type.representations, out_type.representations,
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/nn/modules/basismanager/basisexpansion_blocks.py", line 72, in __init__
    block_expansion = block_basisexpansion(basis, points, basis_filter=basis_filter, recompute=recompute)
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/nn/modules/basismanager/basisexpansion_singleblock.py", line 144, in block_basisexpansion
    for b, attr in enumerate(basis):
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/kernels/basis.py", line 247, in __iter__
    for attr in basis:
  File "/home/kale/.pyenv/versions/3.10.0/lib/python3.10/site-packages/escnn/kernels/steerable_basis.py", line 547, in __iter__
    assert idx < self.dim
AssertionError

I tried a few different kernel sizes and both the 'ico' and 'dodeca' samples, but neither had any effect. I'm not sure if this is a bug, or if I'm doing something wrong.


On a related note, I'm not totally sure I understand when to use R3IcoConv vs R3Conv. Below are some bullet points outlining my thoughts on this. Can you take a look and let me know if I'm on the right track or not?

  • With R3Conv, the kernel parameters are used to make linear combinations of spherical harmonics, so you can think of the parameters as Fourier coefficients.
  • With R3IcoConv, the kernel parameters are associated with vertices of a icosahedron/dodecahedron/icosidodecahedron, so they live entirely in the spatial domain and aren't Fourier coefficients.
  • Smooth functions don't require many frequencies to represent in the Fourier domain, but "spiky" functions do. So if you want spiky filters, R3IcoConv should work better. If you want smooth filters, R3Conv is fine.
  • I also assume that R3IcoConv is more efficient, since it's specialized to the specific case of icosahedral symmetry, but I have no idea if this is true.

This leads to a few follow-up questions:

  • Would you ever recommend using R3Conv instead of R3IcoConv with the icoOnR3 gspace?
  • I notice that it takes ≈20s to build the icosahedral irreps when instantiating the icoOnR3 gspace. Are these irreps actually used by R3IcoConv? I know that in the 2D case, you can do group convolutions that are C(n)-equivariant with standard, unconstrained kernels by "lifting" the input to a higher dimension. Can R3IcoConv be thought of as a 3D version of this? If so, I'd expect that it wouldn't be necessary to calculate irreps. I suspect that this analogy is just wrong, but I'm not sure.

Thanks for taking the time to help me understand this. I've watched Erik Bekkers' lecture series on equivariant learning, so I feel like I have a some grasp on the fundamentals, but this is all still quite new to me.

Utility functions to save and load instances of Group and Representations

Hi @Gabri95,

I have been facing some (perhaps unnecessary) issues in saving and loading instances of classes from ESCNN.

Say I want to generate a dataset of Equivariant dynamical systems, I would like to save in files the actual symmetry group (subgroup) and the representations of each of the measurements taken.

The representations are ok to save using the irreps ids and the change of basis, but the group itself it's proven a bit harder to save and load elegantly. The group and representation classes are not serializable and thus cannot be pickled. But storing the entire object also seems unnecessary.

I am assuming you have already faced that issue, or if not, that you might have a good idea of how to do this as quickly. Any ideas?

Perhaps I can build (at some point) the protocol for storing and loading this instances

PS: I will answer all other threads we have active after ICLR. Sorry for the delay.

Missing grid() method for DihedralGroup class

I noticed that the DihedralGroup class does not have a definition for the grid() method, resulting in an error when using trying to use a Fourier-based activation function with layers using this group. The CyclicGroup on the other hand class does have a grid() method, so it would be nice if the DihedralGroup also had it.

`id` argument documentation hidden in R3 and R2 Gspaces.

Documentation of a GSpace restrict method mentions that documentation of id parameter should appear in the non-abstract instances of GSpace:

Build the GSpace associated with the subgroup of the current fiber group identified by the input id. This reduces the level of symmetries of the base space to be considered. Check the restrict method’s documentation in the non-abstract subclass used for a description of the parameter id.

However, this documentation is not exposed for R3 nor R2.

ESCNN SO(3) 3D CNN example

Hello!
I was wondering if there is an example of a 3D Steerable CNN (using R3Conv). Furthermore, Is it possible to use this library to train a model using a tensor of shape (batch, channels, width, length, depth)?

Thanks!

`rot3dOnR3` API is not inline `rot2dOnR2`

I wanted to implement a small 3D ESCNN. I saw that the rot3dOnR3 API is not inline rot2dOnR2. rot3dOnR3 misses the n parameter to describe the discrete rotations in 3D.

image

versus

image

Can the library be used to compute point convolutions like Tensor Field Networks?

So far, I have seen this library only being used for data on the grid (e.g. images in pixels). However, can it also be used for point clouds like Tensor Field Networks do?

In principle, the theory should work the same way: just do point convolutions with equivariant kernels based on the basis defined here. However, I see that the GeometricTensor class only support real tensors, i.e. data that lies on a grid. Is there any thought on implementing that in the future?

TypeError: 'SequentialModule' object is not subscriptable

Hi Gabriele,

It seems that SequentialModule is not subscribable now.

to reproduce:

from escnn import group
from escnn import gspaces
from escnn import nn

G = group.so3_group()
gspace = gspaces.no_base_space(G)
in_type = gspace.type(G.standard_representation())

seq = nn.SequentialModule(nn.IIDBatchNorm1d(in_type), nn.IIDBatchNorm1d(in_type))
print(seq[0])

Missing the tetrahedron group (`tetraOnR3` in `escnn.gspaces`?)

Hi,

I'm trying to build 3D equivariant CNN and I found the tetrahedron symmetry group ("Describes 3D rotation symmetries of a tetrahedron in the space R3") is missing in action-on-volume.

Is there any way to build it (by subgroup or isomorphism)? Thanks!


Indeed, in 3D, the only discrete 3D rotational symmetry groups are the symmetries of the few platonic solids (the
tetrahedron group, the octahedron group and the icosahedron group).

Conversely, since there are only a few options in 3D, we provide a different method for each of them
(e.g. :func:`~escnn.gspaces.icoOnR3` or :func:`~escnn.gspaces.octaOnR3`).


BTW, I tried

from escnn.gspaces import GSpace3D
GSpace3D((False, 'tetra'))

but got

elif so3_sg_id[0] == 'tetra':
# sg = escnn.group.tetra_group()
# parent_mapping = so3_to_o3(adj, sg)
# child_mapping = o3_to_so3(adj, sg)
raise NotImplementedError()

Use of `np.float` and `np.int` etc

Thanks @Gabri95 et al for providing this library. I am currently ramping up my use of it. I however discovered a small bug, that makes escnn unusable with any numpy version higher or equal to 1.24.

The numpy team deprecated the np.float and np.int and np.bool aliases to the respective python internals.
https://numpy.org/doc/stable/release/1.20.0-notes.html#using-the-aliases-of-builtin-types-like-np-int-is-deprecated
From numpy 1.24 and higher, their use produces an error which makes the example given in the README.md break immediately.

Instance Norm as normalization?

Dear @Gabri95
sorry to bug you. I am currently trying to come up with an equivariant Unet architecture which is very close to a "standard" Unet, I use as a reference. For this, I came across the matter of different normalization schemes. I looked at your implementations here and you appear to be focusing on batch norm only.

However, I was wondering if anything speaks against implementing InstanceNorm? The difference being that the mean/var is not computed across the entire batch, but rather across each sample in a batch.

Pytorch's Automatic Mixed Precision (AMP) is not supported for non-uniform field types.

Currently, Automatic Mixed Precision (AMP) is not supported for non-uniform field types. This is the script used to replicate the problem:

from escnn import group
from escnn import gspaces
from escnn import nn
import torch
from functools import partial


class O2SteerableCNNGated(torch.nn.Module):
    def __init__(self):
        super(O2SteerableCNNGated, self).__init__()

        self.act_r2 = gspaces.flipRot2dOnR2(N=-1)

        self.in_type = nn.FieldType(self.act_r2, [self.act_r2.trivial_repr])

        self.upsample = nn.R2Upsampling(self.in_type, size=(29, 29))

        self.mask = nn.MaskModule(self.in_type, 29, margin=1)

        channels = 12
        L = 4
        ################# This doesn't work
        out_type = nn.FieldType(
            self.act_r2,
            channels * [self.act_r2.irreps[0] for _ in range(L + 1)]
            + channels * [self.act_r2.irreps[i] for i in range(L + 1)],
        )

        ############# This does, since the field type is uniform
        # out_type = nn.FieldType(
        #     self.act_r2,
        #     channels * [self.act_r2.irreps[0]]
        #     + channels
        #     * [group.directsum([self.act_r2.irreps[i] for i in range(L + 1)])],
        # )

        self.block_1 = nn.SequentialModule(
            nn.R2Conv(self.in_type, out_type, kernel_size=5, padding=1),
        )

    def forward(self, x):
        x = self.in_type(x)
        x = self.upsample(x)
        x = self.mask(x)
        x = self.block_1(x)

        return x


if __name__ == "__main__":
    device = "cuda"
    USE_AMP = True

    o2_gated_nonlinearity = O2SteerableCNNGated().to(device)
    data = torch.randn(128, 1, 28, 28).to(device)
    with torch.autocast(device_type=device, dtype=torch.float16, enabled=USE_AMP):
        o2_gated_nonlinearity(data)

Running the script results in the following error:

/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541990/work/aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "/home/lars/Studie/Thesis/experimenting_code/old_tests/bug_report.py", line 67, in <module>
    o2_gated_nonlinearity(data)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lars/Studie/Thesis/experimenting_code/old_tests/bug_report.py", line 55, in forward
    x = self.block_1(x)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/escnn/nn/modules/sequential_module.py", line 88, in forward
    x = m(x)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/escnn/nn/modules/conv/r2convolution.py", line 208, in forward
    _filter, _bias = self.expand_parameters()
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/escnn/nn/modules/conv/rd_convolution.py", line 269, in expand_parameters
    _filter = self.basisexpansion(self.weights)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lars/miniconda3/envs/thesis/lib/python3.10/site-packages/escnn/nn/modules/basismanager/basisexpansion_blocks.py", line 354, in forward
    _filter[
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Which seems to suggest that at some part in the code a tensor is created with hardcoded Float dtype, which is not compatible with the Half dtype. I have spent a little bit of time trying to find where this happens, but I haven't found the exact part yet.

Note: Uncommenting lines 29 through 35 is an example where it does work, since a uniform field type is used there.
In this particular example, it would make more sense to use the uniform field type anyways. But for cases where you'd want a non-uniform field type this issue could be troublesome.

type.representations are tuple or list, combinations break

Hi @Gabri95, thanks for making this nice library ;).

I've managed to break it. You seem to be using type.representations sometimes a tuple and sometimes a list. When one tries to combine those, it breaks, because python is fine with list+list and tuple+tuple, but doesn't like tuple+list.

g = so3_group()
rho = G.irrep(0)+G.irrep(1)
gspace = no_base_space(G)
in_type = gspace.type(rho)
subgroup_id = (False, -1)  # identifies the SO(2) subgroup of rotations around the Z axis
h, _, _ = g.subgroup(subgroup_id)
gspace_h = no_base_space(h)
in_type_h = in_type.restrict(subgroup_id)
radial_type = gspace_h.type(h.irrep(0))
print(type(in_type.representations))
print(type(in_type_h.representations))
print(type(radial_type.representations))
in_type_h+radial_type

Gives error:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-71-4cf0c77ed4a3>", line 13, in <module>
    in_type_h+radial_type
  File "/usr2/pim/.local/lib/python3.8/site-packages/escnn/nn/field_type.py", line 436, in __add__
    return FieldType(self.gspace, self.representations + other.representations)
TypeError: can only concatenate list (not "tuple") to list

I'm working around this by adding:

in_type_h.representations = tuple(in_type_h.representations)

So no rush in fixing this. Just wanted to let you know.

Cheers,
Pim

testimage.jpeg missing

Just wanted to run the test suite to double check #29, but then I ran into another problem:

        if filename:
>           fp = builtins.open(filename, "rb")
E           FileNotFoundError: [Errno 2] No such file or directory: '../group/testimage.jpeg'

which ocurred in escnn/nn/modules/conv/r2convolution.py:245 (see

x = mpimg.imread('../group/testimage.jpeg').transpose((2, 0, 1))[np.newaxis, 0:c, :, :]
)

Irreps of direct product

Hi, I am testing the irreps of G $\times$ G, where G is SO3. I found that direct product leads to different frequencies. That causes some difficulties for me to pick the right irreps. I am not sure whether this is a design choice or a bug.

import escnn.group as g
r1 = g.so3_group(1)
r2 = g.so3_group(1)

r1.irreps()
# [SO(3)|[irrep_0]:1, SO(3)|[irrep_1]:3] # which is correct according to the definition

r12 = g.direct_product(r1, r2, name='1_order')

 r1.irreps()
# [SO(3)|[irrep_0]:1, SO(3)|[irrep_1]:3, SO(3)|[irrep_2]:5, SO(3)|[irrep_3]:7] # get more frequencies here

r12.irreps()
#[1_order|[irrep_[(0,),(0,)](0)]:1,
# 1_order|[irrep_[(0,),(1,)](0)]:3,
# 1_order|[irrep_[(0,),(2,)](0)]:5,
# 1_order|[irrep_[(0,),(3,)](0)]:7,
# 1_order|[irrep_[(1,),(0,)](0)]:3,
# 1_order|[irrep_[(1,),(1,)](0)]:9,
# 1_order|[irrep_[(1,),(2,)](0)]:15,
# 1_order|[irrep_[(1,),(3,)](0)]:21,
# 1_order|[irrep_[(2,),(0,)](0)]:5,
# 1_order|[irrep_[(2,),(1,)](0)]:15,
# 1_order|[irrep_[(2,),(2,)](0)]:25,
# 1_order|[irrep_[(2,),(3,)](0)]:35,
# 1_order|[irrep_[(3,),(0,)](0)]:7,
# 1_order|[irrep_[(3,),(1,)](0)]:21,
# 1_order|[irrep_[(3,),(2,)](0)]:35,
# 1_order|[irrep_[(3,),(3,)](0)]:49]

# I do not need frequency 2 and 3 in r12 r1 and r2

Model Initialization Extremely slow

Is there a way to speed up the model initialisation process? Every time I initialize the model, it takes over 30 minutes to initialize the model before the training starts.

Blog post about escnn: https://blogs.rstudio.com/ai/posts/2023-05-09-group-equivariant-cnn-3/

Hi @Gabri95,

just wanted to let you (and maybe others who could be interested in this) know that I've written a post about escnn that

  • uses escnn from R, by means of the interface package reticulate
  • provides a high-level introduction to how the math, the code, and the task to be solved go together (I call it an "introduction to an introduction" - the latter being your introduction.ipynb)

Here it is: https://blogs.rstudio.com/ai/posts/2023-05-09-group-equivariant-cnn-3/

Of course, close this issue any time, - I'm just taking the opportunity for a "sales pitch" of sorts ;-)):
In R, we have a port of PyTorch, as well as a steadily growing ecosystem based on it: see

https://github.com/mlverse/torch
https://github.com/mlverse

So in case someone reads this and could imagine porting escnn, we'd be more than happy to help :-)
Thanks!

Export linear layer

Hi!

I am trying to export a network containing linear layers, but I'm getting the following error:

File "X/site-packages/escnn/nn/modules/linear.py", line 305, in export
linear.weight.data[:] = self
TypeError: can't assign a Linear to a torch.FloatTensor

Do you know what's causing it?

Thanks!

The MaskModule does not support 3D input

When using the MaskModule on 3D input, the following error is raised:
File "/home/lars/anaconda3/envs/thesis/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/lars/anaconda3/envs/thesis/lib/python3.10/site-packages/escnn/nn/modules/sequential_module.py", line 88, in forward x = m(x) File "/home/lars/anaconda3/envs/thesis/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/home/lars/anaconda3/envs/thesis/lib/python3.10/site-packages/escnn/nn/modules/masking_module.py", line 70, in forward assert input.tensor.shape[2:] == self.mask.shape[2:] AssertionError

The two asserted shapes report as follows:
torch.Size([29, 29, 29]) torch.Size([29, 29])

So it seems that the MaskModule does not support 3D input. It would be nice if it also supported 3D input (either dynamically or through a flag/seperate input). Would that be possible?

strided r2conv operations as pooling substitute fails to construct

I was playing with strided convolutions as a pooling operation that retains equivariance (more robustly than Avg/MaxPool). For this, I was testing a simple setup:

def test_r2conv_fails_with_ksize2_stride2():

    reg_sconv = torch.nn.Conv2d(64, 64, kernel_size=2, stride=2, padding=1)
    assert hasattr(reg_sconv, "forward")

    gspace = gspaces.rot2dOnR2(N=4)
    iotype = nn.FieldType(gspace, 16 * [gspace.regular_repr])

    poolop = nn.R2Conv(
            iotype, iotype, kernel_size=2, stride=2, padding=1, bias=False
        )
    assert hasattr(poolop, "forward")

Unfortunately, the nn.R2Conv constructor triggers a ValueError:

E           ValueError: 
E                           The basis for the steerable filter is empty!
E                           Tune the `frequencies_cutoff`, `kernel_size`, `rings`, `sigma` or `basis_filter` parameters to allow
E                           for a larger basis.

py310/lib64/python3.10/site-packages/escnn/nn/modules/conv/rd_convolution.py:230: ValueError

I was trying to debug this, but couldn't get any insights out of this function

def compute_basis_params(

Model is SO(3) equivariant only for the first 12 testing elements

Hello,

I would like to construct an SO(3) equivariant model that takes a vector field as input and outputs a vector field. My attempt is equivariant to the first 12 elements of rot3donR3.testing_elements, but fails for the remaining 48.

I think it is because of the field type for the hidden layers. I am mapping the input vector field to a scalar field, and then back again. I also tried using irrep(j), j=1,2,3 for the hidden feature type after removing the normalization and activation functions but still passed/ failed for the same elements as when using the trivial representation.

Is there a different hidden feature type for which the equivariance holds for all testing elements? Based on rot3dOnR3.representations, it would seem that the trivial feature type and irrep are my only options.

Thank you,

Jacob

import escnn
import torch

class eqCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.r3_act = escnn.gspaces.rot3dOnR3()
        self.feat_type_in = escnn.nn.FieldType(self.r3_act, [self.r3_act.irrep(1)])
        self.feat_type_hid = escnn.nn.FieldType(self.r3_act, 8 * [self.r3_act.trivial_repr])
        self.feat_type_out = escnn.nn.FieldType(self.r3_act, [self.r3_act.irrep(1)])
        self.model = escnn.nn.SequentialModule(
            escnn.nn.R3Conv(self.feat_type_in, self.feat_type_hid, kernel_size=3, padding=1),
            escnn.nn.InnerBatchNorm(self.feat_type_hid),
            escnn.nn.ReLU(self.feat_type_hid),
            escnn.nn.R3Conv(self.feat_type_hid, self.feat_type_hid, kernel_size=3, padding=1),
            escnn.nn.InnerBatchNorm(self.feat_type_hid),
            escnn.nn.ReLU(self.feat_type_hid),
            escnn.nn.R3Conv(self.feat_type_hid, self.feat_type_out, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.model(escnn.nn.GeometricTensor(x, self.feat_type_in)).tensor

torch.manual_seed(1)
model = eqCNN()
x = torch.randn(1, 3, 8, 8, 8)

for j, g in enumerate(model.r3_act.testing_elements):
    x_transformed = escnn.nn.GeometricTensor(x, model.feat_type_in).transform(g).tensor
    y = model(x_transformed)
    y_transformed = escnn.nn.GeometricTensor(model(x), model.feat_type_in).transform(g).tensor
    if not torch.allclose(y, y_transformed, atol=1e-5):
        print(f"Failed for element {j}; mean error: {(y - y_transformed).abs().mean()}")

Inferring the flip in an O(2)-equivariant model

Hi,

First of all, congrats on this great piece of work.

I'm working on a model whose inputs are shapes given by their spherical harmonic coefficients. The problem I'm interested in requires O(2)-equivariance in the XY-plane.

I'm essentially interested in learning O(2)-invariant embeddings, plus a function that recognizes the element of O(2) that transformed the input, in similar fashion to what is done here

I'm able to easily construct the O(2)-invariant part, but I'm however having issues with the group recognition function (which should be O(2)-equivariant).

My idea was to output two steerable vectors which are acted on by O2.irrep(1, 1), restricting them to (0. , -1). I would use one of them to infer the flip, and the other to infer the rotation. However, I don't quite understand how the flipping axis is working. I thought that by restricting the irreps to (0. , 1) I would be fixing the flip axis to the X-axis, and so I expected that flipping my input (along any axis) would translate into a corresponding flip (along the X-axis) of my output steerable vectors. This is not what I'm observing in practice.

Rotations with no flips are working perfectly (i.e. the output vectors come rotated by the appropriate amount), and the invariant part is working for both flips and rotations. But when I flip the input, the output does not come flipped along the X-axis and I'm not able to tell what the actual flip axis is or how to fix it.

What am I missing?
Thanks in advance!

Suggestions on group and group representation

Hello,

I want to define a group such that given a vector input (a, b, c, d), the group actions would transform it into:

  1. (a, b, c, d)
  2. (-a, b, -c, d)
  3. (-a, -b, -c, -d)
  4. (a, -b, c, -d)

Any suggestions to define such a group and group representation from existing ones? I tried the irreducible representation of C4, but it didn't work.

Thanks!

Support navigating the API and migrating equivariant application to escnn

This is not a bug, however, since there are no discussions allowed for the repo (might consider enabling them?), I ask here. This is a call for aid in navigating the API for migrating a working application of equivariant learning, that is outside of the typical uses of ESCNN, but well within its capabilities.

Say I have a Trivial Principal Fiber Bundle $A = \mathbb{E_d} \times B$, where as principal fiber we have $\mathbb{E_d}$ usual Euclidean Group ($O(d) \times T_d$) and $B \in R^n$ the base space (a manifold with some symmetric properties).

In my application the bundle is trivial, meaning that $A$ is globally a product of two manifolds (a lie group and a manifold).

Note that because of triviality, I am not quite sure that fiber bundle formalism brings valuable mathematical machinery. Since we can work our way to understanding A as a product space. However, I am still becoming familiar with the theory, so there are high chances I am wrong.

I am interested in using ESCNN to define the symmetries of this space (to be exploited later in learning applications), but I am having some trouble finding my way.

Say the space $A$ has a known finite symmetry group $\mathcal{G}$, with known actions representations given by:

$$\rho_A(g) a = \begin{bmatrix} \rho_\mathbb{E_d}{g} & \boldsymbol{0}\\\ \boldsymbol{0} & \rho_B{g} \end{bmatrix} \begin{bmatrix} e\\\ b \end{bmatrix} \quad | a \in A, e \in \mathbb{E_d}, b \in B, \forall g \in \mathcal{G}$$

Where $\rho_\mathbb{E_d}]$ is the representation of $\mathcal{G}$ in the fiber space, taking the form of a finite group of Euclidean isometries (translations, rotations, reflections), while $\rho_\mathbb{B}$ is the representation of the group $\mathcal{G}$ in the base space, taking the form of a representation of a finite group, which dependent on the data I am assuming to be in $B$ (might be scalars, vectors, other fields) will have a "known" representation, but not necessarily a regular, permutation, irrep.

I have two questions that I want to solve before migrating stuff.

  1. Will the Gspaces formalism offer some utility in handling trivial bundles? Do I really need a Gspace? Or should I instead construct the two representations of the group independently and define $\rho_A(g)$ as a sum of the fiber and base representations ($\mathbb{E}_d$, $B$)?.
  2. I think I understand that if I have a data dependant representation of the finite group in $B$, to create the representation for my input data, I have to find the orthogonal transformation $Q$ to go from a say "permutation rep" to the desired one of my input data. Am I right? is there a way to define the representations directly by "inputting" the "known" rep, and then using the framework to compute $Q$? That is, to define a using a specific vector-space basis my group representation, and let the framework compute all the required transformations to obtain a "permutation rep" "decomposed rep" etc.

I understand the questions is long. Thanks for taking the time even to read. And thank you for providing this tool! I hope I can contribute to this project soon with some valuable additions/examples and tutorials of the likes of these .

Partial parameter expansion to subgroup

Hello,

I am interested in being able to change the group of an instantiated (and trained) group convolutional layer while preserving weight information. The export function accomplishes this to go from equivariance of the original group down to only translations (in the case of convolutions), but I would like to know how to do this in the general case of expanding the learnable parameters from equivariance of the original group to any subgroup, such as going from C8_on_R2 to C4_on_R2.

Thanks,
Kaitlin

escnn's conv, BN, relu is not equivariant?

I'm trying to train a image recognition network which should be invariant under SO(2). Here is a minimal example of my network that uses escnn's modules:

class EquivariantConvBN(enn.EquivariantModule):
  def __init__(self, in_type, out_type, kernel_size=3, stride=1, dilation=1, bias=True, use_norm=True):
    super().__init__()
    self.in_type = in_type
    self.out_type = out_type
    padding = (kernel_size - 1) // 2
    layers = []
    layers.append(enn.R2Conv(in_type, out_type, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, sigma=None))
    if use_norm:
      layers.append(enn.InnerBatchNorm(out_type))
    layers.append(enn.ReLU(out_type, inplace=True))

    self.layer = enn.SequentialModule(*layers)

  def forward(self, x):
    return self.layer(x)

  def evaluate_output_shape(self, input_shape):
    return self.layer.evaluate_output_shape(input_shape)

gspace = gspaces.rot2dOnR2(4)
n_feat = 16
in_type = enn.FieldType(gspace, [gspace.trivial_repr]*3)
out_type = enn.FieldType(gspace, [gspace.regular_repr]*n_feat)
conv = EquivariantConvBN(in_type, enn.FieldType(gspace, [gspace.regular_repr]*n_feat), kernel_size=7, stride=2).cuda().eval()
pool = nn.AdaptiveAvgPool2d(1).cuda()

rgb = np.random.uniform(0, 1, size=(480,480,3))   # Generate a random image
img = torch.tensor(rgb).float().cuda().permute(2,0,1)[None]
with torch.no_grad():
  out = pool(conv(enn.GeometricTensor(img, in_type)).tensor)
  for n_rot in range(0,3):
    rgb_tmp = cv2.rotate(rgb, n_rot)   # Rotate the image by 90/180/270 degs
    out_cur = pool(conv(enn.GeometricTensor(torch.tensor(rgb_tmp).float().cuda().permute(2,0,1)[None], in_type)).tensor)
    print(f'diff:{torch.abs(out-out_cur).max()}')

The output is

diff:0.00496324896812439
diff:0.007862985134124756
diff:0.007146209478378296

The idea is to test under rotations of 90/180/270deg to the input image, whether the final output is the same. My understanding is that by using equivariant network, the output tensor should be the same under these rotations. However, I found the result is not. Is there anything I'm missing?

Jax support

Hi,

Thanks for this really cool library!

Being a jax user, I was wondering whether you've thought on extending it to support jax, akin to e3nn-jax?

At first glance it seems that most of the library is 'pure' python, appart from the GeometricTensor and FieldType classes in escnn/nn/*, which seems easily translatable to jax.Array etc, and obviously all the layers in escnn/nn/modules which would need to be rewritten for flax / haiku / equinox.

I'd be happy to help out with this :)

Best,
Emile

Weight initialization to 0

Thank you for your great work on this library @Gabri95! I observed some behavior regarding the implemented R2 convolution layer and was wondering if you know why this is happening.

I'm using the O(2) group in my network and was wondering, why certain filters get initialized to zero. This only occurs for the O(2) group, not e.g. the SO(2) group, and happens in the forward pass of the BlocksBasisExpansion. As far as I could observe, the same filters are always set to zero, resulting in the same output dimensions being zero (1, 19 ...)

import numpy as np
import torch

from escnn.nn import *
from escnn.gspaces import *
from escnn.group import *

input = torch.rand(128, 3, 32, 32)
frequency = 4
gspace = flipRot2dOnR2(-1, maximum_frequency=frequency)
input_field_type = FieldType(gspace, [gspace.trivial_repr] * 3)
N = np.array([6, 11, 15, 19, 24]) * 2

FourierELU = FourierELU(
    gspace,
    8,
    irreps=[(0, 0)] + [(1, f) for f in range(frequency + 1)],
    N=N[frequency],
    inplace=True,
)

conv = R2Conv(
    input_field_type,
    FourierELU.in_type,
    kernel_size=7,
    padding=3,
    frequencies_cutoff=lambda r: 3 * r,
)

x = conv(GeometricTensor(input, input_field_type))
print(x.tensor[0,1,0])
print(x.tensor[0,19,0])

Am I doing something wrong here or is this desired behavior?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.