Git Product home page Git Product logo

coupled-vae's People

Contributors

cwloka avatar hxyue1 avatar jkclem avatar kenricnelson avatar kevin-chen0 avatar thistleton avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

hxyue1

coupled-vae's Issues

n_epoch_display is a bogus parameter

The n_epoch_display argument for the train_VAEs function in the notebook doesn't actually do anything. While it gets passed to other functions, it looks like the VAEMNIST and VAECIFAR modules don't actually use the variable.

Calculate_generalizedmean with coupled_loss as the riskBias

use calculate_robustness as a starting point. Generalize this by making -2/3 a variable called riskBias. Call this calculate_generalizedmean

Then compute a particular riskBias using the formula riskBias = - 2 coupling_loss/(1 + dim * coupling_loss). Compute the generalized mean with this particular riskBias and call this calculate_coupling_loss_bias

Develop Coupled ELBO Probability Graphic

There is a need to expand the analysis shown in Fig 5 & 8 of the Coupled VAE paper. The Coupled VAE paper provided a graphic of the reconstruction loss by plotting a histogram of the loss per image and overlaying this with the Accuracy, Robustness, and Decisiveness geometric mean metrics. The analysis needs to add a plot for the latent layer divergence and to combine the divergence & reconstruction into the probabilities which represent ELBO metric.

The design for this analysis is being developed in the Mathematica file Generalized_ELBO.nb. There are three components:

  1. Probability histogram of reconstruction with overlay of generalized mean metrics
  2. Probability histrogram of divergence with overlay of generalized mean metrics
  3. Probability histogram of ELBO with overlay of generalized mean metrics

Dependency conflicts when installing coupledvae

Dependency conflicts can arise when running the command pip install -i https://test.pypi.org/simple/ coupledvae==0.0.15. scikit-learn seems to be the major culprit.

INFO: pip is looking at multiple versions of coupledvae to determine which version is compatible with other requirements. This could take a while.

ERROR: Cannot install coupledvae because these package versions have conflicting dependencies.

scikit-learn 1.2.1 depends on threadpoolctl>=2.0.0
scikit-learn 1.2.0 depends on threadpoolctl>=2.0.0
scikit-learn 1.1.3 depends on threadpoolctl>=2.0.0
scikit-learn 1.1.2 depends on threadpoolctl>=2.0.0
scikit-learn 1.1.1 depends on threadpoolctl>=2.0.0
scikit-learn 1.1.0 depends on threadpoolctl>=2.0.0
scikit-learn 1.0.2 depends on threadpoolctl>=2.0.0
scikit-learn 1.0.1 depends on threadpoolctl>=2.0.0
scikit-learn 1.0 depends on threadpoolctl>=2.0.0
scikit-learn 0.24.2 depends on joblib>=0.11
scikit-learn 0.24.1 depends on joblib>=0.11
scikit-learn 0.24.0 depends on joblib>=0.11
scikit-learn 0.23.2 depends on joblib>=0.11
scikit-learn 0.23.1 depends on joblib>=0.11
scikit-learn 0.23.0 depends on joblib>=0.11
scikit-learn 0.22 depends on joblib>=0.11
scikit-learn 0.21.0 depends on scipy>=0.17.0
scikit-learn 0.20.3 depends on scipy>=0.13.3```

This can be circumvented by installing `nsc` first, but I'm raising this as an issue because we'll need to come up with a better solution in the long term.

coupledvae version release roadmap

coupledvae_v3.1: Comment the notebook using terms from PDL - TF2 Coursera course.

coupledvae_v4.x: Rearchitect existing Coupled VAE model with the contents from the PDL - TF2 Coursera course. Including:

  • Encoder/decoder
  • Gradient Tape
  • Display plots per n epochs rather than display after end of training

coupledvae_v5.x: Integrate nsc lib.

  • v5 will be one where we conduct experimentations for our next paper

Review and add metrics/latent space features for v4.1

I have enhanced the VAE model in vae_v4.1b_wip.ipynb. v4.1b is a branch of v4 and is independent from v4.1a. The additional features that you can see here are the following:

  1. Plot of the latent space scatterplot per epoch
  2. Saving the loss and epoch values per epoch and then plotting them as a graph in the end

Please do the following:

  1. Play around with the BATCH_SIZE, n_epoch, n_sample, and z_dim hyperparameters. See how when you change them can affect ELBO/loss value per epoch as well as the runtime length. One thing to note is that n_sample is bounded by the BATCH_SIZE so n_sample cannot be greater than the batch.

  2. Store the kl_div and neg_ll values into the self.metrics default_dict, just like what I did for loss and elbo. You may need to find a way to output these values together in the compute_loss(), for example, concatenating these values as a Tensor vector.

  3. Plot the kl_div and neg_ll values at the end of the notebook, just like what is done for loss and elbo.

I have marked the above 3 points in the vae_v4.1b_wip for easy following.

Review and/or fix changes to the VAE architecture for v4.1

I have looked at the VAE architecture from Keras here, and propose to make the following changes to VAE v.4.1:

  1. Make the encoder/decoder as tfk.Model instead of tfk.Layer
  2. Have the encoder/decoder pair passed into the VAE model like other hyperparameters
  3. Remove the VAE wrapper class and move the necessary functions into the VAEModel class, including:
    -train_step()
    -compute_loss()
    -log_normal_pdf()
    -generate_and_save_images()
  4. Use .fit() to train VAEModel instead of .train().

The benefits of this approach is simplication of the code by not having use a wrapper class. This also aligns with how Keras model usually trains the model by using .fit().

If this is the way to go, I have created a new VAE version vae_v4.1a_wip.ipynb, but has encounter a handful of errors:

  1. Encoder/Decoder model "has not been built" and therefore I am not able to view the 'summary()'. Alternatively, we may need to write in the way that Keras tutorial has already provided but I prefer not to.
latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
  1. Passing in tensorflow_datasets object types into VAEModel to fit. For some reason, in the Keras tutorial, they had to concatenate x_train and x_test into mnist_digits and then fit just mnist_digits. They can do that because x_train and x_test are numpy arrays. In our case, we currently pass in fm_train_dataset and fm_train_dataset separately. Do we have to mimic what keras is doing here. Would tensorflow.python.data.ops.dataset_ops.ShuffleDataset even work here?
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=30, batch_size=128)

Here is the link to keras Model class.

  1. Ability for it to train without the .train() from the wrapper class. Do we simply move those contents into the .train_step() function in VAEModel? How about the test_step() or fit() function? Do we need to override them in from tfk.Model?

  2. Where to put the generate_and_save_images()? Put that into the .train_step() so that there can be plots for every train step? Also, generate_and_save_images() inputs only test images while self.compute_loss() inputs only train images. So it may be problematic if we have the concatenate train and test images together as mentioned in 2.

Can you take a look at the following?

Thanks!

Look at why cifar10_corrupted does not contain any training set

Running the following code:

datasets, datasets_info = tfds.load(name='cifar10_corrupted',
                                    with_info=True,
                                    as_supervised=False
                                    )

print(datasets)

It only contains the test datasets, no train.

{'test': <PrefetchDataset shapes: {image: (32, 32, 3), label: ()}, types: {image: tf.uint8, label: tf.int64}>}

Other datasets, such as mnist_corrupted, would have contained by train and test.

datasets, datasets_info = tfds.load(name='mnist_corrupted',
                                    with_info=True,
                                    as_supervised=False
                                    )

print(datasets)
{'test': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>,
 'train': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.