Git Product home page Git Product logo

mobilestylegan.pytorch's Introduction

MobileStyleGAN: A Lightweight Convolutional Neural Network for High-Fidelity Image Synthesis

Official PyTorch Implementation

The accompanying videos can be found on YouTube. For more details, please refer to the paper.

Requirements

  • Python 3.8+
  • 1โ€“8 high-end NVIDIA GPUs with at least 12 GB of memory. We have done all testing and development using DL Workstation with 4x2080Ti

Training

pip install -r requirements.txt
python train.py --cfg configs/mobile_stylegan_ffhq.json --gpus <n_gpus>

Convert checkpoint from rosinality/stylegan2-pytorch

Our framework supports StyleGAN2 checkpoints format from rosinality/stylegan2-pytorch. To convert ckpt your own checkpoint of StyleGAN2 to our framework:

python convert_rosinality_ckpt.py --ckpt <path_to_rosinality_stylegan2_ckpt> --ckpt-mnet <path_to_output_mapping_network_ckpt> --ckpt-snet <path_to_output_synthesis_network_ckpt> --cfg-path <path_to_output_config_json>

Check converted checkpoint

To check that your checkpoint is converted correctly, just run demo visualization:

python demo.py --cfg <path_to_output_config_json> --ckpt "" --generator teacher

Generate images using MobileStyleGAN

python generate.py --cfg configs/mobile_stylegan_ffhq.json --device cuda --ckpt <path_to_ckpt> --output-path <path_to_store_imgs> --batch-size <batch_size> --n-batches <n_batches>

Evaluate FID score

To evaluate the FID score we use a modified version of pytorch-fid library:

python evaluate_fid.py <path_to_ref_dataset> <path_to_generated_imgs>

Demo

Run demo visualization using MobileStyleGAN:

python demo.py --cfg configs/mobile_stylegan_ffhq.json --ckpt <path_to_ckpt>

Run visual comparison using StyleGAN2 vs. MobileStyleGAN:

python compare.py --cfg configs/mobile_stylegan_ffhq.json --ckpt <path_to_ckpt>

Convert to ONNX

python train.py --cfg configs/mobile_stylegan_ffhq.json --ckpt <path_to_ckpt> --export-model onnx --export-dir <output_dir>

Convert to CoreML

python train.py --cfg configs/mobile_stylegan_ffhq.json --ckpt <path_to_ckpt> --export-model coreml --export-dir <output_dir>

Deployment using OpenVINO

We provide external library random_face as an example of deploying our model at the edge devices using the OpenVINO framework.

Pretrained models

Name FID
mobilestylegan_ffhq.ckpt 7.75

(*) Our framework supports automatic download pretrained models, just use --ckpt <pretrined_model_name>.

Legacy license

Code Source License
Custom CUDA kernels https://github.com/NVlabs/stylegan2 Nvidia License
StyleGAN2 blocks https://github.com/rosinality/stylegan2-pytorch MIT

Acknowledgements

We want to thank the people whose works contributed to our project::

  • Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen, Timo Aila for research related to style based generative models.
  • Kim Seonghyeon for implementation of StyleGAN2 in PyTorch.
  • Fergal Cotter for implementation of Discrete Wavelet Transforms and Inverse Discrete Wavelet Transforms in PyTorch.
  • Cyril Diagne for the excellent demo of how to run MobileStyleGAN directly into the web-browser.

Citation

If you are using the results and code of this work, please cite it as:

@misc{belousov2021mobilestylegan,
      title={MobileStyleGAN: A Lightweight Convolutional Neural Network for High-Fidelity Image Synthesis},
      author={Sergei Belousov},
      year={2021},
      eprint={2104.04767},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@article{BELOUSOV2021100115,
      title = {MobileStyleGAN.pytorch: PyTorch-based toolkit to compress StyleGAN2 model},
      journal = {Software Impacts},
      year = {2021},
      issn = {2665-9638},
      doi = {https://doi.org/10.1016/j.simpa.2021.100115},
      url = {https://www.sciencedirect.com/science/article/pii/S2665963821000452},
      author = {Sergei Belousov},
}

mobilestylegan.pytorch's People

Contributors

bes-dev avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mobilestylegan.pytorch's Issues

When use batch_size>2

Hi, thank you for sharing code! As the default batch_size=2, training is fine. But when use batch_size = 8 , get error as bellow:
image
Hoping for your reply!

integrate with Lightning ecosystem CI

Hello and so happy to see you use Pytorch-Lightning! ๐ŸŽ‰
Just wondering if you already heard about quite the new Pytorch Lightning (PL) ecosystem CI where we would like to invite you to... You can check out our blog post about it: Stay Ahead of Breaking Changes with the New Lightning Ecosystem CI โšก
As you use PL framework for your cool project, we would like to enhance your experience and offer you safe updates to our future releases. At this moment, you run tests with a particular PL version, but it may accidentally happen that the next version will be incompatible with your project... ๐Ÿ˜• We do not intend to change anything on our project side, but still here we have a solution - ecosystem CI with testing both - your and our latest development head we can find it very early and prevent releasing eventually bad version... ๐Ÿ‘

What is needed to do?

What will you get?

  • scheduled nightly testing configured for development/stable versions
  • slack notification if something went wrong to investigate
  • testing also on multi-GPU machine as our gift to you ๐Ÿฐ

cc: @Borda

Doubts about ckpt

Hello, you pointed out in the paper that The whole network contains 8.01M parameters, has a computational complexity of 15.09 GMAC. But the size of your shared .ckpt file (mobilestylegan_ffhq_v2.ckpt) is about 689M. Can you give me some advice and why there is such a big difference? Thanks again for your guidance.

Image to image translation like CycleGAN

Is there a way to use MobileStyleGAN as an image-to-image (A->B) style transfer model, similar to CycleGAN, rather than just an image synthesizer from no input? I have a custom dataset of cat faces, one set is real (domain A) and the other set are fake (domain B). I want an input of A to translate into domain B.

Observations on eyeglasses and textures

Thanks for sharing your ideas and code. It is rather fun to compare it to StyleGAN2. I am wondering about this

  1. Why does your algorithm do poorly with eyeglasses?
    yycomp_5

  2. There is a certain blockiness to the images (almost like JPEG artifacts) Not sure why it is more common.

  3. I would have expected hair to be better rendered with your architecture, but for some odd reason (especially facial hair) is more iffy. Almost like too much regularity to the wavelet directions.

Thanks if you have any information to share. It is an interesting architecture you propose, so I missing intuition behind it.

Finetuning of student network?

Is it possible to initialise student network with other weights than random?
If my understanding is correct, the flag --ckpt affects both teacher and student. I'd like to load student from a checkpoint but have teacher as specified in the config file.

TypeError: 'NoneType' object is not iterable`

Traceback (most recent call last):
File "train.py", line 60, in
main(args)
File "train.py", line 42, in main
trainer.fit(distiller)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
self._run(model)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
self.dispatch()
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
self.accelerator.start_training(self)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
self._results = trainer.run_stage()
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
return self.run_train()
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 869, in run_train
self.train_loop.run_training_epoch()
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 489, in run_training_epoch
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 728, in run_training_batch
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 432, in optimizer_step
using_lbfgs=is_lbfgs,
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step
self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step
optimizer.step(closure=lambda_closure, **kwargs)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/torch/optim/adam.py", line 66, in step
loss = closure()
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 723, in train_step_and_backward_closure
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 813, in training_step_and_backward
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 280, in training_step
training_step_output = self.trainer.accelerator.training_step(args)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step
return self.training_type_plugin.training_step(*args)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 155, in training_step
return self.lightning_module.training_step(*args, **kwargs)
File "/home/opu/MobileStyleGAN/core/distiller.py", line 70, in training_step
loss = self.generator_step(batch, batch_nb)
File "/home/opu/MobileStyleGAN/core/distiller.py", line 107, in generator_step
style, pred_t, gt_images = self.make_sample(batch)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/home/opu/MobileStyleGAN/core/distiller.py", line 122, in make_sample
gt = self.synthesis_net(style)
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/opu/MobileStyleGAN/core/models/synthesis_network.py", line 96, in forward
img = self.upsample(img) + rgb
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/opu/MobileStyleGAN/core/models/modules/legacy.py", line 40, in forward
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
File "/home/opu/MobileStyleGAN/core/models/modules/ops/upfirdn2d.py", line 14, in upfirdn2d
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
File "/home/opu/MobileStyleGAN/core/models/modules/ops/upfirdn2d_cuda.py", line 99, in forward
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
RuntimeError: CUDA error: an illegal memory access was encountered
Exception ignored in: <bound method tqdm.del of <pytorch_lightning.callbacks.progress.tqdm object at 0x7f944c6c4240>>
Traceback (most recent call last):
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/tqdm/std.py", line 1145, in del
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/tqdm/std.py", line 1299, in close
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/tqdm/std.py", line 1492, in display
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/tqdm/std.py", line 1148, in str
File "/home/opu/anaconda3/envs/mobilestylegan/lib/python3.6/site-packages/tqdm/std.py", line 1450, in format_dict
TypeError: 'NoneType' object is not iterable

when I tried to run the project, this error raised. I need help

Can this be used inside GFP - GAN?

Hi Sergei, thank you for your amazing research and this repo. By following the readme, I was able to generate the CoreML model and I noticed that the generated model fully runs on ANE in my testing which is amazing.

So far so great. My goal is to use MobileStyleGAN inside GFP-GAN, which looks like this:

gfp

From the diagram, the green "Pretrained GAN as prior" is StyleGAN2 with some additions. Their StyleGAN2 has some additions called CS-SFT that transforms the output at each layer. This StyleGAN2 takes 2 inputs style & conditions:

style.shape = [
    (16,512)
]

conditions.shape = [
    (1,256,8,8),
    (1,256,8,8),
    (1,256,16,16),
    (1,256,16,16),
    (1,256,32,32),
    (1,256,32,32),
    (1,256,64,64),
    (1,256,64,64),
    (1,128,128,128),
    (1,128,128,128),
    (1,64,256,256),
    (1,64,256,256),
    (1,32,512,512),
    (1,32,512,512)
]

So, from what you can tell, is it possible to use MobileStyleGAN in GFP-GAN? I've been playing with your repo for the past few days. If you can advise on if this will work I would really appreciate it.

If you think it's not possible then I will stop here.

Thank you.

Smaller images look interesting!

Hi, I've been using MobileStyleGAN and it has been wonderful so far. ๐Ÿ˜ƒ

When I run python compare.py --cfg configs/mobile_stylegan_ffhq.json the 1024x1024 images are great.

But when I generate 512x512 images, it doesn't look normal. Problem for 512x512 images:

  • Teacher is darker (left)
  • Student is blurrier (right)

compare2
compare3
compare1

To generate 512x512 images, I remove the last channel from synthesis_network.py and mobile_synthesis_network.py. (channels[1:-1] and I set strict=False on load)

How can devs fix output to look normal for 512x512?

Thank you for your help and your work.

Where can I find my training results๏ผŸ

Thanks for your MobileStyleGAN.
I don't know where to find my training results,is the training result stored in mobilestylegan_ffhq.ckpt file?Or other files?

How often do it save the model?

Generated image for CoreML is all black

Hi @bes-dev, thanks for your work!

I am actually not able to convert MobileStyleGAN correctly. I am using the following command:

python train.py --cfg configs/mobile_stylegan_ffhq.json --checkpoint_dir . --ckpt mobilestylegan_ffhq_v2.ckpt --export-model coreml

But the generated image is all black. If I replace the synthesis network with the one provided here then the image does get generated fine. Meaning that MappingNetwork.mlmodel's parameters are loaded but not of SynthesisNetwork.mlmodel. Note that I do download the mobilestylegan_ffhq_v2.ckpt checkpoint and save it in root directory of this repository.

model_zoo.json need an update

While trying to run the demo on my local computer, I have used this line of code as stated in README

python demo.py --cfg configs/mobile_stylegan_ffhq.json

As a result, I got this error:

actual: c6a23e76375fbbf2d554e199718d73e1
expected: 8de24b4d08049c32dbec8beca5ed2074
Cached Downloading: /tmp/mobilestylegan_ffhq.ckpt
Traceback (most recent call last):
  File "demo.py", line 35, in <module>
    main(args)
  File "demo.py", line 15, in main
    ckpt = model_zoo(args.ckpt)
  File "/home/syn1650/Desktop/mobilestyle/MobileStyleGAN.pytorch/core/model_zoo.py", line 8, in model_zoo
    ckpt = download_ckpt(**zoo[name])
  File "/home/syn1650/Desktop/mobilestyle/MobileStyleGAN.pytorch/core/utils.py", line 29, in download_ckpt
    gdown.cached_download(url, ckpt_path, md5=md5)
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/gdown/cached_download.py", line 123, in cached_download
    download(url, temp_path, quiet=quiet, proxy=proxy, speed=speed)
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/gdown/download.py", line 110, in download
    res = sess.get(url, stream=True)
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/requests/sessions.py", line 600, in get
    return self.request("GET", url, **kwargs)
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/requests/sessions.py", line 573, in request
    prep = self.prepare_request(req)
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/requests/sessions.py", line 496, in prepare_request
    hooks=merge_hooks(request.hooks, self.hooks),
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/requests/models.py", line 368, in prepare
    self.prepare_url(url, params)
  File "/home/syn1650/anaconda3/envs/comparison/lib/python3.7/site-packages/requests/models.py", line 440, in prepare_url
    f"Invalid URL {url!r}: No scheme supplied. "
requests.exceptions.MissingSchema: Invalid URL '': No scheme supplied. Perhaps you meant http://?

Then I changed to the following line:

 "mobilestylegan_ffhq.ckpt": {
        "url": "https://drive.google.com/uc?id=11Kja0XGE8liLb6R5slNZjF3j3v_6xydt",
        "name": "mobilestylegan_ffhq.ckpt",
        "md5": "8de24b4d08049c32dbec8beca5ed2074"
    }
}

to

  "mobilestylegan_ffhq.ckpt": {
        "url": "https://drive.google.com/uc?id=11Kja0XGE8liLb6R5slNZjF3j3v_6xydt",
        "name": "mobilestylegan_ffhq_v2.ckpt",
        "md5": "8de24b4d08049c32dbec8beca5ed2074"
    }
}

And it worked.

Also, I have downloaded all the files manually. I think something wrong with the code or my PC. The code just downloads mapping_network.ckpt.

6GB of memory.

Q1:if I have just 6GB of memory.can i train a good effect?
or I need at least 12GB?

Q2:why don't you load the teacher of Discriminator?

thx.

convert to coreml with w+

i'm trying to convert the model that works on w+ latent space, how can i use the convert to coreml command?

How to use W+ space to generate samples from MobileStyleGAN

Hi!

I want to sample images from W+ space using PyTorch checkpoints. But there doesn't seem to exist any argument to generate.py script for that. Could you please guide me regarding this?

The images I sampled from CoreML's W+ space models (using both Mapping and Synthesis) were weirdly in bluish color. These models were exported using --export-w-plus argument. I've attached few of them here.

coreml_w_plus_3
coreml_w_plus_2
coreml_w_plus_0

When I use W space in CoreML models then the samples are colored correctly.

0
2
4

Any help is highly appreciated!

Regards
Rahul Bhalley

question about replace stygan2 decoder in other gan network

Hi, thanks for your amazing work. I try to use mobileStyleGAN to replace the stylegan2 decoder in GFPGAN. To solve this:

  1. The mapping network in GFPGAN is a u-net and output a feature with shape [batch, 16, 512]. So I replace the freq with style[1] which was style[2], so the freq share the style with hidden.
  2. GFPGAN output an image with shape [batch, 3, 512, 512]. In mobileStyleGAN, I change the channels[1:] to channels[2:]
  3. I fix the noise dataset to generate [3, 512, 512] image with torch.randn()

Then I successfully trained the network with batchsize=4, dataset length=10000. But after 1 epoch the kid_val is 0.0007 and never changed, the train loss oscillating around 0.5. I trained 20 epochs and use the 20st checkpoint to predict but get this result:
test

Another problem, I find the mapping network param is updating while training, but this network should freeze the params? Right?
Can you share your training loss? Or some training tips? Thanks for your help.

About Mobile Device Benchmarks

Hi,

Thanks for the great work. Do you have benchmarks on the performance of the model, like CoreML prediction speed on specific devices? Also, I wonder how prediction times change by the resolution of the model (256, 512, 1024) and how big these models are in size. Thanks...

Need some advice on how to use the repo to train a big dataset.

Hello:
I am new to Pytorch, and I want to learn how to use StyleGan2, and I found a big dataset from here:
https://github.com/NVlabs/metfaces-dataset

I download the datasets and found the data is huge, more than 17GB on disk.
If I want to train the dataset and export the trained model in an ONNX format.
How I can start?
I have Python 3.9 on Windows 10, and Pytorch version 1.11.0.
My PC has 16-core CPU with 128GB RAM with a NVIDIA graphics card, so I think hardware is good enough to start a big project.
Thanks,

How can I change my .pt file to .ckpt file

Hi~Thanks for your MobileStyleGAN!
I notice that you offer convert_rosinality_ckpt.py,but the parameters entered by this code do not have the path to the. Pt file.
It just has the parameter '--ckpt',but the output of styleGAN2 is .pt file.
So I still don't know how to change my .pt file...Please give me some help,thanks.

support stylespace

thanks for your amazing work.
this model support w, w+ as input but can it support stylespace?
Screenshot 2022-11-06 at 22 44 04

Can I do Face edditing?

Can I use MobileStyleGan's latent space to perform face editing like anycost-gan?
I have read both papers, and I hope to be able to do face editing with mobilestylegan, since mobilestylegan has faster inference time.

cannot find a way to encode image

stylegan2 has (18,512) latent vector and mobileStyleGAN map (1,512) vector to (1,512) style.
i dont get it how can i encode the image to the (1,512) style that the model takes as input.
thanks for help

Project dependencies may have API risk issues

Hi, In MobileStyleGAN.pytorch, inappropriate dependency versioning constraints can cause risks.

Below are the dependencies and version constraints that the project is using

wheel
torch
pytorch-lightning==1.0.2
gdown==3.12.2
addict==2.2.1
piq==0.5.2
numpy==1.17.5
PyWavelets==1.1.1
git+://github.com/fbcotter/pytorch_wavelets.git
neptune-client==0.4.132
kornia==0.4.1
pytorch_fid
coremltools

The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict.
The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.

After further analysis, in this project,
The version constraint of dependency pytorch-lightning can be changed to >=0.3.6.9,<=0.5.2.1.
The version constraint of dependency gdown can be changed to >=3.7.0,<=4.5.1.
The version constraint of dependency kornia can be changed to >=0.2.1,<=0.6.2.

The above modification suggestions can reduce the dependency conflicts as much as possible,
and introduce the latest version as much as possible without calling Error in the projects.

The invocation of the current project includes all the following methods.

The calling methods from the pytorch-lightning
pytorch_lightning.callbacks.ModelCheckpoint
pytorch_lightning.Trainer.fit
pytorch_lightning.Trainer
The calling methods from the gdown
gdown.cached_download
The calling methods from the kornia
kornia.augmentation.RandomHorizontalFlip
kornia.augmentation.RandomAffine
kornia.augmentation.RandomErasing
The calling methods from the all methods
json.dump
UpFirDn2dBackward.apply
torch.no_grad
torch.rsqrt.view
core.models.discriminator.Discriminator
real_pred.F.softplus.mean
SynthesisBlock
self.idwt.size
cfg.type.pl_loggers.getattr
fid_inception_v3
Wrapper
torch.unbind
ToRGB
self.perceptual_loss
distiller
prep_filt_sfb2d
pred.self.inception.view
self.branch3x3dbl_1
self.dwt_to_img.size
R1Regularization
torch.nn.Parameter
t.add_
json.load
torch.nn.functional.max_pool2d
FIDInceptionE_2
self.modulation
j.img_s.cpu
core.models.mobile_synthesis_network.MobileSynthesisNetwork
ConstantInput
gradgrad_input.reshape.reshape
x_pred.sum
style_b.unsqueeze.repeat.unsqueeze
cv2.imwrite
style.view.unsqueeze
self.layers
idwt.DWTInverse
size.size.b.torch.randn.to
k.source_state.size
torch.nn.functional.conv_transpose2d
self.branch7x7dbl_4
self.FIDInceptionC.super.__init__
torch.tensor
modules.DWTInverse
self.loss.reg_d
range
self.branch3x3_1
self.get_demodulation.view
core.models.mapping_network.MappingNetwork
torch.tensor.sum
style_a.unsqueeze.repeat
torch.load
self.mapping_net.apply
FIDInceptionC
random.randint
cv2.waitKey
self.branch5x5_1
t.mul.add_.clamp_
UpFirDn2dBackward.apply.view
pytorch_wavelets.DWTInverse
torch.utils.data.DataLoader
self.get_demodulation
torch.utils.cpp_extension.load.fused_bias_act
gdown.cached_download
ResBlock
pytorch_lightning.callbacks.ModelCheckpoint
core.models.mapping_network.MappingNetwork.state_dict
high.view.view
torch.autograd.grad.size
torch.onnx.export
self.weight.unsqueeze
t.mul.add_.clamp_.permute
pred.squeeze.squeeze.cpu.numpy.size
self.net._modules.items
image.new_empty
distiller.size
self.to_img1
numpy.iscomplexobj
EqualConv2d
self.to_img1.view
img_t.cpu
block.conv2.load_state_dict
self.parameters
enumerate
self.branch3x3dbl_3b
self.synthesis_net.append
Blur
core.loss.perceptual_loss.PerceptualLoss
scipy.linalg.sqrtm
modules.legacy.PixelNorm.append
getattr
self.noise
torch.nn.LeakyReLU
torch.cat
grad_x.size.grad_x.view.norm
core.models.synthesis_network.SynthesisNetwork
torch.nn.ModuleList
self.student
int_to_mode
FIDInceptionE_1
torch.uint8.img.to.numpy
style_a.unsqueeze.repeat.unsqueeze
gradgrad_out.view.view
covmean.np.isfinite.all
format
UpFirDn2d.apply
conv_module
build_logger
sorted
torch.cat.view
core.utils.select_weights.size
torch.nn.BatchNorm2d
upfirdn2d_native
stddev.mean.squeeze
torch.nn.BatchNorm2d.train.to
numpy.isfinite
FusedLeakyReLU
torch.nn.functional.conv2d.size
input.reshape.reshape
block.to_rgb.load_state_dict
core.utils.download_ckpt
ctx.save_for_backward
self.branch7x7dbl_1
style.view.size
self.register_buffer
self.dwt
stddev.repeat.var
self.mapping_net
self.branch7x7dbl_2
cv2.imshow
core.distiller.Distiller.to_onnx
multiprocessing.cpu_count
ValueError
torch.nn.AdaptiveAvgPool2d
self.mapping_net.style_dim.self.wsize.torch.randn.to
weight.transpose.reshape.pow
hasattr
self.generator_step
EqualLinear
StyledConv2d
noise_injection.NoiseInjection
batch.to.to
torch.rsqrt
core.utils.tensor_to_img
torchvision.models.inception_v3.load_state_dict
self.net
bias.view
self.weight_dw.transpose
core.distiller.Distiller.to_coreml
self.student.append
self.branch3x3dbl_2
core.utils.save_cfg
distiller.cpu
torch.save
self.branch3x3_2a
torch.optim.Adam
StyledConv
out.view.view
offset.sigma1.dot
torch.cat.sigmoid
stddev.repeat.mean
modules.MultichannelIamge
img.size
core.utils.apply_trace_model_mode
block_idx.InceptionV3.to.eval
self.final_linear
self.branch7x7_3
self.get_modulation
opts.append
self.loss_weights.items
t.mul
self.branch1x1
torch.zeros
torch.nn.L1Loss
extract_snet
torch.nn.functional.interpolate
torchvision.transforms.ToTensor
pathlib.Path.glob
batch.self.mapping_net.unsqueeze.repeat
self.blocks.append
torch.autograd.grad
var.self.mapping_net.view
img.view.size
self.modulation.bias.data.fill_
self
self.conv.view
stddev.repeat.repeat
style.self.modulation.view
self.l1_loss
torch.nn.functional.conv2d
fake.self.F.softplus.mean
v.size
self.branch5x5_2
self.make_sample
target.state_dict.items
diff.dot
hidden.size.hidden.size.torch.randn.to
self.branch3x3dbl_3
math.log
self.branch3x3dbl_3a
chr
self.FIDInceptionE_1.super.__init__
mnet.layers.load_state_dict
upfirdn2d
range.t.add_.div_
torch.nn.Sequential
self.log
kernel.torch.flip.view
out.permute.permute
self.mapping_net.style_dim.torch.randn.to
RuntimeError
core.models.synthesis_network.SynthesisNetwork.state_dict
torch.utils.model_zoo.load_url
numpy.load
self.to_img1.size
weight.transpose.reshape.transpose
torch.nn.BatchNorm2d.train.to.eval
torch.utils.cpp_extension.load.upfirdn2d
self.mapping_net.style_dim.torch.randn.self.mapping_net.mean
pathlib.Path.endswith
tqdm.tqdm
snet.conv1.load_state_dict
FusedLeakyReLUFunction.apply
MultichannelIamge
torch.nn.functional.leaky_relu
model
torch.uint8.img.to.numpy.to
int
pred.squeeze.squeeze.cpu.numpy.squeeze
torch.nn.functional.adaptive_avg_pool2d
mapping_net_ckpt.MappingNetwork.eval
core.model_zoo.model_zoo
self._log_loss
block.conv1.load_state_dict
block
self.InceptionV3.super.__init__
core.utils.load_weights
coremltools.TensorType
self.loss.loss_d
img.view.view
i.snet.layers.load_state_dict
self.branch7x7_2
ConvLayer
fake_pred.F.softplus.mean
self.layers.append
ScaledLeakyReLU
target.load_state_dict
os.path.exists
grad_x.size.grad_x.view.norm.mean
torch.nn.functional.avg_pool2d
calculate_frechet_distance
numpy.atleast_2d.dot
self.student.apply
block_idx.InceptionV3.to
pytorch_lightning.Trainer.fit
core.utils.select_weights
FIDInceptionA
pred.squeeze.squeeze.cpu
numpy.abs
self.upsample
self.activate
_SFB2D
self.to_img
self.up
self.loss.reg_d.items
pytorch_lightning.Trainer
style.self.modulation.view.size
torch.load.items
self.branch3x3_2b
torch.autograd.grad.view
torch.stack
self.FIDInceptionA.super.__init__
torchvision.models.inception_v3
self.to_rgb1
outp.append
numpy.trace
out_dim.torch.zeros.fill_
synthesis_net_ckpt.SynthesisNetwork.eval
isinstance
numpy.mean
torch.cat.size
grad_output.reshape.reshape
numpy.atleast_1d
w.self.style_inv.self.scale.pow
addict.Dict
piq.KID
self.minibatch_discrimination
weight.pow.sum
make_kernel
m.wsize
numpy.atleast_2d
input.size
style_b.unsqueeze.repeat
gt.self.inception.view
fused.fused_bias_act.sum
self.img_to_dwt
self.synthesis_net
channels.append
real.detach
torch.device
distiller.to.to
torch.utils.cpp_extension.load
Upsample
self.r1_reg
print
arch.models.getattr
numpy.concatenate.append
numpy.max
self.compute_mean_style
ImagePathDataset
torch.nn.functional.linear
style.unsqueeze.repeat
shape.torch.randn.to
numpy.allclose
pywt.Wavelet
self.mapping_net.style_dim.self.wsize.torch.randn.to.to
pytorch_wavelets.DWTForward
modules.StyledConv2d
self.synthesis_net.load_state_dict
norm
self.branch7x7dbl_5
modules.legacy.EqualLinear
self.conv1
IDWTUpsaplme
out.permute.view
sfb1d
core.loss.distiller_loss.DistillerLoss
width.height.batch.image.new_empty.normal_
path.Image.open.convert
list
super
core.utils.load_cfg
collections.OrderedDict
outputs.x.x.torch.stack.mean
numpy.diagonal
dim.grad_input.sum.detach
self.convs
snet.to_rgb1.load_state_dict
torch.nn.functional.l1_loss
os.path.dirname
target.state_dict
self.transforms
super.__init__
math.sqrt
create_config
numpy.cov
pred.squeeze.squeeze
make_style
torch.jit.trace
out.permute.reshape
self.mapping_net.load_state_dict
self.input
self.mapping_net.style_dim.self.cfg.batch_size.torch.randn.to
blocks.append
ModulatedConv2d
self.input.repeat
torchvision.transforms.Compose
PerceptualNetwork
t.mul.add_
self.m
pathlib.Path
torch.sqrt
self.student.wsize
self.conv
main
torchvision.transforms.Resize
self.layers.wsize
self.discriminator_step
self.cfg.mode.split
fused_leaky_relu
kornia.augmentation.RandomErasing
self.branch_pool
out.append
m
k.startswith
input.reshape.view
weight.transpose.reshape
mode_to_int
torch.nn.Linear
self.branch7x7_1
diffaug.get_default_transforms
NoiseInjection
FusedLeakyReLUFunctionBackward.apply
modulated_conv2d.ModulatedConv2d
k.replace
core.distiller.Distiller.simultaneous_forward
torch.cuda.is_available
img.view
self._resize
self.final_conv
self.weight_permute.self.weight_dw.transpose.unsqueeze
self.get_demodulation.size
torch.flip
gt.self._resize.detach
modules.legacy.PixelNorm
self.dwt_to_img
self.net.items
self.FIDInceptionE_2.super.__init__
torch.nn.functional.pad
self.kid.compute_metric
torch.nn.MaxPool2d
os.path.join
img_s.cpu
self.idwt
argparse.ArgumentParser.add_argument
coremltools.convert.save
open
modules.ConstantInput
SFB2D.apply
self.gan_loss.loss_g
weight.transpose.reshape.view
self.inception
numpy.hstack
input.new_empty
coremltools.convert
grad_output.new_empty
max
pred.squeeze.squeeze.cpu.numpy
argparse.ArgumentParser
torch.randn
kornia.augmentation.RandomAffine
self.conv2
torch.nn.MSELoss
len
style.view.view
self.student.parameters
extract_mnet
self.l2_loss
i.noise.size
kornia.augmentation.RandomHorizontalFlip
core.models.inception_v3.load_inception_v3
core.dataset.NoiseDataset
self.to_rgb
utils.NoiseManager
ConvLayer.append
self.blur
i.blocks.state_dict
self.branch7x7dbl_3
in_dim.out_dim.torch.randn.div_
self.loss.gan_loss.parameters
t.clamp_
pytorch_fid.inception.InceptionV3
PIL.Image.open
core.loss.non_saturating_gan_loss.NonSaturatingGANLoss
random.random
fake.detach
numpy.eye
noise
argparse.ArgumentParser.parse_args
input.view.view
self.conv1.size
core.models.synthesis_network.SynthesisBlock
core.distiller.Distiller
distiller.mapping_net.style_dim.args.batch_size.torch.randn.to
batch.self.mapping_net.unsqueeze
w.self.style_inv.self.scale.pow.sum
torch.nn.functional.softplus
calculate_fid_given_paths
InceptionV3
self.loss.loss_g
torch.mean
self.gan_loss.loss_d
modules.MobileSynthesisBlock
calculate_activation_statistics
min
numpy.concatenate
compute_statistics_of_path
self.act
self.skip
torch.nn.BatchNorm2d.train
layers.append
self.gan_loss.reg_d
snet.input.load_state_dict
pred.detach
get_activations
os.getcwd

@developer
Could please help me check this issue?
May I pull a request to fix it?
Thank you very much.

A few questions

Hi there, much thanks for your work on MobileStyleGAN I have a few questions to ask.

  1. How long should the training last?
    I have been training on my custom StyleGAN2 Model for 70+ epochs and the original loss seems to be maintaining at around 7. How many epoch was the original MobileNet FFHQ student network trained for?

  2. What are the losses to look out for?

Epoch 2:  18%|โ–‰    | 2546/14000 [59:37<4:28:15,  1.41s/it, loss=10.041, v_num=1234, l1=5.73, l2=4.59, loss_p=2.04, loss_g=7.92, d_reg=0.0189, loss_d=0.00057, kid_val=0.375, loss_val=18.9]

For the losses above what would be their target value roughly speaking and which of the losses should I take note of or is the most important in this case? Also what are the kid_val target for the FFHQ model in the paper?

  1. Does the kid_val was not in the top True result affect training in any way?

I would much appreciate your reply. Thanks

Edit: Apologies for tagging @bes-dev

Unable to convert to CoreML model. Got Neptune-related error.

Thank you so much for your work! I have some issue using it.

Running the following command.

python train.py --cfg configs/mobile_stylegan_ffhq.json --checkpoint_dir mobilestylegan_ffhq_v2.ckpt --export-model coreml --export-dir .

I get some error related to Neptune.

load mapping network...
load pretrained model: stylegan2_ffhq_config_f_mapping_network.ckpt...
Computing MD5: /tmp/stylegan2_ffhq_config_f_mapping_network.ckpt
MD5 matches: /tmp/stylegan2_ffhq_config_f_mapping_network.ckpt
load synthesis network...
load pretrained model: stylegan2_ffhq_config_f_synthesis_network.ckpt...
Computing MD5: /tmp/stylegan2_ffhq_config_f_synthesis_network.ckpt
MD5 matches: /tmp/stylegan2_ffhq_config_f_synthesis_network.ckpt
/usr/local/lib/python3.7/site-packages/torchvision/models/inception.py:83: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
  ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
Traceback (most recent call last):
  File "train.py", line 65, in <module>
    main(args)
  File "train.py", line 21, in main
    logger = build_logger(cfg.logger)
  File "train.py", line 12, in build_logger
    **cfg.params
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loggers/neptune.py", line 270, in __init__
    self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
  File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loggers/neptune.py", line 358, in _verify_input_arguments
    raise ValueError(legacy_kwargs_msg.format(legacy_kwargs=used_legacy_kwargs))
ValueError: Following kwargs are deprecated: ['offline_mode', 'project_name', 'experiment_name'].
If you are looking for the Neptune logger using legacy Python API, it's still available as part of neptune-contrib package:
  - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html
The NeptuneLogger was re-written to use the neptune.new Python API
  - https://neptune.ai/blog/neptune-new
  - https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning
You should use arguments accepted by either NeptuneLogger.init() or neptune.init()

How can I fix this?

wrong pred

I have 6 GB of GPU.the batchsize is 2.the resolution is 512*512.
I just train 7 epoches. the pred image is whole pink or other color.
Is it because I haven't trained enough or am I wrong?

image

thx.

`compute_mean_style` method generates gray image

Hi @bes-dev,

In issue #36 you told me that compute_mean_style method can be used to compute latent average vector for MobileStyleGAN. But the image produced using this vector via student network is all gray.

I ran the following code to produce image using latent average vector.

# Import libraries.
import cv2
from core.utils import load_cfg, load_weights, tensor_to_img
from core.distiller import Distiller

# Load configuration.
cfg = load_cfg("configs/mobile_stylegan_ffhq.json")

distiller = Distiller(cfg)
style_mean = distiller.compute_mean_style(style_dim=512, wsize=23)
out = distiller.student(style_mean)['img']
cv2.imwrite('style_mean.png', tensor_to_img(out[0]))

I get the following average image as output.

style_mean

Furthermore, every call to compute_mean_style returns a different vector. It's very weird. Shouldn't it be the same? Also larger batch_size doesn't make any difference. :(

style_mean1 = distiller.compute_mean_style(style_dim=512, wsize=23)
style_mean2 = distiller.compute_mean_style(style_dim=512, wsize=23)
style_mean3 = distiller.compute_mean_style(style_dim=512, wsize=23)

# Comparing different means. ๐Ÿ˜… Shouldn't they all be the same 'cause it's just an average. right?
style_mean1 == style_mean2
style_mean1 == style_mean3
style_mean2 == style_mean3

Output

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])
tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])
tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])

I could be missing something. Please let me know what could be the issue on my side.

Regards
Rahul Bhalley

The converted model generate blur images

After the conversion, the model generates the blur images. I use the file "convert_rosinality_ckpt.py" to convert the "550000.pt"(the pretrained model on 256 provided by "https://github.com/rosinality/stylegan2-pytorch") model. In order to see the generated images by styleganv2 pretrained parameters, I modified the

img = self.student(style)["img"]
with "img = self.synthesis_net(style)["img"]" to generate some images. Here are some images below.
image image image image
I don't know what caused this. Have you ever encountered this problem? @bes-dev

MobileStyleGAN Checkpoint converted to ONNX generates grey images

Hi!

Thank you for an amazing repository.
I successfully converted my StyleGAN2-ada rosinality checkpoint, by running the following line:
python convert_rosinality_ckpt.py --ckpt {path_to_rosinality_stylegan2_ckpt} --ckpt-mnet output/mnet.ckpt --ckpt-snet output/snet.ckpt --cfg-path output/config.json

I tested the checkpoint with demo.py and it produces images as expected.

I then converted it to ONNX by running
python train.py --cfg output/config.json --export-model onnx --export-dir onnx-2
and tried to use the converted checkpoint in MobileStyleGAN web demo (https://github.com/cyrildiagne/mobilestylegan-web-demo).
It produces uniform grey images for all seeds. The web demo works fine with the authors' ffhq checkpoint so it seems to be an issue with the converted model.

Do you have any thoughts on what might be causing this?

Screenshot 2022-10-05 at 17 57 19

Is there a Pytorch Lite version?

Thank you for amazing work. I'm in need of a pytorch lite (.ptl) version of the Pytorch (.pt) model. How may I download your .pt or .ptl model somehow? Do you have a torch.hub.load version we can download?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.