Git Product home page Git Product logo

Comments (7)

frgfm avatar frgfm commented on June 1, 2024 1

Hello @dayunyan 👋

Thanks for reporting this!
So this is an advanced topic actually, here's what's happening under the hood:

  • when you create a CAM extractor, two things happen in terms of hooks
    • if no target layer is passed, it will set temp forward hooks & make a forward pass to select one by default
    • once this is resolved, it will set & enable hooks necessary for the CAM computation. For gradient based methods, that means backward hooks 😅
  • so now that a first extractor is set
    • by default, its hooks are enabled (waiting for a model forward)
    • creating the second extractor without passing explicitly the target layer will trigger the layer resolution mechanism. However the forward pass of this resolution is run without checking gradients. But because you already have the first extractor & its gradient hooks enabled, that trigger the error 🙃

Now you have two solutions:

  1. Disable the hooks at the right moment
from torchvision.models import resnet50
from torchcam.methods import XGradCAM

model = resnet50().eval()
# Disable CAM computation hooks
extractor_1 = XGradCAM(model, enable_hooks=False)
extractor_2 = XGradCAM(model)
# Re-enable them
extractor_1._hooks_enabled = True
  1. Pass the target layer explicitly to avoid the problem
from torchvision.models import resnet50
from torchcam.methods import XGradCAM

model = resnet50().eval()
extractor_1 = XGradCAM(model)
# Use the layer resolution of the first extractor to avoid the double resolution
extractor_2 = XGradCAM(model, target_layer=extractor_1.target_names)

I would highly recommend the second one which is more efficient / faster :)

I'm trying to think about a way to prevent this behavior, but that means I should disable hooks automatically before attempting to locate a layer 🤯

One question though: why would you create two identical CAM extractors for the same model? 🤔

Anyway, I hope this helped!

from torch-cam.

dayunyan avatar dayunyan commented on June 1, 2024

Thank you very much! What an excellent answer! @frgfm

As for your question, in the project source code, I want to extract the CAM in the process of model training and use it as a standard to calculate a loss. In my mind, since I updated the parameters in the model using loss.backward() & optimizer.step(), I should create a new extractor for the updated model before the next training loop, so I ran into this error.
The code structure looks like the following

begin train:
    for batch, [data, labels] in enumerate(dataloder):
        extractor = XGradCAM(model) # The error always occurs the second time the extractor is created
        # Input data and extract the CAM
        ........
        # Get the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Now I wonder if I only need to create an extractor once outside the loop, and as the parameters of the model are updated, the parameters inside the extractor are updated as well.

I'm not sure if you understand what I mean, but anyway, I really appreciate your answer!

from torch-cam.

frgfm avatar frgfm commented on June 1, 2024

You're very much welcome 😄
I started this project initially simply to implement research papers about DL interpretability to understand them, so the design might differ a bit from other neighbouring efforts!

About your question:

  • every time you create an extractor, it creates hooks (I haven't implemented an auto removal when it's dereferenced, I opened #197 for that)
  • so I suggest creating the extractor once outside the loop
  • about the CAM computation, two things to consider
    1. the backprop of the loss does populate the gradients. But that's the partial derivative of the loss from a mathematical point of view, which increase with inaccuracies of the model.
    2. the CAM computation backprops a synthetic loss if you will (one-hot vector on the class you're interested in multiplied by output logits). This happens when extractor(class_idx, out) is called.

Now, here are my suggestions:

  • if you want to use the gradient of the cam computation for param updating, bear in mind that for the optimization/param update to be relevant, the gradient needs to represent a variable that increases with errors. So if class_idx is the target, then use maximize=True in your optimizer (cf. https://pytorch.org/docs/stable/generated/torch.optim.Adam.html). That will reverse the sign of the update for the parameters.
from torch.optim import Adam
from torchcam.methods import LayerCAM

model = ...
optimizer = Adam(model.parameters(), maximize=True)
extractor = LayerCAM(model)

# Epoch loop
for x, target in dataloader:
    optimizer.zero_grad()
    out = model(x)
    # Backprop the CAM grads
    extractor(target.numpy().tolist(), out)
    optimizer.step()

Please understand that this sounds highly experimental, I cannot vouch for the outcome

  • if you want to use the gradient of the loss for the CAM computation, you'll have to manually modify torchcam since the automatic backprop is enforced. Sorry 🙃

Hope this isn't too obscure 😅

from torch-cam.

dayunyan avatar dayunyan commented on June 1, 2024

To be honest, it's a little obscure 😂 But I will try to understand.

Thanks very very very very much again!🥰

from torch-cam.

frgfm avatar frgfm commented on June 1, 2024

Feel free to ask if you have specific questions later on, I'll try my best to answer :)

from torch-cam.

niniack avatar niniack commented on June 1, 2024

Hi @frgfm

I'm not sure if this is the best spot to ask this question, but it seems somewhat relevant to this topic.

somewhere outside the training loop:

model = resnet50()
cam_extractor = GradCAM(model, target_layer='conv1')

This is a simplified excerpt from my training loop:

optimizer.zero_grad()

# Forward pass
outputs = model(inputs)

# Batch processing
cams = cam_extractor(outputs.argmax(dim=1).tolist(), outputs, retain_graph=True)

#Grab the first from list (we only have one target layer)
cams = cams[0]

# Simplified but nothing wild happening here
custom_loss = custom_loss(cams)

loss = criterion(outputs, labels)
loss += custom_loss
loss.backward()
optimizer.step()

This doesn't work too well for me and I suspect that I'm losing the computational graph somewhere? When I do:

print(cams.requires_grad)

I get a False. So, I'm not really able to go backward (I think?). Setting cams.requires_grad=True doesn't seem like the right answer either.

Am I on the right track? I posted this here I am also trying to use the CAM for updating parameters.

from torch-cam.

frgfm avatar frgfm commented on June 1, 2024

Hi @niniack 👋

Actually, I think this is unrelated to this topic, but let's check this:

  • first thing, for me to debug, it's much easier when I have a minimal snippet to reproduce the problem on my end. I understand you might not want to share all of that and that's quite fine :) However "This doesn't work too well for me" is not very helpful haha what do you mean? NaNs? a uniform CAM?
  • I don't have the context here, but without the training loop, I can only try to guess: you are correct that what you pass to a loss function/criterion needs to require gradient. The goal is to compute the derivative of parameters in regards to the loss. I have no clue what custom_loss does, but CAMs by nature don't need to propagate the gradients (i.e. I didn't ensure the operations done to compute the CAMs were preserving the gradient flow)
  • That means you can easily use them as targets or any other tensors that doesn't require gradients. I imagine that will be easier than making CAMs backpropable

I hope this helps a bit!
Cheers ✌️

from torch-cam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.