Comments (2)
This should not happen. Can you update the snippet below to show the problem?
import os
import torch
from lightning.pytorch import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def train_dataloader(self):
raise RuntimeError
def val_dataloader(self):
raise RuntimeError
def test_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.test(model)
if __name__ == "__main__":
run()
from pytorch-lightning.
I found the bug. It appears when you use the strategy "deepspeed" in the trainer. Code below :)
import os
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import lightning as L
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class LDataset(L.LightningDataModule):
def __init__(self):
super().__init__()
self.num_samples = None
self.weights = None
self.len = None
self.train_data = None
self.test_data = None
def setup(self, stage: str):
if stage == "fit":
self.train_data = RandomDataset(32, 14)
# since RandomSampler only balances train data, the weights are calculated here naturally
self.weights = [1, 1, 1, 1, 1, 1]
self.num_samples = len(self.train_data)
if stage == "test":
self.test_data = RandomDataset(32, 14)
def train_dataloader(self):
return DataLoader(self.train_data,
sampler=WeightedRandomSampler(replacement=True,
weights=self.weights,
num_samples=self.num_samples
),
batch_size=2)
def val_dataloader(self):
raise RuntimeError
def test_dataloader(self):
return DataLoader(self.test_data, batch_size=2)
class BoringModel(L.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
model = BoringModel()
mydata = LDataset()
trainer = L.Trainer(
strategy="deepspeed",
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
log_every_n_steps=10,
enable_checkpointing=True,
check_val_every_n_epoch=5
)
trainer.test(model, datamodule=mydata)
if __name__ == "__main__":
run()
from pytorch-lightning.
Related Issues (20)
- Unable to extract confusion matrix as a metric from trainer HOT 1
- Torchmetrics Accuracy issue when dont shuffle test data. HOT 1
- ModelCheckpoint: Using save_top_k, only the first k models are stored, not the best k models HOT 2
- trainer.fit from checkpoint without performance improvement will break 'last' link to checkpoint on window11
- Exception in RecordFunction callback: state_ptr INTERNAL ASSERT FAILED at "../torch/csrc/profiler/standalone/nvtx_observer.cpp":115
- `ckpt_path` in `Trainer` accepts URIs to automatically load checkpoints from remote paths
- When doing tuner.scale_batch_size, check full dataset length first
- RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [68]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
- Is the Lightning App deprecated? (Lightning App doc is not found) HOT 1
- Dynamically link arguments in `LightningCLI`? HOT 2
- Support get optimizer and lr_schedulers from deepspeed config
- Validation dataloader is added to train dataloader after first epoch
- Add dog has an error: FileNotFoundError: HOT 1
- Resume training, how to change learning scheduler? HOT 1
- Possible bug in recognizing `mps` accelerator even though PyTorch seems to register the `mps` device?
- Error loading a saved model to run inference (using ddp_notebook strategy)
- AttributeError: module 'pytorch_lightning.callbacks' has no attribute 'ProgressBarBase'. Did you mean: 'ProgressBar'?
- Using the MLflow logger produces Inconsistent metric plots HOT 2
- can't fit with ddp_notebook on a Vertex AI Workbench instance (CUDA initialized)
- Lightning stalls with 2 GPUs on 1 node with SLURM (and apptainer)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-lightning.