Git Product home page Git Product logo

Comments (20)

alexlee-gk avatar alexlee-gk commented on June 2, 2024

The topological sort error can be ignored -- the code works correctly despite of this.

Did you change this line https://github.com/alexlee-gk/video_prediction/blob/master/video_prediction/datasets/kth_dataset.py#L17 to use (32, 32, 3)?

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Yup! Doing this caused the reshape error. I have an additional question: changing the batch size parameter should be done in base_dataset.py or base_model.py or both?

And a further clarification: what is the clip_length parameter in base_model.py?

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

It depends on how the dataset and model are being used. For example, in the training script, the dataset is created with the batch_size specified by the model: https://github.com/alexlee-gk/video_prediction/blob/master/scripts/train.py#L141-L146

clip_length is the number of consecutive frames that are given to the discriminator at each training iteration. This is useful when training with a high sequence_length but wanting to limit the temporal receptive field of the discriminator (e.g. as regularization or to save computation). In such cases, a subsequence of length clip_length is sampled from the generated sequence. In our final experiments, we ended up using the whole generated sequence, so clip_length = sequence_length - context_frames.

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Ah okay, got it! I've managed to get the model running on the kth dataset (I was changing the wrong batch_size before).

I've gotten errors with respect to the clip_length, so just to confirm - is it alright for the clip length to be greater than the number of context frames? As long as clip length is <= (sequence_length - context_frames), it should be alright?

Thank you very much for your help and quick response!

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

Great! clip length is <= (sequence_length - context_frames). Do let me know if you still encounter an error after adjusting the clip_length.

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

So eventually it is still running out of memory unfortunately. Just confirming that model checkpoints are still saved - so I should be able to run the evaluate.py script on the intermediate result?

I am getting the following error when running evaluate.py:

assert true.shape == pred.shape
AttributeError: 'NoneType' object has no attribute 'shape'

Thanks again for your help!

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

Yes, you should be able to run evaluate.py script on the intermediate result. The evaluated tensors true and pred should not be None, so you can look into why that's the case.

In regards to running out of memory in the middle of training (and not at the beginning of training), it might be caused when computing the losses on the validation set (which is done only for the tensorboard summaries) simultaneously while training. You can try the experimental branch, where the validation evaluation is done in a separate session.run().

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Thanks! I am trying my own data now shaped (128, 128, 1), but I am still getting a reshape error. My dataset is correctly outputting 128x128. Is there another place I need to change a parameter to switch from 3-channel to 1-channel images? Could you clarify the process for generating the savp model encoder/decoder structure 'self.hparams.ngf'? Does this somehow depend on the number of channels?

I am also having reshape issues with (256, 256, 1) -> off by 8 in both cases. I am using your pre-written 256 scale architecture. I have additionally tried copying my data into 3 channels -> passing in (256, 256, 3) - I still have a factor of 8 reshape error. I've tried the kth dataset - reworked to (256, 256, 3), and that appears to run without the reshape error. Is there a particular detail in passing in a different dataset that might be causing this factor of 8 difference in the session run with 'fetches'?

EDIT: I figured out what was going on! My input is float arrays ranging from 0 to 1, not uint8 arrays. This was causing issues for the tf.train.Feature -> BytesList. Do you know what's the best way to create a trainable feature such that it takes in a list of float numpy arrays?

Thanks again!

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

Hi mitkina, sorry for the late reply. If you still want to try, I have made several improvements in the experimental branch (will soon be merged into master). The dataset and model now runs with other dimensions, including (128, 128, 1) images. The KTH dataset can be preprocessed to 128x128 resolution with this command:

bash data/download_and_preprocess_dataset.sh kth 128

The encoder/decoder structure is made following the encoder_layer_specs/decoder_layer_specs, which specify the number of output channels for each Conv2D and whether a Conv2DRNN layer should follow that Conv2D. Although the SAVP model have some predefined layer specs, I have not tried them out with datasets at that resolution, so you might need to adjust those to use less parameters, increase speed, use less memory, etc. The code in experimental is able to train a VAE on the 128x128 dataset with the current defaults, with a batch size of 4 in a Titan X.

I'd also recommend first training either a deterministic model or a VAE-only model (i.e. no GANs). Those might already give pretty good KTH predictions, and they will use less memory and be faster to train because they won't be using discriminators. I'd also recommend using a much shorter context_frames, e.g. --model_hparams context_frames=2,sequence_length=12 as in the BAIR dataset.

In regards to using float images, you can probably use FloatList instead of BytesList, and remove tf.image.convert_image_dtype from the decode_and_preprocess_images method.

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Hi Alex,

Thanks for the information! I have already gotten the CVAE and the SAVP model to work with 128x128x1 images, with fewer context frames :) I actually found the CVAE was doing better than the SAVP for my dataset. But still overfitting.

In tensorboard, there's summary and summary_1 - what is the difference between these two? They both seem to be from the training set? Do you evaluate the validation set anywhere?

The FloatList does not work in this case because I have a list of arrays of floats, not a list of floats. I could unroll the image into a list of floats I guess... For now I've just discretized my continuous values into 256 bins for the network to work.

Thanks!

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

That's great! Yes, there is a trade-off between the VAE and the SAVP model. The former might be better at traditional metrics that penalize for spatial distortion (e.g. PSNR and SSIM) but the latter might be better at producing more realistic images, assuming that the dataset can benefit from that (e.g. real-world datasets with a lot of texture details). I haven't noticed overfitting on the datasets I've tried. Is your dataset large enough? You can also try standard data augmentation preprocessing (e.g. random cropping and horizontal flipping).

The difference is that the tensorboard summaries ending in "_1" correspond to the validation set. Furthermore, if you are using the newest version of my experimental branch, there are summaries ending in "_2", which correspond to the validation set but using longer sequences than used for training (assuming that the dataset's long_sequence_length != sequence_length). I have also repurposed the PR curves summaries for 2D plotting of the metrics over the prediction steps (to see them, you need to build tensorboard from source, see tensorflow/tensorboard#1110).

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

Update: I added details about the summaries in the README.

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Awesome, thanks! My dataset is rather small, so that's definitely part of the problem. I was also training uneven lengths of video sequences, so during training the network was taking samples from each video, as far as I understand. I am now retraining both SAVP and CVAE on the videos split up into the correct sequence length (10 steps total - with 5 step prediction), to see if that performs any better. I have been seeing interesting results that a ConvLSTM which learns the internal representation of the dynamics was performing better in all the metrics (PSNS, MSE, SSIM, TPR, TNR, FPR, FNR). I am hoping that with this new training procedure, the CVAE will perform better, but perhaps a CVAE is just more data intensive. Although data augmentation might be useful, I can only really do the video flipping, the cropping does not make a lot of sense in my context I think.

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

In the SAVP network during training, I have been seeing the discriminator loss decrease considerably, but the generator loss increasing. It seems the discriminator is overpowering the generator :S

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Hi Alex,

Just an update and a quick question. The CVAE now performs much better (due to the dataset split), but the SAVP network is performing very poorly (it might also just need more epochs than the CVAE). I was wondering if you could clarify the difference in TensorBoard of the transformed images summary and the gen_images_enc_summary? My understanding is that the former is the generator output and the latter is the encoding received by the generator?

Thanks again!

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

I've retrained the models with the new hyperparameters, and I noticed that the SAVP model with flow transformations doesn't do as well as the one with CDNA transformations. So, you can try the latter again with tv_weight=0,transformation=cdna.

The summaries with the "_enc" postfix refer to generations sampled from the approximate posterior q (i.e. the encoder), whereas the corresponding summaries without that postfix refer to generations sampled from the prior.

Also, the TP/FP/TN/FN numbers in the PR curves tab doesn't mean anything in my setup. These are just dummy values that were needed to hack the PR curve summaries to plot arbitrary 2D plots.

For the SAVP model, the GAN loss usually decreases over training iterations for the discriminator, but increases for the generator. It should be okay as long as the values are between 0 and 1.

from video_prediction.

mitkina avatar mitkina commented on June 2, 2024

Thanks for the feedback! I will give the cdna transformation another shot!
Re: TP/FP/TN/FN - I wrote my own, so no worries!
That's good to know regarding the GAN!

from video_prediction.

fatemehtd avatar fatemehtd commented on June 2, 2024

I've retrained the models with the new hyperparameters, and I noticed that the SAVP model with flow transformations doesn't do as well as the one with CDNA transformations. So, you can try the latter again with tv_weight=0,transformation=cdna.

The summaries with the "_enc" postfix refer to generations sampled from the approximate posterior q (i.e. the encoder), whereas the corresponding summaries without that postfix refer to generations sampled from the prior.

Also, the TP/FP/TN/FN numbers in the PR curves tab doesn't mean anything in my setup. These are just dummy values that were needed to hack the PR curve summaries to plot arbitrary 2D plots.

For the SAVP model, the GAN loss usually decreases over training iterations for the discriminator, but increases for the generator. It should be okay as long as the values are between 0 and 1.

Hi Alex
I trained the SAVP on my dataset, but the GAN loss increases over iterations and its lowest value is around 5 and increases up to 22. What does it mean? Could I trust the generator in this state?

from video_prediction.

alexlee-gk avatar alexlee-gk commented on June 2, 2024

For the BAIR dataset, the GAN loss for the discriminator converges to around 0.15-0.2, and the GAN loss for the generator does go up slightly over training iterations but the final value is around 1. You might need to increase the relative weighting of the GAN loss vs the other losses for your dataset, such that the GAN loss is not too high.

from video_prediction.

fatemehtd avatar fatemehtd commented on June 2, 2024

Thanks for your feedback, I should definitely change the weights for my dataset, I've used kth parameters up to now. The values of gen_kl_loss increases during iterations up to 182, and may be more in the following iterations.

Here is the values of different losses in the middle of the training
learning_rate 0.0002
progress global step 14400 epoch 10.3
image/sec 13.2 remaining 5774m (96.2h) (4.0d)
d_loss 0.023427725
discrim_video_sn_gan_loss (0.14210741, 0.1)
discrim_video_sn_vae_gan_loss (0.09216982, 0.1)
g_loss 20.750145
gen_l1_loss (0.032766685, 100.0)
gen_video_sn_gan_loss (0.5542432, 0.1)
gen_video_sn_vae_gan_loss (0.8700853, 0.1)
gen_video_sn_vae_gan_feature_cdist_loss (1.7331043, 10.0)
gen_kl_loss (175.23955, 0.0)

Could you please suggest me how to change the following weights?
"l1_weight": 100.0,
"kl_weight": 0.01,
"video_sn_vae_gan_weight": 0.1,
"video_sn_gan_weight": 0.1,
"vae_gan_feature_cdist_weight": 10.0,
"gan_feature_cdist_weight": 0.0,
"state_weight": 0.0,
"nz": 32

from video_prediction.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.