Comments (20)
The topological sort error can be ignored -- the code works correctly despite of this.
Did you change this line https://github.com/alexlee-gk/video_prediction/blob/master/video_prediction/datasets/kth_dataset.py#L17 to use (32, 32, 3)
?
from video_prediction.
Yup! Doing this caused the reshape error. I have an additional question: changing the batch size parameter should be done in base_dataset.py or base_model.py or both?
And a further clarification: what is the clip_length parameter in base_model.py?
from video_prediction.
It depends on how the dataset and model are being used. For example, in the training script, the dataset is created with the batch_size
specified by the model: https://github.com/alexlee-gk/video_prediction/blob/master/scripts/train.py#L141-L146
clip_length
is the number of consecutive frames that are given to the discriminator at each training iteration. This is useful when training with a high sequence_length
but wanting to limit the temporal receptive field of the discriminator (e.g. as regularization or to save computation). In such cases, a subsequence of length clip_length
is sampled from the generated sequence. In our final experiments, we ended up using the whole generated sequence, so clip_length = sequence_length - context_frames
.
from video_prediction.
Ah okay, got it! I've managed to get the model running on the kth dataset (I was changing the wrong batch_size before).
I've gotten errors with respect to the clip_length, so just to confirm - is it alright for the clip length to be greater than the number of context frames? As long as clip length is <= (sequence_length - context_frames), it should be alright?
Thank you very much for your help and quick response!
from video_prediction.
Great! clip length is <= (sequence_length - context_frames)
. Do let me know if you still encounter an error after adjusting the clip_length
.
from video_prediction.
So eventually it is still running out of memory unfortunately. Just confirming that model checkpoints are still saved - so I should be able to run the evaluate.py script on the intermediate result?
I am getting the following error when running evaluate.py:
assert true.shape == pred.shape
AttributeError: 'NoneType' object has no attribute 'shape'
Thanks again for your help!
from video_prediction.
Yes, you should be able to run evaluate.py script on the intermediate result. The evaluated tensors true
and pred
should not be None
, so you can look into why that's the case.
In regards to running out of memory in the middle of training (and not at the beginning of training), it might be caused when computing the losses on the validation set (which is done only for the tensorboard summaries) simultaneously while training. You can try the experimental
branch, where the validation evaluation is done in a separate session.run()
.
from video_prediction.
Thanks! I am trying my own data now shaped (128, 128, 1), but I am still getting a reshape error. My dataset is correctly outputting 128x128. Is there another place I need to change a parameter to switch from 3-channel to 1-channel images? Could you clarify the process for generating the savp model encoder/decoder structure 'self.hparams.ngf'? Does this somehow depend on the number of channels?
I am also having reshape issues with (256, 256, 1) -> off by 8 in both cases. I am using your pre-written 256 scale architecture. I have additionally tried copying my data into 3 channels -> passing in (256, 256, 3) - I still have a factor of 8 reshape error. I've tried the kth dataset - reworked to (256, 256, 3), and that appears to run without the reshape error. Is there a particular detail in passing in a different dataset that might be causing this factor of 8 difference in the session run with 'fetches'?
EDIT: I figured out what was going on! My input is float arrays ranging from 0 to 1, not uint8 arrays. This was causing issues for the tf.train.Feature -> BytesList. Do you know what's the best way to create a trainable feature such that it takes in a list of float numpy arrays?
Thanks again!
from video_prediction.
Hi mitkina, sorry for the late reply. If you still want to try, I have made several improvements in the experimental
branch (will soon be merged into master
). The dataset and model now runs with other dimensions, including (128, 128, 1) images. The KTH dataset can be preprocessed to 128x128 resolution with this command:
bash data/download_and_preprocess_dataset.sh kth 128
The encoder/decoder structure is made following the encoder_layer_specs
/decoder_layer_specs
, which specify the number of output channels for each Conv2D and whether a Conv2DRNN layer should follow that Conv2D. Although the SAVP model have some predefined layer specs, I have not tried them out with datasets at that resolution, so you might need to adjust those to use less parameters, increase speed, use less memory, etc. The code in experimental
is able to train a VAE on the 128x128 dataset with the current defaults, with a batch size of 4 in a Titan X.
I'd also recommend first training either a deterministic model or a VAE-only model (i.e. no GANs). Those might already give pretty good KTH predictions, and they will use less memory and be faster to train because they won't be using discriminators. I'd also recommend using a much shorter context_frames
, e.g. --model_hparams context_frames=2,sequence_length=12
as in the BAIR dataset.
In regards to using float images, you can probably use FloatList
instead of BytesList
, and remove tf.image.convert_image_dtype
from the decode_and_preprocess_images
method.
from video_prediction.
Hi Alex,
Thanks for the information! I have already gotten the CVAE and the SAVP model to work with 128x128x1 images, with fewer context frames :) I actually found the CVAE was doing better than the SAVP for my dataset. But still overfitting.
In tensorboard, there's summary and summary_1 - what is the difference between these two? They both seem to be from the training set? Do you evaluate the validation set anywhere?
The FloatList does not work in this case because I have a list of arrays of floats, not a list of floats. I could unroll the image into a list of floats I guess... For now I've just discretized my continuous values into 256 bins for the network to work.
Thanks!
from video_prediction.
That's great! Yes, there is a trade-off between the VAE and the SAVP model. The former might be better at traditional metrics that penalize for spatial distortion (e.g. PSNR and SSIM) but the latter might be better at producing more realistic images, assuming that the dataset can benefit from that (e.g. real-world datasets with a lot of texture details). I haven't noticed overfitting on the datasets I've tried. Is your dataset large enough? You can also try standard data augmentation preprocessing (e.g. random cropping and horizontal flipping).
The difference is that the tensorboard summaries ending in "_1" correspond to the validation set. Furthermore, if you are using the newest version of my experimental
branch, there are summaries ending in "_2", which correspond to the validation set but using longer sequences than used for training (assuming that the dataset's long_sequence_length != sequence_length
). I have also repurposed the PR curves summaries for 2D plotting of the metrics over the prediction steps (to see them, you need to build tensorboard from source, see tensorflow/tensorboard#1110).
from video_prediction.
Update: I added details about the summaries in the README.
from video_prediction.
Awesome, thanks! My dataset is rather small, so that's definitely part of the problem. I was also training uneven lengths of video sequences, so during training the network was taking samples from each video, as far as I understand. I am now retraining both SAVP and CVAE on the videos split up into the correct sequence length (10 steps total - with 5 step prediction), to see if that performs any better. I have been seeing interesting results that a ConvLSTM which learns the internal representation of the dynamics was performing better in all the metrics (PSNS, MSE, SSIM, TPR, TNR, FPR, FNR). I am hoping that with this new training procedure, the CVAE will perform better, but perhaps a CVAE is just more data intensive. Although data augmentation might be useful, I can only really do the video flipping, the cropping does not make a lot of sense in my context I think.
from video_prediction.
In the SAVP network during training, I have been seeing the discriminator loss decrease considerably, but the generator loss increasing. It seems the discriminator is overpowering the generator :S
from video_prediction.
Hi Alex,
Just an update and a quick question. The CVAE now performs much better (due to the dataset split), but the SAVP network is performing very poorly (it might also just need more epochs than the CVAE). I was wondering if you could clarify the difference in TensorBoard of the transformed images summary and the gen_images_enc_summary? My understanding is that the former is the generator output and the latter is the encoding received by the generator?
Thanks again!
from video_prediction.
I've retrained the models with the new hyperparameters, and I noticed that the SAVP model with flow transformations doesn't do as well as the one with CDNA transformations. So, you can try the latter again with tv_weight=0,transformation=cdna
.
The summaries with the "_enc" postfix refer to generations sampled from the approximate posterior q (i.e. the encoder), whereas the corresponding summaries without that postfix refer to generations sampled from the prior.
Also, the TP/FP/TN/FN numbers in the PR curves tab doesn't mean anything in my setup. These are just dummy values that were needed to hack the PR curve summaries to plot arbitrary 2D plots.
For the SAVP model, the GAN loss usually decreases over training iterations for the discriminator, but increases for the generator. It should be okay as long as the values are between 0 and 1.
from video_prediction.
Thanks for the feedback! I will give the cdna transformation another shot!
Re: TP/FP/TN/FN - I wrote my own, so no worries!
That's good to know regarding the GAN!
from video_prediction.
I've retrained the models with the new hyperparameters, and I noticed that the SAVP model with flow transformations doesn't do as well as the one with CDNA transformations. So, you can try the latter again with
tv_weight=0,transformation=cdna
.The summaries with the "_enc" postfix refer to generations sampled from the approximate posterior q (i.e. the encoder), whereas the corresponding summaries without that postfix refer to generations sampled from the prior.
Also, the TP/FP/TN/FN numbers in the PR curves tab doesn't mean anything in my setup. These are just dummy values that were needed to hack the PR curve summaries to plot arbitrary 2D plots.
For the SAVP model, the GAN loss usually decreases over training iterations for the discriminator, but increases for the generator. It should be okay as long as the values are between 0 and 1.
Hi Alex
I trained the SAVP on my dataset, but the GAN loss increases over iterations and its lowest value is around 5 and increases up to 22. What does it mean? Could I trust the generator in this state?
from video_prediction.
For the BAIR dataset, the GAN loss for the discriminator converges to around 0.15-0.2, and the GAN loss for the generator does go up slightly over training iterations but the final value is around 1. You might need to increase the relative weighting of the GAN loss vs the other losses for your dataset, such that the GAN loss is not too high.
from video_prediction.
Thanks for your feedback, I should definitely change the weights for my dataset, I've used kth parameters up to now. The values of gen_kl_loss increases during iterations up to 182, and may be more in the following iterations.
Here is the values of different losses in the middle of the training
learning_rate 0.0002
progress global step 14400 epoch 10.3
image/sec 13.2 remaining 5774m (96.2h) (4.0d)
d_loss 0.023427725
discrim_video_sn_gan_loss (0.14210741, 0.1)
discrim_video_sn_vae_gan_loss (0.09216982, 0.1)
g_loss 20.750145
gen_l1_loss (0.032766685, 100.0)
gen_video_sn_gan_loss (0.5542432, 0.1)
gen_video_sn_vae_gan_loss (0.8700853, 0.1)
gen_video_sn_vae_gan_feature_cdist_loss (1.7331043, 10.0)
gen_kl_loss (175.23955, 0.0)
Could you please suggest me how to change the following weights?
"l1_weight": 100.0,
"kl_weight": 0.01,
"video_sn_vae_gan_weight": 0.1,
"video_sn_gan_weight": 0.1,
"vae_gan_feature_cdist_weight": 10.0,
"gan_feature_cdist_weight": 0.0,
"state_weight": 0.0,
"nz": 32
from video_prediction.
Related Issues (20)
- Is downloading dataset necessary for sample prediction videos? HOT 1
- Error in downloading dataset (partially downloaded) HOT 1
- Checkpoint data loss error when evaluating
- Training stability & progress HOT 1
- train with my own dataset
- what is the difference between Bair action free and action conditioned HOT 1
- FailedPreconditionError while trying to predict using gan_only model on KTH
- ValueError: as_list() is not defined on an unknown TensorShape. HOT 6
- Unable to download pretrained model HOT 1
- Using trained model for custom sized images HOT 1
- Questions about evaluation with the deterministic model
- KeyError: 'gen_states' when run train.py
- KL Loss Weight is zero
- The KHT dataset have not existed any more
- File "scripts/generate.py", line 15, in <module> from video_prediction import datasets, models ModuleNotFoundError: No module named 'video_prediction' HOT 1
- Training error HOT 4
- Testing on custom images
- Dependency Nightmare
- bash download pre-trained model gives an error
- CDNA Masks
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from video_prediction.