Wrapper doesn't work with a batch of data :
File "/Users/titou/Documents/flair-2/src/models/make_train.py", line 82, in
main()
File "/Users/titou/Documents/flair-2/src/models/make_train.py", line 77, in main
trainer.fit(model=lightning_model)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 529, in fit
call._call_and_handle_interrupt(
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 568, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 973, in _run
results = self._run_stage()
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1016, in _run_stage
self.fit_loop.run()
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run
self.advance()
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance
self.epoch_loop.run(self._data_fetcher)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
self.advance(data_fetcher)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 218, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 185, in run
self._optimizer_step(kwargs.get("batch_idx", 0), closure)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 260, in _optimizer_step
call._call_lightning_module_hook(
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 144, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1256, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py", line 155, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 225, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 114, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 280, in wrapper
out = func(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 33, in _use_grad
ret = func(self, *args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torch/optim/adamw.py", line 148, in step
loss = closure()
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 101, in _wrap_closure
closure_result = closure()
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in call
self._result = self.closure(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 126, in closure
step_output = self._step_fn()
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 307, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 291, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 367, in training_step
return self.model.training_step(*args, **kwargs)
File "/Users/titou/Documents/flair-2/./src/models/lightning.py", line 94, in training_step
outputs = self.forward(inputs={'aerial': aerial, 'sen': sen})
File "/Users/titou/Documents/flair-2/./src/models/lightning.py", line 86, in forward
x = self.model(inputs=inputs, step=self.step, batch_size=self.batch_size)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/titou/Documents/flair-2/./src/data/tta/wrappers.py", line 23, in forward
inputs = augmentation.augment(inputs, param)
File "/Users/titou/Documents/flair-2/./src/data/tta/augmentations.py", line 77, in augment
inputs[key] = F.rotate(inputs[key], angle=angle)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torchvision/transforms/functional.py", line 1140, in rotate
return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torchvision/transforms/_functional_tensor.py", line 669, in rotate
return _apply_grid_transform(img, grid, interpolation, fill=fill)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torchvision/transforms/_functional_tensor.py", line 560, in _apply_grid_transform
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
File "/opt/homebrew/Caskroom/miniconda/base/envs/flair-2-env/lib/python3.10/site-packages/torch/nn/functional.py", line 4244, in grid_sample
return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
RuntimeError: grid_sampler(): expected grid to have size 3 in last dimension, but got grid with sizes [16, 40, 40, 2]