Comments (2)
@SidShenoy Thank you for your comment. That is a very sharp observation!
Yes, indeed, changing the pooling layers affects the following feature maps. Since the output of the max-pooling is now changed to the LL filter (similar to average pooling), the feature map after the pooling layer is slightly different. (Note that the other feature maps from the other three wavelet filters do not propagate to the next layer of the encoder. They are skipped to the decoder so that the only change the encoder has to care about is due to the LL filter change from the max-pooling)
However, as we wrote in the paper, we decided not to touch the encoder but just let the decoder adapt to those changes. You can fine-tune the encoder weights by partially or entirely freeing encoder weight parameters and we actually tried some variants, such as freeing only the following convolution parameters (after pooling layer) so that the change will be dealt in the encoder as well. There was not much difference at the final outcomes so we chose to stick on the simpler training strategy.
This can be explained in two-folds; 1) It is already a well-known phenomenon and a lot of observations were consistently reported that style transfer can be done with changing the max-pooling to average pooling (even though the VGG network was trained using the max-pooling) and the effect is sometimes even better. Similarly, our encoder with LL filter, which is an average pooling with some scaling factor, shares this characteristic. 2) Since the decoder is newly trained, it has enough capacity to deal with such shiftings of the feature maps in the encoder to output a good reconstruction.
Still, your comment is very valuable and we will include our description of the training procedure more in detail to clarify the point. Thx a lot for your attention :)
from wct2.
Maybe this partial code snippet would help your understanding on what we did:
for param in self.encoder.parameters():
param.requires_grad = False
self.dec_optim = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.decoder.parameters()),
lr = self.lr,
betas=(self.beta1, self.beta2)
)
feature, skips = self.encoder(real_image)
recon_image = self.decoder(feature, skips)
feature_recon, _ = self.encoder(recon_image)
recon_loss = self.MSE_loss(recon_image, real_image)
feature_loss = torch.zeros(1).to(self.device)
feature_loss += self.MSE_loss(feature_recon, feature.detach())
loss = recon_loss * self.recon_weight + feature_loss * self.feature_weight
self.reset_grad()
loss.backward()
self.dec_optim.step()
from wct2.
Related Issues (20)
- WCT2 seems cannot transfer styles without segmentation maps? Could you please indicate how to make it works without segmentation maps since you claimed in your paper that it supports transfer directly and without segmentation. HOT 1
- Can I train with my data? HOT 1
- Some questiona about network? HOT 4
- Query regarding "conv0" layer HOT 1
- Segmentation maps precision HOT 1
- In "cat5" skip network, only the last layer matters HOT 2
- Qs about the Dataset HOT 2
- Does it support large-size photos style transfer? HOT 1
- some questions about the figure 9 HOT 2
- How to add temporal consistency HOT 2
- QS about the depth map of transfered images HOT 1
- About training source HOT 4
- why the encoder has the layer conv0 HOT 1
- How does it work for an indoor dataset? HOT 1
- Question about the loss functions
- LL components vs. avg. pooling HOT 2
- Details about training decoder
- About SSIM metircs
- Without Seg HOT 1
- Create
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from wct2.