Git Product home page Git Product logo

cnn-vae's People

Contributors

dependabot[bot] avatar lukeditria avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

cnn-vae's Issues

Thank you

Hey I just wanted to say I really like your project. I have tried out other vae repos in the past and this one is much easier to use and performs much better. I was able to very quickly get setup training on my dataset. The VGG19 model vastly improves training performance for me compared to other network structures I've used before. Code is very clean as well. Cheers!

size of VAE image

Can the size of VAE image input be changed?Why must it be 64x64?

Question about class ResUp

In class Decoder, the number of convolutional patterns decreaes (from z to ch8 -> ch8 -> ch4 -> ch2 -> ch).

In class ResUp, the convolution operation goes in this way: self.conv1 = nn.Conv2d(channel_in, channel_out//2, 3, 1, 1).
In this way, the number of convolutional patterns is smaller than channel_out.
Take ResUp(ch*8, ch*4) as an example, the number of convolutional patterns varies by ch*8 -> ch*2 ->ch*4.

I suggest that are there any mistakes in self.conv1 in class ResUp?
I think it should be self.conv1 = nn.Conv2d(channel_in, channel_in//2,3,1,1) or nn.Conv2d(channel_in, channel_out*2, 3, 1, 1) or something that doesn't "shrink" the convolutional patterns during the operation.

Cool! how well does it work?

I've been thinking about implementing a resnet style vae for images. In preparation I came across your repo -- very interesting! I'm wondering if you'd be willing to share your thoughts about how well this works? In particular, I'd be interested to understand the impact of the skip connections. Also, qualitatively, what fidelity do you see in the reconstructed images?

Kind regards,
Hudson

How to modify the image size?

Thanks for your code! It did work when I trained with image_size=64, but when I tried to modify the image_size=128, some error
occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-23-cd633620fbd2> in <module>
      6 
      7         #VAE loss
----> 8         loss = vae_loss(recon_data, data[0].to(device), mu, logvar)
      9 
     10         #Perception loss

<ipython-input-4-ba4d722462c7> in vae_loss(recon, x, mu, logvar)
     14 
     15 def vae_loss(recon, x, mu, logvar):
---> 16     recon_loss = F.binary_cross_entropy_with_logits(recon, x)
     17     KL_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
     18     loss = recon_loss + 0.01 * KL_loss

~/anaconda3/envs/pt1.0/lib/python3.6/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2096 
   2097     if not (target.size() == input.size()):
-> 2098         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2099 
   2100     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([8, 3, 128, 128])) must be the same as input size (torch.Size([8, 3, 320, 320]))

So how can I modify the input size of image correctly? Thank you very much, if you can answer my question!

About training for1024*1024 imaging

Dear LukeDitria,
I am a long-time user of this repo and I have raised an issue about scaling to 256 resolutions earlier (if you remember). But now I want to train on much higher resolution images (1024), do you think it is possible or what necessary changes do I need to make?
Best regards,
Jay

How about 1-channel image reconstrustion?

Thanks for your code! It did work when I trained with image_channel=3, but when I tried to test the image with channel=1, some error occurred:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-38-0565c2bab7d7> in <module>
      9 
     10         #Perception loss
---> 11         loss_feature = feature_loss(data[0].to(device), recon_data, feature_extractor)
     12 
     13         loss += loss_feature

<ipython-input-30-2deca954bbe0> in feature_loss(img, recon_data, feature_extractor)
     26 def feature_loss(img, recon_data, feature_extractor):
     27     img_cat = torch.cat((img, torch.sigmoid(recon_data)), 0)
---> 28     out = feature_extractor(img_cat)
     29     loss = 0
     30     for i in range(len(feature_extractor)):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    421 
    422     def forward(self, input: Tensor) -> Tensor:
--> 423         return self._conv_forward(input, self.weight)
    424 
    425 class Conv3d(_ConvNd):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    417                             weight, self.bias, self.stride,
    418                             _pair(0), self.dilation, self.groups)
--> 419         return F.conv2d(input, weight, self.bias, self.stride,
    420                         self.padding, self.dilation, self.groups)
    421 

RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[64, 1, 256, 256] to have 3 channels, but got 1 channels instead

So how can I modify the code correctly? Thank you very much, if you could answer my question!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.