Git Product home page Git Product logo

funit's Introduction

License CC BY-NC-SA 4.0 Python 3.7

FUNIT: Few-Shot Unsupervised Image-to-Image Translation

animal swap gif

Few-shot Unsueprvised Image-to-Image Translation
Ming-Yu Liu, Xun Huang, Arun Mallya, Tero Karras, Timo Aila, Jaakko Lehtinen, and Jan Kautz.
In arXiv 2019.

Copyright (C) 2019 NVIDIA Corporation.

All rights reserved. Licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)

The code is released for academic research use only. For commercial use, please contact [email protected].

For press and other inquiries, please contact Hector Marinez

Installation

  • Clone this repo git clone https://github.com/NVlabs/FUNIT.git
  • Install CUDA10.0+
  • Install cuDNN7.5
  • Install Anaconda3
  • Install required python pakcages
    • conda install -y pytorch torchvision cudatoolkit=10.0 -c pytorch
    • conda install -y -c anaconda pip
    • pip install pyyaml tensorboardX
    • conda install -y -c menpo opencv3

To reproduce the results reported in the paper, you would need an NVIDIA DGX1 machine with 8 V100 GPUs.

Hardware Requirement

To reproduce the experiment results reported in our ICCV paper, you would need to have an NVIDIA DGX1 machine with 8 V100 32GB GPUs. The training will use all 8 GPUS and take almost all of the GPU memory. It would take about 2 weeks to finish the training.

Dataset Preparation

Animal Face Dataset

We are releasing the Animal Face dataset. If you use this dataset in your publication, please cite the FUNIT paper.

cd dataset
wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar
tar xvf ILSVRC2012_img_train.tar
  • The training images should be in datasets/ILSVRC/Data/CLS-LOC/train. Now, extract the animal face images by running
python tools/extract_animalfaces.py datasets/ILSVRC/Data/CLS-LOC/train --output_folder datasets/animals --coor_file datasets/animalface_coordinates.txt
  • The animal face images should be in datasets/animals. Note there are 149 folders. Each folder contains images of one animal kind. The number of images of the dataset is 117,484.
  • We use 119 animal kinds for training and the ramining 30 animal kinds for evaluation.

Training

Once the animal face dataset is prepared, you can train an animal face translation model by running.

python train.py --config configs/funit_animals.yaml --multigpus

The training results including the checkpoints and intermediate results will be stored in outputs/funit_animals.

For custom dataset, you would need to write an new configuration file. Please create one based on the example config file.

Testing pretrained model

To test the pretrained model, please first create a folder pretrained under the root folder. Then, we need to downlowad the pretrained models via the link and save it in pretrained. Untar the file tar xvf pretrained.tar.gz.

Now, we can test the translation

python test_k_shot.py --config configs/funit_animals.yaml --ckpt pretrained/animal149_gen.pt --input images/input_content.jpg --class_image_folder images/n02138411 --output images/output.jpg

The above command with translate the input image

images/input_content.jpg

input image

to an output meerkat image

output image

by using a set of 5 example meerkat images

Citation

If you use this code for your research, please cite our papers.

@inproceedings{liu2019few,
  title={Few-shot Unsueprvised Image-to-Image Translation},
  author={Ming-Yu Liu and Xun Huang and Arun Mallya and Tero Karras and Timo Aila and Jaakko Lehtinen and Jan Kautz.},
  booktitle={arxiv},
  year={2019}
}

funit's People

Contributors

arunmallya avatar mingyuliutw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

funit's Issues

network configuration advice.

If I want to train with 4096 persons and 1million faces, should I change something in default config?
increase nf dims ?

No train file

Looks like awesome work, seems you are missing the training scripts at the moment? Will you also be releasing the pre-trained models?

Cheers

Discriminator architecture question

hi,I read your paper but I don‘t confirm the output of Discriminator!!

as your paper :

It consist of one convolution layer and followed by 4 AvgPool2x2 , if input image size is 128x128 then the output should be 4x4 and channel is ||S|| ?

so the output of discriminator size is (bs, ||S||, 4, 4)?

About tools/extract_animal_faces.py

Would the file 'tools/extract_animal_faces.py' be released soon? I'm so interested by this dataset and eager to use this in some other work.

do we need to separate content and style image during training?

It seems that during training, we use the same dataloader for content and style image generation. Say we have class A and class B, there is a chance that we have a content image from A and a style image from A, too. Is that resonable?

Also, how to define the "num_classes" in the discriminator? Is it equal to the number of classes in the whole training set? What if I use different classes in content data loader and style data loader?

G acc: 0.0000

Elapsed time in update: 1.213263
Iteration: 00065911/00100000
D acc: 0.9995 G acc: 0.0000
Elapsed time in update: 1.134419
Iteration: 00065912/00100000
D acc: 0.9999 G acc: 0.0000
Elapsed time in update: 1.168156
Iteration: 00065913/00100000
D acc: 1.0000 G acc: 0.0000
Elapsed time in update: 1.190268
Iteration: 00065914/00100000
D acc: 0.9999 G acc: 0.0000
Elapsed time in update: 1.129889
Iteration: 00065915/00100000
D acc: 1.0000 G acc: 0.0000
Elapsed time in update: 1.128156
Iteration: 00065916/00100000
D acc: 0.9999 G acc: 0.0000
Elapsed time in update: 1.135906
Iteration: 00065917/00100000
D acc: 1.0000 G acc: 0.0000
Elapsed time in update: 1.108964
Iteration: 00065918/00100000

G acc is 0 , is this normal?

how to train fully connected layers?

When calculating affine parameters, Z_y is input into three full-connection layers, and then the mean and standard deviation are output. Why do we do this? How do full-connection layers train?

Fine tuning for a new dataset

Hi all!
Thanks for open sourcing the research work at NVIDIA. I wanted to know what is the procedure to fine-tune the existing model on a different dataset,

assign_adain_params function

Hi,

I hope you are doing well. I just wondering why we use the .view(-1) function to fill the weight and bias of the adain layers. This function combine all samples in one batch with each other. I would be grateful if you correct my misunderstanding.

Thanks

GauGAN Page scrolls when painting using mobile device

I am trying to use this in my lessons during lockdown as a primary school teacher. Most students don't have computers at home (inner city school) but their parents all have nice phones. When using GauGAN on any mobile device (android/iphone/ipad) the page keeps scrolling around when painting with finger or is zoomed in too far to be able to see the download button for the output on the right.
Does anyone know if a mobile version of the page will be made soon?
Thank you.

subpixel upscaler?

should it be better?

x = Conv2D (4* (nf // 2**(i+1)), kernel_size=3, strides=1, padding='same')(x)
x = SubpixelUpscaler()(x)

than current

x = UpSampling2D()(x)                    
x = Conv2D (nf // 2**(i+1), kernel_size=5, strides=1, padding='same')(x)

?

train

hello, where is the train.py? have you uploaded all files? thank you

About tools/extract_animal_faces.py

Would the file 'tools/extract_animal_faces.py' be released soon? I'm so interested by this dataset and eager to use this in some other work.

Possibility to translate target class image to source class image?

Let's say the model was trained to translate images between cat, dog, tiger and lions. Given some pig images, the model can translate a cat image as content image to pig.

I wonder if it's possible to do something opposite. Given some pig images that the model never sees, is there a way to translate pig to cat?

In short, is there a way to translate unseen class to seen class?

Why not use the reconstruction loss of style?

I find that FUNIT use a different loss to make the translation image have the same class as the guide image, one is the $L_adv$ and one is the $L_F$, but why don't you use reconstruction of style (or class code) to replace the $L_F$ like MUNIT?

AdaptiveInstanceNorm2D issues...

in your code AdaptiveInstanceNorm2D normalizes a tensor across a batch and uses running_mean and running_var, so it should be called AdaptiveBatchNorm2D ?

seems not save images when write_html

In funciton write_html:
_write_row(html_file, it, '%s/gen_train_current.jpg' % img_dir, all_size)
It seems not save image named gen_train_current.jpg, so I cannot browse the html successfully

Why batch_norm is used inside AdaptiveInstanceNorm2d?

Hi! I have a doubt that the code in blocks.py (L188-L192) as show below:

class AdaptiveInstanceNorm2d(nn.Module):
        ...
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)
        return out.view(b, c, *x.size()[2:])

It is the definition of adaptive instance normalization. It looks like you are trying to reshape a batch of images into a "bigger" single batch image, then apply "batch normalization" on it, finally recover it back to batch, channel, height, weight. But, no matter reshaping it into single batch or not, the features of each channel from all the batch have been normalized. I am wondering how it could be an instance normalization.

I believe the code is perfectly correct, but please explain the tricks that were used here, thanks in advance!

Long training time

Why does it take such a long time (two weeks) to train FUNIT model? How to accelerate it.

NotImplementedError: Only 3D, 4D, 5D padding with non-constant padding are supported for now

Traceback (most recent call last):
  File "/home/hu/disk1/EXP_Part2/FUNIT/train.py", line 112, in <module>
    train()
  File "/home/hu/disk1/EXP_Part2/FUNIT/train.py", line 91, in train
    d_acc = trainer.dis_update(mc, mc, config)
  File "/home/hu/disk1/EXP_Part2/FUNIT/trainer.py", line 65, in dis_update
    al, lfa, lre, reg, acc = self.model(co_data, cl_data, hp, 'dis_update')
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/disk1/EXP_Part2/FUNIT/funit_model.py", line 53, in forward
    l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb)
  File "/home/hu/disk1/EXP_Part2/FUNIT/networks.py", line 81, in calc_dis_real_loss
    resp_real, gan_feat = self.forward(input_real, input_label)
  File "/home/hu/disk1/EXP_Part2/FUNIT/networks.py", line 65, in forward
    feat = self.cnn_f(x)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/disk1/EXP_Part2/FUNIT/blocks.py", line 163, in forward
    x = self.conv(self.pad(x))
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/padding.py", line 171, in forward
    return F.pad(input, self.padding, 'reflect')
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/functional.py", line 2840, in pad
    raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
NotImplementedError: Only 3D, 4D, 5D padding with non-constant padding are supported for now

any ideas? thanks!

Why are you using this format D ?

Hi, your Discriminator Design is unique . So , I want to know this Discriminator architure about output is num_classes channel format what is meaning ?

and why do you use projection Discrminator ? do you have any test kinds of D ?

How do you guarantee that the identity is preserved?

In your paper, you mention that the discriminator solves a series of "binary classification tasks determining whether an input image is a real image of the source class or a translation output coming from G. As there are |S| source classes, D produces |S| outputs" But you "do not penalize D for not predicting false for images of other classes (S{c_x})," which means D can predict any number of classes as long as class c_x is predicted false/positive (depending on whether that is a real or fake sample).

Then there is the feature matching loss, that forces that the identity feature from G(x, y) be as similar to Df(y) as possible. All cool until here.

Now, if D can output any number of classes as positives, how do you make sure that the last layer of D is not just broadcasting the same value to all classes, meaning it just determines whether the sample is real or false, independent of the class? That would make the identity features of class c_y almost the same as those of class c_x, with minimal feature matching loss no matter whether the output is from class cx or class cy.

Any clarification would be appreciated :)

The purpose of gen_test object in FUNITModel

The gen_test object in FUNITModel seems to be a clone of the whole model except the discriminator. The gen_test object is only used in generating testing results and it seems that it is redundant because we can run the gen object in evaluation mode. Why making a clone of the gen object? Was it designed for the multi-GPU environment?

self.gen_test = copy.deepcopy(self.gen)

AssertionError: 3D tensors expect 2 values for padding

"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license
(https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import copy

import torch
import torch.nn as nn

from networks import FewShotGen, GPPatchMcResDis


def recon_criterion(predict, target):
    return torch.mean(torch.abs(predict - target))


class FUNITModel(nn.Module):
    def __init__(self, hp):
        super(FUNITModel, self).__init__()
        self.gen = FewShotGen(hp['gen'])
        self.dis = GPPatchMcResDis(hp['dis'])
        self.gen_test = copy.deepcopy(self.gen)

    def forward(self, co_data, cl_data, hp, mode):
        xa = co_data[0].cuda()
        la = co_data[1].cuda()
        xb = cl_data[0].cuda()
        lb = cl_data[1].cuda()
        if mode == 'gen_update':
            c_xa = self.gen.enc_content(xa)
            s_xa = self.gen.enc_class_model(xa)
            s_xb = self.gen.enc_class_model(xb)
            xt = self.gen.decode(c_xa, s_xb)  # translation
            xr = self.gen.decode(c_xa, s_xa)  # reconstruction
            l_adv_t, gacc_t, xt_gan_feat = self.dis.calc_gen_loss(xt, lb)
            l_adv_r, gacc_r, xr_gan_feat = self.dis.calc_gen_loss(xr, la)
            _, xb_gan_feat = self.dis(xb, lb)
            _, xa_gan_feat = self.dis(xa, la)
            l_c_rec = recon_criterion(xr_gan_feat.mean(3).mean(2),
                                      xa_gan_feat.mean(3).mean(2))
            l_m_rec = recon_criterion(xt_gan_feat.mean(3).mean(2),
                                      xb_gan_feat.mean(3).mean(2))
            l_x_rec = recon_criterion(xr, xa)
            l_adv = 0.5 * (l_adv_t + l_adv_r)
            acc = 0.5 * (gacc_t + gacc_r)
            l_total = (hp['gan_w'] * l_adv + hp['r_w'] * l_x_rec + hp[
                'fm_w'] * (l_c_rec + l_m_rec))
            l_total.backward()
            return l_total, l_adv, l_x_rec, l_c_rec, l_m_rec, acc
        elif mode == 'dis_update':
            xb.requires_grad_()
            l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb)
            l_real = hp['gan_w'] * l_real_pre
            l_real.backward(retain_graph=True)
            l_reg_pre = self.dis.calc_grad2(resp_r, xb)
            l_reg = 10 * l_reg_pre
            l_reg.backward()
            with torch.no_grad():
                c_xa = self.gen.enc_content(xa)
                s_xb = self.gen.enc_class_model(xb)
                xt = self.gen.decode(c_xa, s_xb)
            l_fake_p, acc_f, resp_f = self.dis.calc_dis_fake_loss(xt.detach(),
                                                                  lb)
            l_fake = hp['gan_w'] * l_fake_p
            l_fake.backward()
            l_total = l_fake + l_real + l_reg
            acc = 0.5 * (acc_f + acc_r)
            return l_total, l_fake_p, l_real_pre, l_reg_pre, acc
        else:
            assert 0, 'Not support operation'

    def test(self, co_data, cl_data):
        self.eval()
        self.gen.eval()
        self.gen_test.eval()
        xa = co_data[0]
        xb = cl_data[0]
        c_xa_current = self.gen.enc_content(xa)
        s_xa_current = self.gen.enc_class_model(xa)
        s_xb_current = self.gen.enc_class_model(xb)
        xt_current = self.gen.decode(c_xa_current, s_xb_current)
        xr_current = self.gen.decode(c_xa_current, s_xa_current)
        c_xa = self.gen_test.enc_content(xa)
        s_xa = self.gen_test.enc_class_model(xa)
        s_xb = self.gen_test.enc_class_model(xb)
        xt = self.gen_test.decode(c_xa, s_xb)
        xr = self.gen_test.decode(c_xa, s_xa)
        self.train()
        return xa, xr_current, xt_current, xb, xr, xt

    def translate_k_shot(self, co_data, cl_data, k):
        self.eval()
        xa = co_data[0].cuda()
        xb = cl_data[0].cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        if k == 1:
            c_xa_current = self.gen_test.enc_content(xa)
            s_xb_current = self.gen_test.enc_class_model(xb)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        else:
            s_xb_current_before = self.gen_test.enc_class_model(xb)
            s_xb_current_after = s_xb_current_before.squeeze(-1).permute(1,
                                                                         2,
                                                                         0)
            s_xb_current_pool = torch.nn.functional.avg_pool1d(
                s_xb_current_after, k)
            s_xb_current = s_xb_current_pool.permute(2, 0, 1).unsqueeze(-1)
            xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current

    def compute_k_style(self, style_batch, k):
        self.eval()
        style_batch = style_batch.cuda()
        s_xb_before = self.gen_test.enc_class_model(style_batch)
        s_xb_after = s_xb_before.squeeze(-1).permute(1, 2, 0)
        s_xb_pool = torch.nn.functional.avg_pool1d(s_xb_after, k)
        s_xb = s_xb_pool.permute(2, 0, 1).unsqueeze(-1)
        return s_xb

    def translate_simple(self, content_image, class_code):
        self.eval()
        xa = content_image.cuda()
        s_xb_current = class_code.cuda()
        c_xa_current = self.gen_test.enc_content(xa)
        xt_current = self.gen_test.decode(c_xa_current, s_xb_current)
        return xt_current

if __name__ == "__main__":
    from utils import get_config, get_train_loaders, make_result_folders
    config = get_config('configs/funit_animals_test.yaml')
    model = FUNITModel(config)
    # print(model.gen)

    t = torch.randn((2,1,150,160))
    xa, xr_current, xt_current, xb, xr, xt = model.test(t, t)
    print(xa.shape, xr_current.shape, xt_current.shape, xb.shape, xr.shape, xt.shape)


 

The file is funit_model.py, run it on test data t = torch.randn((2,1,150,160)) ,get errors:

(hhh) hu@hu-D520MT-K:~/disk1/EXP_Part2/FUNIT$ /home/hu/anaconda3/envs/hhh/bin/python /home/hu/disk1/EXP_Part2/FUNIT/funit_model.py
Traceback (most recent call last):
  File "/home/hu/disk1/EXP_Part2/FUNIT/funit_model.py", line 136, in <module>
    xa, xr_current, xt_current, xb, xr, xt = model.test(t, t)
  File "/home/hu/disk1/EXP_Part2/FUNIT/funit_model.py", line 79, in test
    c_xa_current = self.gen.enc_content(xa)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/disk1/EXP_Part2/FUNIT/networks.py", line 225, in forward
    return self.model(x)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/disk1/EXP_Part2/FUNIT/blocks.py", line 163, in forward
    x = self.conv(self.pad(x))
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/modules/padding.py", line 171, in forward
    return F.pad(input, self.padding, 'reflect')
  File "/home/hu/anaconda3/envs/hhh/lib/python3.7/site-packages/torch/nn/functional.py", line 2803, in pad
    assert len(pad) == 2, '3D tensors expect 2 values for padding'
AssertionError: 3D tensors expect 2 values for padding

I don't konw the reason? could anyone help me please

GANimal demo is broken when uploading an image

Hello! Not sure if this is the correct place to report this. In any case, I cannot use the GANimal Demo on http://imaginaire.cc/ganimal/ as when I click on Upload after selecting an image, it throws an error:

Uncaught ReferenceError: loadImage is not defined
    onclick http://imaginaire.cc/ganimal/:1

I have the same problem on Firefox (96.0) and Google Chrome (95.0.4638.54)

Seek the AnimalFaces Dataset

Hello!
I'm interested in your experiment of AnimalFaces Dataset, but I didn't find a pure AnimalFaces in pubic work, the cost of dowmloading the original ImageNet is large to me, so I wander if you have upload the dataset to some online network location? I believe it will be very helpful if they have already been uploaded.
Best!

Dataset download link is broken

Hi,
I couldn't download the dataset using the commands
cd dataset wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar tar xvf ILSVRC2012_img_train.tar

Is there an alternative link I can use?

Thanks

Finetune on a new sample

Thanks for the code and excilent work. I have one question about the few-shot step. After training, do we need to finetune the AdaIN layers on new samples? I found you just calculate the latent code based on new samples and test it without any finetuning. Why we say this is few-shot? Thanks in advance!

When tool.py is shared?

Hi,
It is great job. Can you release the tool.py firstly? I need this dataset and will cite your paper.

Pretrained model on other datasets

Hi,
Thanks you for sharing your work and code!
Could you share your pretrained model on other dataset than Animal faces? Specifically the pretrained model on the Birds dataset?

Unable to download the dataset - Error 404

Hi,

I am attempting to download the dataset used for FUNIT, however I get the error below when attempting to wget:

$ wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar
--2021-03-08 17:45:41--  http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar
Resolving www.image-net.org (www.image-net.org)...
Connecting to www.image-net.org (www.image-net.org) ... connected.
HTTP request sent, awaiting response... 404 Not Found
2021-03-08 17:45:42 ERROR 404: Not Found.

Could you please advise on this issue?

Thank You!

BatchNorm layers cause training error when track_running_stats=True with DistributedDataParallel

When using DDP (pytorch 12.1) some of my batch norm layers cause the training to fail due to an inplace operation with the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: torch.cuda.FloatTensor [65]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The operation that failed was simply:

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2, **bn_params)
        self.act = nn.SiLU(inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)    # <------------------------------------------------------------- this fails
        x = self.act(x)
        return x

On a whim I tried passing:
self.bn = nn.BatchNorm2d(c2, track_running_stats=False, **bn_params)
to all my batch norm layers and the training ran, but of course this is not a viable solution.

For the record, I also tried cloning x and setting nn.SiLU(inplace=False) but got the same error.

RuntimeError: CUDA error: device-side assert triggered

Traceback (most recent call last):
  File "train.py", line 83, in <module>
    d_acc = trainer.dis_update(co_data, cl_data, config)
  File "/home/xxx/FUNIT-master/trainer.py", line 62, in dis_update
    al, lfa, lre, reg, acc = self.model(co_data, cl_data, hp, 'dis_update')
  File "/root/anaconda3/envs/funit/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/xxx/FUNIT-master/funit_model.py", line 53, in forward
    l_real_pre, acc_r, resp_r = self.dis.calc_dis_real_loss(xb, lb)
  File "/home/xxx/FUNIT-master/networks.py", line 83, in calc_dis_real_loss
    total_count = torch.tensor(np.prod(resp_real.size()), dtype=torch.float).cuda()
RuntimeError: CUDA error: device-side assert triggered

How the gradient penalty works? Where is the code?

First, thank you very much for sharing the code with the community.

I am curious about why and how the gradient penalty works.

As you said in the Table5 (on page 13 of the arxiv version paper), the gradient penalty loss is very import for the results, but I didn't find the specific description for it in the paper.

So can you point out how it works and where the code is?

Thanks again for the attractive work.

The training code does not work well...

The result which is generated by the model is trained during 400 iterations.

I think it looks that the models fail to learn.
I modify inplace=False in nn.ReLU and nn.LeakyReLU in blocks.py. following this issue #23

gen_train_00000400_02

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.