heykeetae / self-attention-gan Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)
Pytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)
In trainer.py
def build_tensorboard(self):
from logger import Logger
self.logger = Logger(self.log_path)
The logger module seems to be missing.
I see from tester import Tester in main.py but there is no tester.py
Hi,
I noticed something strange in sagan_models.py. In the original paper matrix multiplication is used to calculate f_x * g_x and also h_x * attn_score. However, in the code torch.mul() is used which is an elementwise multiplication operation. I think you meant to use torch.matmul().
If I want to generate image of size 512*512, the attention module will take int(64 / 8)*(512 ** 2) = 2097152
,the parameters will take a lot of memory. Is the self-attention not suitable for large image generation? How to solve this?
Self-Attention-GAN/sagan_models.py
Line 13 in 8714a54
self attn layer's constructor has activation parameter but don't use that parameter.
why
Hi, I have been reading the code and trying to fix the issue, and to make it work for more image sizes at the same time. I believe the minimal imsize could be around 32 if you keep the kernel size, with minor change in the code (you need to add the same if imsize == 64
in forward part of both models and change the generator's init's structure a little bit). If you go lower than 32, you might be facing error of input size smaller than kernel size. If you go up to 256 you definitely need more layers. So there might not be a One Fit for All solution here. Any suggestions?
Batch_size = 6 ?
Right?
hello,
in sagan_models.py:
out = torch.bmm(proj_value,attention.permute(0,1,2))
do you want to transpose attention? but .permute(0,1,2) makes no change
from Han's paper eq(2), I think it is .permute(0,2,1),
or did I understand anything wrong?
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
Otherwise, it is not a true projection?
Just want to verify that I am correct that the class conditional batch norm is not implemented in this repo?
In the original tensorflow code, the network blocks incorporate a ConditionalBatchNorm operation.
I would just like a simple answer if my conclusion is correct, as this could help others asking themselves the same question :)
please ref:
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py
if choose 'wgan-gp'in d_loss and GP are not update differently in train()
d_loss = d_loss_real + d_loss_fake
d_loss = self.lambda_gp * d_loss_gp
The current code does not seem to be able to handle the various images in the network. I want to make it more flexible, but how do I know what the network configuration are for larger images?
In your code, there is the imsize, howerver in the paper there is no imsize. The int(in_dim / 8)(imsize ** 2) means the channel of the output(the \hat{C} right?).
self.f_ = nn.Conv2d(in_dim, int(in_dim / 8)(imsize ** 2), 1)
self.g_ = nn.Conv2d(in_dim, int(in_dim / 8)*(imsize ** 2), 1)
Could you tell me where in the code is TTUR ?
Thank you
Thanks for sharing your code with us.
When I run the download.sh to download CelebA dataset, it shows "404" error and I cannot open the download link in browser too. May be the download link has been replaced or unusable?
Hi,
In https://github.com/heykeetae/Self-Attention-GAN/blob/master/trainer.py#L34
the parameter is passed to self.d_iters (right?)
However, it seems that it is not used in the training loop below
Am I understanding the training code wrong?
Hi, I think we should detach fake image when training D. Otherwise G gets updates as well and it can be super unstable.
The paper and the code are both for the 2D convolution of the sn limit w, then how to deal with w in the 3D convolution?
File "E:\Users\Raytine\Anaconda3\lib\site-packages\torchvision\datasets\folder.py", line 79, in init
"Supported extensions are: " + ",".join(extensions)))
RuntimeError: Found 0 files in subfolders of: ./data/CelebA/
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif
self.attn1 = Self_Attn( 128, 'relu')
self.attn2 = Self_Attn( 64, 'relu')
Hello,
Thank you for this excellent implementation. Everything is very clear, however in main.py file it appears to be missing both the Tester file and the qgan_trainer file. Please advise.
if config.train:
if config.model=='sagan':
trainer = Trainer(data_loader.loader(), config)
elif config.model == 'qgan':
trainer = qgan_trainer(data_loader.loader(), config)
trainer.train()
else:
tester = Tester(data_loader.loader(), config)
tester.test()
Who can explain how much better after adding Self-Attention?
I want to use this code for multi-class datasets, but it seems like nowhere indicates that?
Hi
My GPU is 11 GB, but when i training. Always : RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58
results in the code : result = self.forward(*input, **kwargs)
Hello,
Why is Softmax computed on columns in your code (dim=-1) whereas it is applied on rows in the original paper (dim=1)?
I wanted to change 'imsize' that is paramater.
So, I changed 'imsize' from 64 to 128 but I got the following error message.
AttributeError: 'Discriminator' object has no attribute 'l4'
Could you tell me a solution?
What should I change in codes?
anyone calculated FID and IS for generated images for CelebA and lsun dataset?
In your code, the shape of gamma [batchsize,1,1,1]. I think the shape should be [1].
Besides, the attention score you get seems to be different with Han's paper. Did you calculate the attention score using the same equation as eqn.(1) in paper.
Hi, @heykeetae
I had read the paper and found the attention map in the paper.
But, how can i visualize the attention map?
In sagan_models.py, there is a tensor about B X N X N, called "attention".
And how can i utilize this tensor?
Could you please give me some advise?
Many thanks!
I would like to know if you know the reasoning of where applying spectral norm on G. Specifically, why not apply spectral norm in the last layer of G (the layer that projects to RGB)? Also, are there any reason why not apply spectral norm to self attention convs as well?
Thanks
self-attention, Q K V , Q * K has a permute for calculate attention weight, but the attention weight permute again to multiply V ?
Hi, thanks for your great efforts on this project.
I have a question about the "gamma" parameter.
Is it natural for gamma to be trained with negative value?
Is this result telling that attention has a negative effect?
You can add a tensorboard logging to the train.py by inserting the following code at line 180. This needs to be modified because l4 does not exist if imsize is less than 64 and the network fails if imsize is greater than 64. When the larger imsize is fixed I will issue a pull request. Also build_tensorboard should be changed. Hope this helps!
def build_tensorboard(self):
from logger import Logger
if os.path.exists(self.log_path):
shutil.rmtree(self.log_path)
os.makedirs(self.log_path)
self.logger = Logger(self.log_path)
Insert at line 180 in train.py
# Print out log info
if (step + 1) % self.log_step == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
" ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
format(elapsed, step + 1, self.total_step, (step + 1),
self.total_step, d_loss_real.item(),
self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item()))
# (1) Log values of the losses (scalars)
info = {
'd_loss_real': d_loss_real.item(),
'd_loss_fake': d_loss_fake.item(),
'd_loss': d_loss.item(),
'g_loss_fake': g_loss_fake.item(),
'ave_gamma_l3': self.G.attn1.gamma.mean().item(),
'ave_gamma_l4': self.G.attn2.gamma.mean().item(),
}
for tag, value in info.items():
self.logger.scalar_summary(tag, value, step + 1)
# Sample images / Save and log
if (step + 1) % self.sample_step == 0:
# (2) Log values and gradients of the parameters (histogram)
for net, name in zip([self.G, self.D], ['G_', 'D_']):
for tag, value in net.named_parameters():
tag = name + tag.replace('.', '/')
self.logger.histo_summary(tag, self.to_np(value), step + 1)
# (3) Log the images
info = {
'fake_images': self.to_np(fake_images.view(*display_vars)[:10, :, :, :]),
'real_images': self.to_np(real_images.view(*display_vars)[:10, :, :, :]),
}
fake_images, _, _ = self.G(fixed_z)
save_image(denorm(fake_images.data),
os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))
info['fixed_fake_images'] = self.to_np(denorm(real_images.data).view(*display_vars)[:10, :, :, :])
for tag, image in info.items():
self.logger.image_summary(tag, image, step + 1)
Would the spectral norm get canceled out because it appears on both the nominator and denominator of the batch normalization equation?
I mean:
bn(x*w/sn(w)) = gamma * (x*w/sn(w) - mean(x*w/sn(w))) / std(x*w/sn(w)) + beta = bn(x*w)
Hello! Thank you for your contribution to generative adversarial network research and for sharing your code! I am from China. Now is the Chinese New Year. I wish you a happy Chinese New Year! I am very interested in your thesis, when I try to add spectral normalization in my new networks, the program gives the following error:
Traceback (most recent call last):
File "SR.py", line 45, in
train(opt, Gs, Zs, reals, NoiseAmp)
File "E:\SinGAN-masterplus\SinGAN\training.py", line 34, in train
D_curr,G_curr = init_models(opt)
File "E:\SinGAN-masterplus\SinGAN\training.py", line 310, in init_models
netG.apply(models.weights_init)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 293, in apply
module.apply(fn)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 293, in apply
module.apply(fn)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 294, in apply
fn(self)
File "E:\SinGAN-masterplus\SinGAN\models.py", line 215, in weights_init
m.weight.data.normal_(0.0, 0.02)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 591, in getattr
type(self).name, name))
AttributeError: 'Conv2d' object has no attribute 'weight'
I have searched a lot of information and couldn't solve it, so I want to ask you, I wish you a happy life, and look forward to your reply!
How to generate the size of 256*256 ? How to realize it? Thank you
I try to run the code several times, but the negative gamma always emerged, sometimes in ave_gamma_l3, sometimes in ave_gamma_l4 or both.
Hello, thanks for the implementation. I'm trying work out why you mention Gumbel Softmax in the trainer several times. Is this an unwanted residue of another project? Thanks
P.S.: In the generator and the discrimators arguments both have the batch_size, however this argument is never used.
Hi,
I want to use a different data set not available via pytorch. I get error and tried different tricks but did not work. Could you please let me know how can add the path of my data set to this code to work.
Has this implementation reproduced the results of the original paper? If so, how long does it take to train?
RuntimeError: start (1431224) + length (0) exceeds dimension size (1431244). (narrow at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/ATen/native/TensorShape.cpp:157)
frame #0: at::Type::narrow(at::Tensor const&, long, long, long) const + 0x49 (0x7fe6365a1639 in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #1: torch::autograd::VariableType::narrow(at::Tensor const&, long, long, long) const + 0x184 (0x7fe6382c3ae4 in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #2: torch::cuda::broadcast_coalesced(at::ArrayRefat::Tensor, at::ArrayRef, unsigned long) + 0xbc0 (0x7fe6386b7210 in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #3: + 0xc423cb (0x7fe6386bb3cb in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #4: + 0x38a5cb (0x7fe637e035cb in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #15: THPFunction_apply(_object*, _object*) + 0x38f (0x7fe6381e1a2f in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #46: __libc_start_main + 0xf5 (0x7fe655cf2c05 in /usr/lib64/libc.so.6)
frame #47: python() [0x4009e9]
Hi, while adding the Att module in my own model, I encounter the error above. But it works well using single GPU.
Do anyone know why this happens? Thanks in advance.
Do you use hinge loss or Wasserstein-GP loss to get the CelebA images?
Thanks for sharing your code with us. But I can't download any of downloads.sh dataset, so please tell me what's the structure of your training data, and I can make it myself. Thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.