Git Product home page Git Product logo

Comments (4)

SongJgit avatar SongJgit commented on July 22, 2024

It seems to be caused by the INFERENCE MODE?
I added torch.is_inference_mode_enabled to with torch.enable_grad():, and found that it was False in validation_step, but True in test_step.

    with torch.enable_grad():
        N = 5
        f = lambda x: x ** 2
        x = torch.randn(N, requires_grad=True)
        # print(x.requires_grad)
        y = f(x)
        I_N = torch.eye(N)
        print(torch.is_inference_mode_enabled()) # False in train and validation, but True in test.
        # Sequential approach
        jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True,allow_unused=True)[0]
                        for v in I_N.unbind()]
        jacobian = torch.stack(jacobian_rows)

    return jacobian

from pytorch-lightning.

SongJgit avatar SongJgit commented on July 22, 2024
with torch.enable_grad():
    x= torch.randn(3,10).requires_grad_(True)
    y = x+1
    print(f"{torch.is_inference_mode_enabled()=}")
    print(f"{x.requires_grad=}")
    print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=False
x.requires_grad=True
y.requires_grad=True
with torch.inference_mode():
    with torch.enable_grad():
        x= torch.randn(3,1).requires_grad_(True)
        y = x+1
        print(f"{torch.is_inference_mode_enabled()=}")
        print(f"{x.requires_grad=}")
        print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=True
x.requires_grad=True
y.requires_grad=False

I want all subsequent gradients as well in test_step, how do I do that?

from pytorch-lightning.

SongJgit avatar SongJgit commented on July 22, 2024

Do I seem to have found a solution?
Use with torch.inference_mode() instead of with torch.enable_grad().

with torch.inference_mode():
    x= torch.randn(3,10).requires_grad_(True)
    y = x+1
    print(f"{torch.is_inference_mode_enabled()=}")
    print(f"{x.requires_grad=}")
    print(f"{y.requires_grad=}")
    with torch.inference_mode(mode=False):
        # with torch.enable_grad():
            x= torch.randn(3,10).requires_grad_(True)
            y = x+1
            print(f"{torch.is_inference_mode_enabled()=}")
            print(f"{x.requires_grad=}")
            print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=True
x.requires_grad=True
y.requires_grad=False
torch.is_inference_mode_enabled()=False
x.requires_grad=True
y.requires_grad=True

from pytorch-lightning.

awaelchli avatar awaelchli commented on July 22, 2024

Inference mode is the default for validation/testing in the Trainer:
https://lightning.ai/docs/pytorch/stable/common/trainer.html#inference-mode
You can't take gradients in the validation/test_step methods by default. But you can turn it off by setting and using torch.no_grad(enabled=False) or like you've done above.

from pytorch-lightning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.