Git Product home page Git Product logo

self-attention-gan's People

Contributors

cbokpark avatar heykeetae avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

self-attention-gan's Issues

Tensorboard Logger missing

In trainer.py

    def build_tensorboard(self):
        from logger import Logger
        self.logger = Logger(self.log_path)

The logger module seems to be missing.

Is there a bug in Self_Attn module?

Hi,

I noticed something strange in sagan_models.py. In the original paper matrix multiplication is used to calculate f_x * g_x and also h_x * attn_score. However, in the code torch.mul() is used which is an elementwise multiplication operation. I think you meant to use torch.matmul().

large image generation will take a lot of memory.

If I want to generate image of size 512*512, the attention module will take int(64 / 8)*(512 ** 2) = 2097152,the parameters will take a lot of memory. Is the self-attention not suitable for large image generation? How to solve this?

model.py is working only for imsize=64

Hi, I have been reading the code and trying to fix the issue, and to make it work for more image sizes at the same time. I believe the minimal imsize could be around 32 if you keep the kernel size, with minor change in the code (you need to add the same if imsize == 64 in forward part of both models and change the generator's init's structure a little bit). If you go lower than 32, you might be facing error of input size smaller than kernel size. If you go up to 256 you definitely need more layers. So there might not be a One Fit for All solution here. Any suggestions?

Is this a writing error: attention.permute(0,1,2)?

hello,
in sagan_models.py:
out = torch.bmm(proj_value,attention.permute(0,1,2))
do you want to transpose attention? but .permute(0,1,2) makes no change
from Han's paper eq(2), I think it is .permute(0,2,1),
or did I understand anything wrong?

class conditional batch norm not implemented

Just want to verify that I am correct that the class conditional batch norm is not implemented in this repo?
In the original tensorflow code, the network blocks incorporate a ConditionalBatchNorm operation.

I would just like a simple answer if my conclusion is correct, as this could help others asking themselves the same question :)

Variable image size.

The current code does not seem to be able to handle the various images in the network. I want to make it more flexible, but how do I know what the network configuration are for larger images?

hinge loss

image
i find citation 13,16,30 and do not know exact principle of hinge loss.
i feel confused about why don't we use WGAN loss function.
cause it has better performance than WGAN loss function?

the imsize

In your code, there is the imsize, howerver in the paper there is no imsize. The int(in_dim / 8)(imsize ** 2) means the channel of the output(the \hat{C} right?).
self.f_ = nn.Conv2d(in_dim, int(in_dim / 8)
(imsize ** 2), 1)
self.g_ = nn.Conv2d(in_dim, int(in_dim / 8)*(imsize ** 2), 1)

Cannot Download CelebA dataset

Thanks for sharing your code with us.
When I run the download.sh to download CelebA dataset, it shows "404" error and I cannot open the download link in browser too. May be the download link has been replaced or unusable?

How to use in 3D conv?

The paper and the code are both for the 2D convolution of the sn limit w, then how to deal with w in the 3D convolution?

RuntimeError: Found 0 files in subfolders of: ./data/CelebA/

File "E:\Users\Raytine\Anaconda3\lib\site-packages\torchvision\datasets\folder.py", line 79, in init
"Supported extensions are: " + ",".join(extensions)))
RuntimeError: Found 0 files in subfolders of: ./data/CelebA/
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif

Missing tester file

Hello,

Thank you for this excellent implementation. Everything is very clear, however in main.py file it appears to be missing both the Tester file and the qgan_trainer file. Please advise.

if config.train:
if config.model=='sagan':
trainer = Trainer(data_loader.loader(), config)
elif config.model == 'qgan':
trainer = qgan_trainer(data_loader.loader(), config)
trainer.train()
else:
tester = Tester(data_loader.loader(), config)
tester.test()

about out of memory

Hi
My GPU is 11 GB, but when i training. Always : RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58
results in the code : result = self.forward(*input, **kwargs)

Softmax

Hello,
Why is Softmax computed on columns in your code (dim=-1) whereas it is applied on rows in the original paper (dim=1)?

I want to change 'imsize'.

I wanted to change 'imsize' that is paramater.
So, I changed 'imsize' from 64 to 128 but I got the following error message.
AttributeError: 'Discriminator' object has no attribute 'l4'

Could you tell me a solution?
What should I change in codes?

FID and IS

anyone calculated FID and IS for generated images for CelebA and lsun dataset?

about the gamma parameter

In your code, the shape of gamma [batchsize,1,1,1]. I think the shape should be [1].
Besides, the attention score you get seems to be different with Han's paper. Did you calculate the attention score using the same equation as eqn.(1) in paper.

About the attention map

Hi, @heykeetae
I had read the paper and found the attention map in the paper.
But, how can i visualize the attention map?
In sagan_models.py, there is a tensor about B X N X N, called "attention".
And how can i utilize this tensor?
Could you please give me some advise?
Many thanks!

Spectral norm on Generator

I would like to know if you know the reasoning of where applying spectral norm on G. Specifically, why not apply spectral norm in the last layer of G (the layer that projects to RGB)? Also, are there any reason why not apply spectral norm to self attention convs as well?

Thanks

permute twice?

self-attention, Q K V , Q * K has a permute for calculate attention weight, but the attention weight permute again to multiply V ?

About negative gamma

Hi, thanks for your great efforts on this project.
I have a question about the "gamma" parameter.
Is it natural for gamma to be trained with negative value?
Is this result telling that attention has a negative effect?

Added tensorboard logger

You can add a tensorboard logging to the train.py by inserting the following code at line 180. This needs to be modified because l4 does not exist if imsize is less than 64 and the network fails if imsize is greater than 64. When the larger imsize is fixed I will issue a pull request. Also build_tensorboard should be changed. Hope this helps!

def build_tensorboard(self):
    from logger import Logger
    if os.path.exists(self.log_path):
        shutil.rmtree(self.log_path)
    os.makedirs(self.log_path)
    self.logger = Logger(self.log_path)

Insert at line 180 in train.py
# Print out log info
if (step + 1) % self.log_step == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
" ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
format(elapsed, step + 1, self.total_step, (step + 1),
self.total_step, d_loss_real.item(),
self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item()))

            # (1) Log values of the losses (scalars)
            info = {
                'd_loss_real': d_loss_real.item(),
                'd_loss_fake': d_loss_fake.item(),
                'd_loss': d_loss.item(),
                'g_loss_fake': g_loss_fake.item(),
                'ave_gamma_l3': self.G.attn1.gamma.mean().item(),
                'ave_gamma_l4': self.G.attn2.gamma.mean().item(),
            }


            for tag, value in info.items():
                self.logger.scalar_summary(tag, value, step + 1)


        # Sample images / Save and log
        if (step + 1) % self.sample_step == 0:

            # (2) Log values and gradients of the parameters (histogram)
            for net, name in zip([self.G, self.D], ['G_', 'D_']):
                for tag, value in net.named_parameters():
                    tag = name + tag.replace('.', '/')
                    self.logger.histo_summary(tag, self.to_np(value), step + 1)

            # (3) Log the images
            info = {

                'fake_images': self.to_np(fake_images.view(*display_vars)[:10, :, :, :]),
                'real_images': self.to_np(real_images.view(*display_vars)[:10, :, :, :]),
            }

            fake_images, _, _ = self.G(fixed_z)
            save_image(denorm(fake_images.data),
                       os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))

            info['fixed_fake_images'] = self.to_np(denorm(real_images.data).view(*display_vars)[:10, :, :, :])

            for tag, image in info.items():
                self.logger.image_summary(tag, image, step + 1)

About:AttributeError: 'Conv2d' object has no attribute 'weight'

Hello! Thank you for your contribution to generative adversarial network research and for sharing your code! I am from China. Now is the Chinese New Year. I wish you a happy Chinese New Year! I am very interested in your thesis, when I try to add spectral normalization in my new networks, the program gives the following error:
Traceback (most recent call last):
File "SR.py", line 45, in
train(opt, Gs, Zs, reals, NoiseAmp)
File "E:\SinGAN-masterplus\SinGAN\training.py", line 34, in train
D_curr,G_curr = init_models(opt)
File "E:\SinGAN-masterplus\SinGAN\training.py", line 310, in init_models
netG.apply(models.weights_init)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 293, in apply
module.apply(fn)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 293, in apply
module.apply(fn)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 294, in apply
fn(self)
File "E:\SinGAN-masterplus\SinGAN\models.py", line 215, in weights_init
m.weight.data.normal_(0.0, 0.02)
File "E:\abcd\lib\site-packages\torch\nn\modules\module.py", line 591, in getattr
type(self).name, name))
AttributeError: 'Conv2d' object has no attribute 'weight'

I have searched a lot of information and couldn't solve it, so I want to ask you, I wish you a happy life, and look forward to your reply!

256*256

How to generate the size of 256*256 ? How to realize it? Thank you

Gumbel Softmax?

Hello, thanks for the implementation. I'm trying work out why you mention Gumbel Softmax in the trainer several times. Is this an unwanted residue of another project? Thanks

P.S.: In the generator and the discrimators arguments both have the batch_size, however this argument is never used.

new data set

Hi,
I want to use a different data set not available via pytorch. I get error and tried different tricks but did not work. Could you please let me know how can add the path of my data set to this code to work.

Weird Error while using multi GPU.

RuntimeError: start (1431224) + length (0) exceeds dimension size (1431244). (narrow at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/ATen/native/TensorShape.cpp:157)
frame #0: at::Type::narrow(at::Tensor const&, long, long, long) const + 0x49 (0x7fe6365a1639 in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/lib/libcaffe2.so)
frame #1: torch::autograd::VariableType::narrow(at::Tensor const&, long, long, long) const + 0x184 (0x7fe6382c3ae4 in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #2: torch::cuda::broadcast_coalesced(at::ArrayRefat::Tensor, at::ArrayRef, unsigned long) + 0xbc0 (0x7fe6386b7210 in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #3: + 0xc423cb (0x7fe6386bb3cb in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #4: + 0x38a5cb (0x7fe637e035cb in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)

frame #15: THPFunction_apply(_object*, _object*) + 0x38f (0x7fe6381e1a2f in /export/home/anaconda_install/anaconda_download/installed_conda/envs/pytorch_0_4_1/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so)
frame #46: __libc_start_main + 0xf5 (0x7fe655cf2c05 in /usr/lib64/libc.so.6)
frame #47: python() [0x4009e9]


Hi, while adding the Att module in my own model, I encounter the error above. But it works well using single GPU.

Do anyone know why this happens? Thanks in advance.

dataset structure

Thanks for sharing your code with us. But I can't download any of downloads.sh dataset, so please tell me what's the structure of your training data, and I can make it myself. Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.