Git Product home page Git Product logo

Comments (3)

Gabri95 avatar Gabri95 commented on May 29, 2024

Hi @Guptajakala

I suspect the problem is that you are testing the invariance of your model's output when the model's output is not invarint but equivariant.
Indeed, your output type contains n_feat copies of a regular representation, i.e. your 16*4 dimensional output splits in 16 blocks of size 4. The 4 channels within each block permute when the input rotates.
Your test is not accounting for this.

To properly check for equivariance, you should "rotated back" the output out by n_rot.
You could do that by wrapping out in a new GeometricTensor with type out_type and then use the transform_fiber method.
In other words, you should replace out with

GeometricTensor(out, out_type).transform_fibers(gspace.fibergroup.element(-n_rot)).tensor

This is a bit verbose since you unwrapped GeometricTensors and you used cv2 to rotate.
The code is a bit shorter if you use one of our pooling operators (which return GeometricTensors) and loop over gspace.testing_elements() (which returns already a list of GroupElements).

Hope this helps,
Gabriele

from escnn.

Guptajakala avatar Guptajakala commented on May 29, 2024

@Gabri95 There is a pooling layer AdaptiveAvgPool2d at last. After the equivariance conv, suppose the feature shape is (B,D,H,W). After pooling, isn't it (B,D,1,1) and thus invariant? I guess even if I "rotate back", that single scalar in each HW dimension doesn't make any difference?

from escnn.

Gabri95 avatar Gabri95 commented on May 29, 2024

Hi @Guptajakala

In this way, the output will be (approx) invariant to translations but not to rotations.
This is because the D channels in the output of shape (B, D,1,1) will rotate when the input rotates, since you chose out_type = enn.FieldType(gspace, [gspace.regular_repr]*n_feat).
What you say would be correct if you used out_type = enn.FieldType(gspace, [gspace.trivial_repr]*n_feat).

When using out_type = enn.FieldType(gspace, [gspace.regular_repr]*n_feat), you can think of the channels dimension as being features over the rotation subgroup.
Check our tutorial notebook for a more intuitive description of the features of Steerable CNNs.

Hope this helps!
Gabriele

from escnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.