Comments (2)
Hey @kvndhrty
I think a pretty easy way to work around this is to not register your meta-template model as a submodule. You can easily do that by packing it into a list:
def __init__(self):
super().__init__()
with torch.device("meta"):
self._template_model = [TemplateModel()]
# then access it like so in your other code:
self._template_model[0]
# ... or write a getter to return you the template model without indexing
I think that the assumption Lightning makes about your model not being on the meta device after training is a reasonable one. Even so before training, since eventually Lightning moves the model to GPU before training. I think it would become quite complex if we had to add logic to ignore such submodules on the meta-device. More so, it would be error-prone, because meta-device initialization is needed for large model training.
So I would like to suggest we don't treat this as a bug.
One other thing you could do is ask yourself whether it is even necessary to have your template model as an attribute at all. Since the creation on meta-device is basically free, you could also just do that on-the-fly whenever you need that. Get the properties you need and store them somewhere. Then you don't need to keep that template model around.
from pytorch-lightning.
@awaelchli I think that is entirely reasonable, I'll pack my module into a list for now. The small re-factor required to init the meta module each time isn't something I'll do this week, but maybe in the near future.
Thank you for the quick response!
from pytorch-lightning.
Related Issues (20)
- Bug: automatic logging doesn't log metric on steps if .update is used
- Remove the `optimizer_to_device` logic if possible HOT 3
- False positive iterable dataset warning for LitData StreamingDataset
- Inconsistent input io type between `to_onnx` and `torch.onnx.export`.
- loss spikes in validation step when the model has multiple losses applied HOT 1
- Trainer does not switch to train mode after validation step HOT 2
- trainer.validate() get different result from trainer.fit
- trainer test and validate have issues with autograd
- `Error while merging hparams` when using LightningCLI and YAML HOT 3
- MLFlowLogger does not save config.yaml for each run
- Checkpoint callback run before validation step - stale or none monitor values considered for validation metrics HOT 2
- OnExceptionCheckpoint callback suppresses exceptions and results in NCCL timeout
- Seeding and multi-GPU training HOT 1
- Support IO Type Checkpoints for trainer.fit() in ckpt_path Parameter
- shortcuts for logging weights and biases norms
- Unable to load Checkpoint
- Add param_group name for BaseFinetuningCallback
- LightningCLI: --help argument given after the subcommand fails
- ModelCheckpoint Callback not working/saving unless `save_on_train_epoch_end` is enabled True which considerably slows down training
- 7x slower training speed when switching from lightning 1.0 to 2.0
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-lightning.