Comments (4)
It seems to be caused by the INFERENCE MODE?
I added torch.is_inference_mode_enabled
to with torch.enable_grad():
, and found that it was False
in validation_step
, but True
in test_step
.
with torch.enable_grad():
N = 5
f = lambda x: x ** 2
x = torch.randn(N, requires_grad=True)
# print(x.requires_grad)
y = f(x)
I_N = torch.eye(N)
print(torch.is_inference_mode_enabled()) # False in train and validation, but True in test.
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True,allow_unused=True)[0]
for v in I_N.unbind()]
jacobian = torch.stack(jacobian_rows)
return jacobian
from pytorch-lightning.
with torch.enable_grad():
x= torch.randn(3,10).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=False
x.requires_grad=True
y.requires_grad=True
with torch.inference_mode():
with torch.enable_grad():
x= torch.randn(3,1).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=True
x.requires_grad=True
y.requires_grad=False
I want all subsequent gradients as well in test_step, how do I do that?
from pytorch-lightning.
Do I seem to have found a solution?
Use with torch.inference_mode()
instead of with torch.enable_grad()
.
with torch.inference_mode():
x= torch.randn(3,10).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
with torch.inference_mode(mode=False):
# with torch.enable_grad():
x= torch.randn(3,10).requires_grad_(True)
y = x+1
print(f"{torch.is_inference_mode_enabled()=}")
print(f"{x.requires_grad=}")
print(f"{y.requires_grad=}")
torch.is_inference_mode_enabled()=True
x.requires_grad=True
y.requires_grad=False
torch.is_inference_mode_enabled()=False
x.requires_grad=True
y.requires_grad=True
from pytorch-lightning.
Inference mode is the default for validation/testing in the Trainer:
https://lightning.ai/docs/pytorch/stable/common/trainer.html#inference-mode
You can't take gradients in the validation/test_step methods by default. But you can turn it off by setting and using torch.no_grad(enabled=False)
or like you've done above.
from pytorch-lightning.
Related Issues (20)
- MisconfigurationException: Do not set `gradient_accumulation_steps` in the DeepSpeed config
- Dataloader on multi-gpu jobs only surpport to manipulate on local_rank=0, is there a way tom manipulate every device?
- Error when fast_dev_run=True or num_sanity_val_steps=0 and using torchmetrics MetricTracker
- Fabric: Incorrect `num_replicas` (ddp/fsdp) when number of GPUs on each node is different HOT 2
- Creating A Second Comet Logger Disables The First
- CUDA unknown error HOT 1
- AttributeError: type object 'Trainer' has no attribute 'add_argparse_args' HOT 1
- Add functionality to save nn.Modules supplied as arguments when initialising LightningModule
- I think it's deadly necessary to add docs or tutorials for handling the case when We return multiple loaders in test_dataloaders() method? I think it
- "save_last" could not save a complete checkpoint
- LR_FIND() does not work in DDP anymore, RuntimeError: No backend type associated with device type cpu
- KeyboardInterrupt raises an exception which results in a zero exit code
- XLA FSDP strategy has undocumented requirement for using activation checkpointing
- The training process will stop unexpectedly HOT 1
- forward method missing required positional argument ‘masks’ in PyTorch Lightning HOT 2
- Lightning Fabric: generic method to get the full state dict
- ModelCheckpoint does not work when using the monitor
- Continuing training with `ckpt_path="last"` and MLFLowLogger fails in distributed setting
- is `lightning` and `pytorch_lightning` the same? HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-lightning.