Comments (1)
@carlos-havier As mentioned in the PyTorch Lightning documentation, when using ddp_notebook
, the downside is:
"GPU operations such as moving tensors to the GPU or calling torch.cuda functions before invoking
Trainer.fit
is not allowed."
This means that there can be no CUDA tensors before calling Trainer.fit
. By default, when training, PyTorch Lightning saves the state_dict
of the trainer as CUDA when using GPU. So when load from checkpoint, CUDA is initialised. You can verify this with a simple check:
print(torch.cuda.is_initialized())
This can be placed before and after calling:
pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt)
You'll observe that CUDA is initialized when calling load_from_checkpoint
, and once CUDA is initialized here, it cannot be re-initialized in a different context as required by ddp_notebook
.
The Fix:
Use map_location
as CPU when calling load_from_checkpoint
:
pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt, map_location=torch.device('cpu'))
For your reference, I have added a few lines to debug based on your notebook here.
from pytorch-lightning.
Related Issues (20)
- Make checkpoint saving fully atomic HOT 3
- Increase MlflowLogger parameter value length limit
- FSDPPrecision should support 16-true with a loss scaler HOT 1
- Hyperparameter logging with multiple loggers only works partially (TensorBoard and CSV)
- With yaml config file for LightningCLI, `self.save_hyperparameters()` behavior abnormal HOT 10
- Running `test` with LightningCLI, the program can quit before the test loop ends
- autocast to float16/bfloat16 fails on transformer encoder HOT 4
- [fabric.example.rl] Not support torch.float64 for MPS device
- Another profiling tool is already active
- Add truncated backpropagation through time (TBPTT) example HOT 2
- Handle gradient accumulations at the end of epoch differently HOT 1
- TransformerEnginePrecision _convert_layers(module) fails for FSDP zero2/zero3
- EarlyStopping override disrupts wandb logging frequency
- Show how to over-fit batches for real HOT 2
- Fabric example trainer fails with validation
- Logging with Fabric using steps HOT 2
- Load from checkpoint doesn't load model for inference HOT 1
- Cannot correctly parse some import paths when using config file for LightningCLI HOT 2
- error: Parser key "data": Problem with given class_path 'my_class_path': __args__
- Pinning the `lightning` package doesn't pin the `pytorch_lightning` package HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-lightning.