Comments (10)
Our GauGAN training codes have been released. You could check training_tutorial.md to set up GauGAN experiments.
from gan-compression.
We will add our GauGAN model later. Stay tuned.
from gan-compression.
Hi, can you let me know when will you release the GauGAN code? Thanks.
from gan-compression.
We will release our compressed model of GauGAN and the test codes in 2 or 3 days.
The training codes may be later. We are trying to merge the training codes into our repository and it may take some time.
from gan-compression.
Hi, we have released our compressed model of GauGAN and the test codes. Check the README for using our compressed model.
from gan-compression.
@lmxyy while waiting for your official SPADE release, I'm trying to fill out the gaps myself. It mostly going successful, however one thing I can't understand is how to adopt weight transfer to SPADE blocks
what I'm trying is this:
idxs = transfer_Conv2d(m1.conv_0, m2.conv_0, input_index=input_index)
idxs = transfer_Conv2d(m1.conv_1, m2.conv_1, input_index=idxs, output_index=input_index)
if m1.learned_shortcut and m2.learned_shortcut:
transfer_Conv2d(m1.conv_s, m2.conv_s, input_index=input_index)
but I constantly getting index out of bounds errors. I think the error comes from the fact that each SPADE block shrinks the number of channels it has.
Could you please share a snippet on how to transfer weights from teacher SPADE block to a student one? Thanks!
Edit:
I tried changing the snippet to:
idxs = transfer_Conv2d(m1.conv_0, m2.conv_0, input_index=input_index)
idxs = transfer_Conv2d(m1.conv_1, m2.conv_1, input_index=idxs)
if m1.learned_shortcut and m2.learned_shortcut:
transfer_Conv2d(m1.conv_s, m2.conv_s, input_index=input_index)
and now it passes, but I'm not sure if it really works.
Edit 2:
It also feels like you have a typo here in transfer_Conv2d
implementation:
if input_index is not None:
q = p.abs().sum([0, 2, 3])
_, idxs = q.topk(m2.in_channels, largest=True)
p = p[:, idxs]
else:
p = p[:, input_index]
should it be is None
?
from gan-compression.
Yes, Edit 2 is a typo. Thank you for pointing it out.
Here is the snippet of my implementation of weight transfering of MobileSPADEGenerator, but I haven't sorted it. I hope this could help you:
def transfer_conv(m1, m2, input_index, output_index=None):
assert isinstance(m1, nn.Conv2d) and isinstance(m2, nn.Conv2d)
p = m1.weight.data
assert input_index is not None
p = p[:, input_index]
if output_index is None:
q = p.abs().sum([1, 2, 3])
_, idxs = q.topk(m2.out_channels, largest=True)
else:
idxs = output_index
m2.weight.data = p[idxs].clone()
if m2.bias is not None:
m2.bias.data = m1.bias.data[idxs].clone()
return idxs
def transfer_spconv(m1, m2, input_index, output_index=None):
assert isinstance(m1, SeparableConv2d) and isinstance(m2, SeparableConv2d)
def transfer_dw(dw1, dw2):
p = dw1.weight.data
# print(input_index.max(), p.shape)
dw2.weight.data = p[input_index].clone()
if dw2.bias is not None:
dw2.bias.data = dw1.bias.data[input_index].clone()
def transfer_pw(pw1, pw2):
p = pw1.weight.data
# print('!!!', input_index.max(), p.shape)
p = p[:, input_index]
if output_index is None:
q = p.abs().sum([1, 2, 3])
_, idxs = q.topk(pw2.out_channels, largest=True)
else:
idxs = output_index
pw2.weight.data = p[idxs].clone()
if pw2.bias is not None:
pw2.bias.data = pw1.bias.data[idxs].clone()
return idxs
transfer_dw(m1.conv[0], m2.conv[0])
idxs = transfer_pw(m1.conv[2], m2.conv[2])
return idxs
def transfer_mbspade(m1, m2, input_index=None):
assert isinstance(m1, MobileSPADE) and isinstance(m2, MobileSPADE)
m2.param_free_norm.running_mean = m1.param_free_norm.running_mean[input_index].clone()
m2.param_free_norm.running_var = m1.param_free_norm.running_var[input_index].clone()
idxs = transfer_conv(m1.mlp_shared[0], m2.mlp_shared[0], list(range(m1.mlp_shared[0].in_channels)))
transfer_spconv(m1.mlp_gamma, m2.mlp_gamma, idxs, input_index)
transfer_spconv(m1.mlp_beta, m2.mlp_beta, idxs, input_index)
return input_index
def transfer_mbresnetblock1(m1, m2, input_index):
assert input_index is not None
assert isinstance(m1, MobileSPADEResnetBlock) and isinstance(m2, MobileSPADEResnetBlock)
if m1.learned_shortcut:
assert m2.learned_shortcut
idxs = transfer_mbspade(m1.norm_0, m2.norm_0, input_index)
idxs = transfer_conv(m1.conv_0, m2.conv_0, idxs)
idxs = transfer_mbspade(m1.norm_1, m2.norm_1, idxs)
idxs = transfer_conv(m1.conv_1, m2.conv_1, idxs)
# print(len(idxs))
transfer_mbspade(m1.norm_s, m2.norm_s, input_index)
transfer_conv(m1.conv_s, m2.conv_s, input_index, idxs)
return idxs
else:
assert not m2.learned_shortcut
idxs = transfer_mbspade(m1.norm_0, m2.norm_0, input_index)
idxs = transfer_conv(m1.conv_0, m2.conv_0, idxs)
idxs = transfer_mbspade(m1.norm_1, m2.norm_1, idxs)
idxs = transfer_conv(m1.conv_1, m2.conv_1, idxs, input_index)
return idxs
def transfer_weight(netA, netB):
if isinstance(netA, MobileSPADEGenerator):
assert isinstance(netB, MobileSPADEGenerator)
idxs = transfer_conv(netA.fc, netB.fc, list(range(netA.fc.in_channels)))
idxs = transfer_mbresnetblock1(netA.head_0, netB.head_0, idxs)
idxs = transfer_mbresnetblock1(netA.G_middle_0, netB.G_middle_0, idxs)
idxs = transfer_mbresnetblock1(netA.G_middle_1, netB.G_middle_1, idxs)
idxs = transfer_mbresnetblock1(netA.up_0, netB.up_0, idxs)
idxs = transfer_mbresnetblock1(netA.up_1, netB.up_1, idxs)
idxs = transfer_mbresnetblock1(netA.up_2, netB.up_2, idxs)
idxs = transfer_mbresnetblock1(netA.up_3, netB.up_3, idxs)
else:
raise NotImplementedError
from gan-compression.
@lmxyy thanks a lot! I will now try to replace it and see how it goes
from gan-compression.
@lmxyy, any estimate for when you might release the GauGAN training codes?
Thank you!
from gan-compression.
@lmxyy, any estimate for when you might release the GauGAN training codes?
Thank you!
We will release the training codes in one or two weeks.
from gan-compression.
Related Issues (20)
- Gray-Scale Input Support HOT 1
- What do these two paths mean?(--metaA_path --metaB_path) HOT 3
- About Select the Best Model (evolution_search.py) HOT 2
- distilling on higer resolution HOT 2
- Distill Problem HOT 4
- "Once-for-all" Network Training Problem HOT 2
- TypeError: _output_padding() missing 1 required positional argument: 'num_spatial_dims' HOT 4
- Guidance to covert pth to ptl
- [Question] About SuperSeparableConv2d HOT 2
- Cannot access to https://hanlab.mit.edu/ HOT 1
- ERROR 403: Forbidden HOT 5
- Request for Access to Pretrained Model for Verification and Replication Purposes. HOT 3
- URL is not supported HOT 7
- Does this way can apply to pix2pixHD model? HOT 2
- How to generate cityscape_A.npz HOT 1
- where is bash scripts/cycle_gan/horse2zebra/search.sh? HOT 1
- "download_real_stat.sh" doesn't work.
- Question about testing the compressed model HOT 1
- Question about the budget setting HOT 2
- about SuperConv2d HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gan-compression.