Git Product home page Git Product logo

collaborative-gan-sampling's Introduction

Collaborative Sampling in Generative Adversarial Networks

This repository provides a TensorFlow implementation of the Collaborative Sampling in Generative Adversarial Networks.


Overview

Once GAN training completes, we use both the generator and the discriminator to produce samples collaboratively. Our sampling scheme consists of one sample proposal step and multiple sample refinement steps. (I) The fixed generator proposes samples. (II) Subsequently, the discriminator provides gradients, with respect to activation maps of the proposed samples, back to a particular layer of the generator. Gradient-based updates of the activation maps are performed repeatedly until the samples are classified as real by the discriminator.


GANs for modelling an imbalanced mixture of 8 Gaussians. Vanilla GANs are prone to mode collapse. The accept-reject sampling algorithms including Discriminator Rejection Sampling (DRS) and Metropolis-Hastings method (MH-GAN) suffer from severe distribution bias due to the mismatch between distribution supports. Our collaborative sampling scheme applied to early terminated GANs succeeds in recovering all modes without compromising sample quality, significantly outperforming the baseline methods.

Real GAN
1K Iter
GAN
9K Iter
DRS
at 1K Iter
MH-GAN
at 1K Iter
Refine
at 1K Iter
Collab
at 1K Iter
Quality Diversity Overall

DCGAN for modelling human faces on the CelebA dataset. (Top) Samples from standard sampling. (Middle) Samples from our collaborative sampling method. (Bottom) The difference between the top and the middle row.

Cifar10 CelebA

CycleGAN for unpaired image-to-image translation. (Top) Samples from standard sampling. (Middle) Samples from our collaborative sampling method. (Bottom) The difference between the top and the middle row.


Dependencies:

  • tensorflow==1.13.0
  • CUDA==10.0
  • pillow
  • scipy=1.2
  • matplotlib
  • requests
  • tqdm

Citation:

If you use this code for your research, please cite our papers.

@inproceedings{liu2019collaborative,
  title={Collaborative Sampling in Generative Adversarial Networks},
  author={Liu, Yuejiang and Kothari, Parth and Alahi, Alexandre},
  booktitle={Thirty-first AAAI conference on artificial intelligence},
  year={2020}
}

Acknowledgements

The baseline implementation has been based on this repository

collaborative-gan-sampling's People

Contributors

thedebugger811 avatar yuejiangliu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

collaborative-gan-sampling's Issues

Ask for your help

Hi! I use social gan to predict my our data,but it always told me File "/home/wyx/sgan-master/sgan/data/trajectories.py", line 336, in init
curr_seq[_idx, :, pad_front:pad_end] = curr_ped_seq # curr_seq(3220)
ValueError: could not broadcast input array from shape (2,6) into shape (2,20).I consider the reason is that the change of the number of people in each frame of video,But i am not sure and can't resolve this problem.Please help me!!Thanks very much!

Http 404 response for wget command

The command below results in HTTP 404 response:
wget https://github.com/tensorflow/models/raw/master/research/gan/mnist/data/classify_mnist_graph_def.pb -P external/

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.