Git Product home page Git Product logo

esrgan-pytorch's Introduction

ESRGAN-PyTorch

Overview

This repository contains an op-for-op PyTorch reimplementation of ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

Table of contents

Download weights

Download datasets

Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc.

Please refer to README.md in the data directory for the method of making a dataset.

How Test and Train

Both training and testing only need to modify yaml file.

Test ESRGAN_x4

python3 test.py --config_path ./configs/test/ESRGAN_x4-DFO2K-Set5.yaml

Train RRDBNet_x4

python3 train_net.py --config_path ./configs/train/RRDBNet_x4-DFO2K.yaml

Resume train RRDBNet_x4

Modify the ./configs/train/RRDBNet_X4.yaml file.

  • line 34: RESUMED_G_MODEL change to ./samples/RRDBNet_X4-DIV2K/g_epoch_xxx.pth.tar.
python3 train_net.py --config_path ./configs/train/RRDBNet_x4-DFO2K.yaml

Train ESRGAN_x4

Modify the ./configs/train/ESRGAN_X4.yaml file.

  • line 39: PRETRAINED_G_MODEL change to ./results/EDSRGAN_x4-DIV2K/g_last.pth.tar.
python3 train_gan.py --config_path ./configs/train/ESRGAN_x4-DFO2K.yaml

Resume train ESRGAN_x4

Modify the ./configs/train/ESRGAN_X4.yaml file.

  • line 39: PRETRAINED_G_MODEL change to ./results/RRDBNet_x4-DIV2K/g_last.pth.tar.
  • line 41: RESUMED_G_MODEL change to ./samples/EDSRGAN_x4-DIV2K/g_epoch_xxx.pth.tar.
  • line 42: RESUMED_D_MODEL change to ./samples/EDSRGAN_x4-DIV2K/d_epoch_xxx.pth.tar.
python3 train_gan.py --config_path ./configs/train/ESRGAN_x4-DFO2K.yaml

Result

Source of original paper results: https://arxiv.org/pdf/1809.00219v2.pdf

In the following table, the value in () indicates the result of the project, and - indicates no test.

Method Scale Set5 (PSNR/SSIM) Set14 (PSNR/SSIM) BSD100 (PSNR/SSIM) Urban100 (PSNR/SSIM) Manga109 (PSNR/SSIM)
RRDB 4 32.73(32.71)/0.9011(0.9018) 28.99(28.96)/0.7917(0.7917) 27.85(27.85)/0.7455(0.7473) 27.03(27.03)/0.8153(0.8156) 31.66(31.60)/0.9196(0.9195)
ESRGAN 4 -(30.44)/-(0.8525) -(26.28)/-(0.6994) -(25.33)/-(0.6534) -(24.36)/-(0.7341) -(29.42)/-(0.8597)
# Download `ESRGAN_x4-DFO2K-25393df7.pth.tar` weights to `./results/pretrained_models`
# More detail see `README.md<Download weights>`
python3 ./inference.py

Input:

Output:

Build `rrdbnet_x4` model successfully.
Load `rrdbnet_x4` model weights `/ESRGAN-PyTorch/results/pretrained_models/ESRGAN_x4-DFO2K.pth.tar` successfully.
SR image save to `./figure/ESRGAN_x4_baboon.png`

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

Credit

ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks

Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Chen Change Loy, Yu Qiao, Xiaoou Tang

Abstract
The Super-Resolution Generative Adversarial Network (SRGAN) is a seminal work that is capable of generating realistic textures during single image super-resolution. However, the hallucinated details are often accompanied with unpleasant artifacts. To further enhance the visual quality, we thoroughly study three key components of SRGAN - network architecture, adversarial loss and perceptual loss, and improve each of them to derive an Enhanced SRGAN (ESRGAN). In particular, we introduce the Residual-in-Residual Dense Block (RRDB) without batch normalization as the basic network building unit. Moreover, we borrow the idea from relativistic GAN to let the discriminator predict relative realness instead of the absolute value. Finally, we improve the perceptual loss by using the features before activation, which could provide stronger supervision for brightness consistency and texture recovery. Benefiting from these improvements, the proposed ESRGAN achieves consistently better visual quality with more realistic and natural textures than SRGAN and won the first place in the PIRM2018-SR Challenge. The code is available at this https URL.

[Paper] [Author's implements(PyTorch)]

@misc{wang2018esrgan,
    title={ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks},
    author={Xintao Wang and Ke Yu and Shixiang Wu and Jinjin Gu and Yihao Liu and Chao Dong and Chen Change Loy and Yu Qiao and Xiaoou Tang},
    year={2018},
    eprint={1809.00219},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

esrgan-pytorch's People

Contributors

ha0tang avatar lornatang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

esrgan-pytorch's Issues

Missing release models.

It seems as though the models to be downloaded by download_weights.py are not in the project releases. Note ESRGAN_PSNR_X4.pth is available elsewhere online but I can't seem to find a copy of ESRGAN_RRDB_X4.pth anywhere.

Would you be able to publish those releases?

Thank you!

Question about pretrained model

Hi, Mr. Liu.
thanks for your concise code! As a beginner, I learned a lot from reading your code.

Now I got a question about pretrained model. I can not download pre-trained model from link "https://github.com/Lornatang/ESRGAN-PyTorch/releases/download/0.1.0/ESRGAN16_DF2K-a03a643d.pth" when I execute "python test_benchmark.py -a esrgan16 --pretrained --gpu 0 data" in the command line. I try to open the link in Chrome browser and also got "404 not found".
I think maybe something wrong with the pre-trained model link, or my way of getting model is not right.

I am looking forward to your reply. You can contact me [email protected].
Thanks again.

Pretrained weights

Hi, Lornatang.

First of all, thanks for your great work!

I just found a new issue on Jul 17 about the Adversarial loss term.
I wonder if the provided pretrained weights of RRDB/ESRGAN/Discriminator are...

  1. trained weights from your previous code (the code with the error specified in the previous github issue)
  2. or pretrained weights from the official github (possibly slightly modified to match you implementation of ESRGAN)

If it is for the first case, are the results in the README.md file also wrong?
To me, the error in the adversarial term seems to have a severe impact on the generator performance, since the generator would have been trained in the opposite way. (But, both the qualitative and quantitative results seemed quite similar to the original paper.)

convert to onnx

hello how to convert to onnx the pth.tar?thank you
sorry solved: converting to onnx direct from pytorch

Inquiries on AMP

Hello Lornatang.

Thanks for your great work.
I'm very impressed with the simple/straightforward yet efficient re-implementation of published works.

To the best of my knowledge, a slight performance decrease can occur when training generative models with mixed precision.
Although this research was based on unconditional generative models, I do believe that it can stand equal to conditional generative models as SR networks.

Thus, I was wondering if there are any ablations on the performance with and without AMP. (Either RRDB or ESRGAN will help)
Also, I'm curious about how much faster the training is with and without AMP.

Thank you in advance.

Can you reproduce the psnr?

I follow the instructions and try hyper-parameter tuning. 100, 200, 300 epochs, learning rate 2e-4, image_size 128 or 192, bs=16,32, 64. I can only get 31.7 dB PSNR on SET5 testset, still 1dB from reported in this repo. Can someone reproduce this 32.7dB on SET5?

Negative Values after substracitng mean in a

Fist of all: Thank you very much for sharing your code!

I happens from time to time that in the train_adversarial function the final tensor in the loss function becomes negative:
d_loss_hr = adversarial_criterion(hr_output - torch.mean(sr_output), real_label)
(Mostly when the torch.mean(sr_output) is larger then the hr_output values)

Is that only a problem on my side, or is that known?

pin memory与pytorch版本

大佬,我在训练完一个epoch后,出现了下面问题,是不是pytorch版本没对应好呢?我的pytorch版本为1.7

Start train ESRGAN model.
Epoch: [1][ 1000/16696]	Time  0.370 ( 0.382)	Data  0.001 ( 0.001)	Pixel loss 0.037700 (0.052648)	Content loss 0.822794 (1.104305)	Adversarial loss 0.012794 (0.010927)	D(HR)  0.995 ( 0.988)	D(SR)  0.000 ( 0.006)	PSNR 22.59 (21.44)
Epoch: [1][ 2000/16696]	Time  0.379 ( 0.380)	Data  0.001 ( 0.001)	Pixel loss 0.029337 (0.041848)	Content loss 0.697089 (0.903808)	Adversarial loss 0.013856 (0.012345)	D(HR)  0.998 ( 0.991)	D(SR)  0.000 ( 0.003)	PSNR 24.02 (22.79)
Epoch: [1][ 3000/16696]	Time  0.370 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.033896 (0.037278)	Content loss 0.733463 (0.811156)	Adversarial loss 0.014888 (0.013205)	D(HR)  0.998 ( 0.993)	D(SR)  0.000 ( 0.002)	PSNR 23.61 (23.53)
Epoch: [1][ 4000/16696]	Time  0.380 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.027832 (0.034183)	Content loss 0.640709 (0.751329)	Adversarial loss 0.014274 (0.013682)	D(HR)  0.996 ( 0.994)	D(SR)  0.000 ( 0.002)	PSNR 26.16 (24.17)
Epoch: [1][ 5000/16696]	Time  0.380 ( 0.378)	Data  0.001 ( 0.001)	Pixel loss 0.022432 (0.032105)	Content loss 0.532851 (0.710133)	Adversarial loss 0.012825 (0.013611)	D(HR)  0.813 ( 0.972)	D(SR)  0.000 ( 0.002)	PSNR 27.34 (24.69)
Epoch: [1][ 6000/16696]	Time  0.379 ( 0.378)	Data  0.001 ( 0.001)	Pixel loss 0.024761 (0.030410)	Content loss 0.607785 (0.679033)	Adversarial loss 0.014104 (0.013516)	D(HR)  0.356 ( 0.945)	D(SR)  0.000 ( 0.002)	PSNR 27.39 (25.16)
Epoch: [1][ 7000/16696]	Time  0.380 ( 0.378)	Data  0.001 ( 0.001)	Pixel loss 0.019196 (0.029033)	Content loss 0.508872 (0.654894)	Adversarial loss 0.013619 (0.013771)	D(HR)  0.991 ( 0.935)	D(SR)  0.000 ( 0.002)	PSNR 28.68 (25.58)
Epoch: [1][ 8000/16696]	Time  0.380 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.018326 (0.027872)	Content loss 0.526340 (0.634764)	Adversarial loss 0.011635 (0.013911)	D(HR)  0.000 ( 0.847)	D(SR)  0.000 ( 0.002)	PSNR 28.88 (25.96)
Epoch: [1][ 9000/16696]	Time  0.384 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.019149 (0.026855)	Content loss 0.552133 (0.617711)	Adversarial loss 0.012062 (0.014202)	D(HR)  0.000 ( 0.791)	D(SR)  0.000 ( 0.002)	PSNR 29.79 (26.32)
Epoch: [1][10000/16696]	Time  0.379 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.015999 (0.025921)	Content loss 0.565542 (0.602837)	Adversarial loss 0.015624 (0.014473)	D(HR)  0.014 ( 0.726)	D(SR)  0.000 ( 0.001)	PSNR 30.17 (26.66)
Epoch: [1][11000/16696]	Time  0.380 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.018763 (0.025180)	Content loss 0.423601 (0.590442)	Adversarial loss 0.019162 (0.014849)	D(HR)  0.000 ( 0.669)	D(SR)  0.000 ( 0.001)	PSNR 29.96 (26.96)
Epoch: [1][12000/16696]	Time  0.379 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.012981 (0.024501)	Content loss 0.429494 (0.579708)	Adversarial loss 0.021525 (0.015203)	D(HR)  0.006 ( 0.632)	D(SR)  0.000 ( 0.001)	PSNR 31.68 (27.22)
Epoch: [1][13000/16696]	Time  0.380 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.021055 (0.023893)	Content loss 0.483494 (0.569777)	Adversarial loss 0.021465 (0.015379)	D(HR)  0.000 ( 0.597)	D(SR)  0.000 ( 0.001)	PSNR 29.54 (27.47)
Epoch: [1][14000/16696]	Time  0.380 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.017674 (0.023372)	Content loss 0.444889 (0.561240)	Adversarial loss 0.017812 (0.015906)	D(HR)  0.000 ( 0.563)	D(SR)  0.000 ( 0.001)	PSNR 30.26 (27.69)
Epoch: [1][15000/16696]	Time  0.380 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.017920 (0.022861)	Content loss 0.408098 (0.553281)	Adversarial loss 0.020351 (0.016176)	D(HR)  0.002 ( 0.536)	D(SR)  0.000 ( 0.001)	PSNR 30.25 (27.92)
Epoch: [1][16000/16696]	Time  0.379 ( 0.379)	Data  0.001 ( 0.001)	Pixel loss 0.013671 (0.022394)	Content loss 0.497734 (0.545935)	Adversarial loss 0.016718 (0.016400)	D(HR)  0.000 ( 0.510)	D(SR)  0.000 ( 0.002)	PSNR 31.51 (28.12)
Valid: [   0/3843]	Time  0.408 ( 0.408)	PSNR 37.65 (37.65)
Valid: [1000/3843]	Time  0.080 ( 0.081)	PSNR 34.63 (36.69)
Valid: [2000/3843]	Time  0.081 ( 0.081)	PSNR 37.63 (36.68)
Valid: [3000/3843]	Time  0.097 ( 0.082)	PSNR 37.98 (36.69)
* PSNR: 36.67.

Exception in thread Thread-2:
Traceback (most recent call last):
  File "/home/myuser/.conda/envs/esrgan_lornatang/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/myuser/.conda/envs/esrgan_lornatang/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/myuser/.conda/envs/esrgan_lornatang/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 28, in _pin_memory_loop
    idx, data = r
ValueError: not enough values to unpack (expected 2, got 0)

Traceback (most recent call last):
  File "train_esrgan.py", line 390, in <module>
    main()
  File "train_esrgan.py", line 86, in main
    writer)
  File "train_esrgan.py", line 216, in train
    for index, (lr, hr) in enumerate(train_dataloader):
  File "/home/myuser/.conda/envs/esrgan_lornatang/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 349, in __iter__
    self._iterator._reset(self)
  File "/home/myuser/.conda/envs/esrgan_lornatang/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 852, in _reset
    data = self._get_data()
  File "/home/myuser/.conda/envs/esrgan_lornatang/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1029, in _get_data
    raise RuntimeError('Pin memory thread exited unexpectedly')
RuntimeError: Pin memory thread exited unexpectedly

关于 train_gan.py

d_model = model.dict[config["MODEL"]["D"]["NAME"]](in_channels=config["MODEL"]["D"]["IN_CHANNELS"],
out_channels=config["MODEL"]["D"]["OUT_CHANNELS"],
channels=config["MODEL"]["D"]["CHANNELS"],
upsample_method=config["MODEL"]["D"]["UPSAMPLE_METHOD"]
)
config["MODEL"]["D"]["UPSAMPLE_METHOD"]在配置文件中并未给出,它有什么作用吗?

相对GAN的损失函数

Hi,大佬您好!
在train_esrgan.py文件中 对抗损失函数这个地方的设置我不太明白。
代码里 更新D网络时 和 更新G网络时的对抗损失 是一样的。
我的理解是更新D时,
对于HR:d_hr_loss = adversarial_criterion( Dra(xr,xf) , real_label)
对于SR :d_sr_loss = adversarial_criterion( Dra(xf,xr) , fake_label)
更新G时,
对于HR:d_hr_loss = adversarial_criterion( Dra(xr,xf) , fake_label)
对于SR :d_sr_loss = adversarial_criterion( Dra(xf,xr) , real_label)
这样才和论文中写的判别器损失和生成器损失公式对应上了,
请问是我对相对判别器的损失理解有误还是代码里笔误了?

a small bug in train_gan.py ?

Hi, thanks for your code and active issue explanation, benefiting a lot from it~ It seems that line 147 and 148 in train_gan.py are all "g_optimizer", I wonder the second should be "d_optimizer"?

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient

When the code run at 'scaler.scale(d_loss).backward()' , I got an error:"one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)."

It has been bothering me for many days, can someone help me?

About release model

Hi

Could you release a pretrained model of this repo?
You've mention ESRGAN_X4.pth in README but i just can't find it.
I'm just curious about the performance of ESRGAN, and I'm going to retrain the model with my project's dataset.
Thanks very much!

The adversarial loss

YL9B6ROC(0MA@G1$H(P_08S
Thanks very much for your work, I haved benefit much from it ~ But I found your adversarial loss is a little different from the paper. Maybe it is similiar to SRGAN's adversarial loss ? just like make fake_img more realistic? I found ESRGAN also make real_img less realistic... Or I misunderstand it ?

I want to change it to make it the same as the paper, like this change in your code:

adversarial_loss = adversarial_criterion(fake_output - torch.mean(real_output), real_label) +
adversarial_criterion(real_output - torch.mean(fake_output), fake_label)

Is it right? Thank you in advance for your help

Why the volatile gpu util is low(less than 20%)

Hello, I train your program on my laptop with 2080ti but it turns out that during training the volatile GPU util is low(5% - 18%). But your SRGAN-PyTorch model can get nearly 100% volatile GPU util. I wonder why the ESRGAN-PyTorch model has such low GPU util?

resume训练模型时不太对劲

大佬,我发觉我resume模型继续训练时,虽然resume成功了,但继续训练时的精度显示,好像重新训练的那样

I Want to ask about your old code

hello~
I'm using your last code(befor update new code). that code didnt use sigmo but torch.full, It make diffrent result image??

I wonder if it would be better to experiment with your new code to check the esrgan results. Is there a big difference in the results?

修改features.34

请问应该如何修改features.34为features.54呢?直接修改配置文件会报错。
ValueError: node: 'features.54' is not present in model. Hint: use get_graph_node_names to make sure the return_nodes you specified are present. It may even be that you need to specify train_return_nodes and eval_return_nodes separately.

Test Spatial size?

what spatial size did you use for testing the model? I didnot obtain the stated results?

请问是否尝试在ESRGAN上用更大的缩放倍数?比如X8? X16?

Hi,大佬您好!
我尝试用ESRGAN做8倍缩放的超分辨,我用了Generator model中三次 upsampling layer来放大8倍 以及 换成pixelshuffle的上采样方式,这两种情况下训练出来的图像质量都很差,我设置的HR_size是256的大小,LR_size则对应32,请问是否是LR太小,upscale_factor太大 导致超分结果不好?还是说RRDBNet的结构设计就不太适合太大缩放倍率的超分任务?

ESRGAN-Training

For my own 128*128 image dataset

TypeError: 'type' object is not subscriptable

Keep ratio input image not upscale

Thank for your great repo! I have question : "how can I keep original ratio image (height, width) after infer, not upscale bigger 4 times.
If can, Can you share how to config it?
Many thanks

内存占用

我想问一下,训练这个ESRGAN需要多少内存,我有个服务器16g显存,都报内存不够

PSNR

What is the final PSNR of this model?

My operation uses the esrgan23 model.,the default value is used for training epoch, but the final PSNR trained is less than 20. What's the matter?

Version Requirements not specified in Readme.

Hey!
Can you please list the version requirements of python, PyTorch, and torchvision?
I am not able to run the code because I am getting an error in importing torchvision.transforms and torchvision.models.
I am currently running the code on below mentioned versions:
pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3

Model Weights

Hi,
I saw a lot of different model inside the google drive that you provide to get the weights of this model. Do you know if it is possible to find on github the implementation of the other models?
Thank you!

How to test on rrdbnet?

Great work! Just one question: How can I test my retrained rrdbnet? The test.py only support esrgan now.

The problem when training 2X scale dataset

Thanks for the code, I have tried to use my own dataset to train ESRGAN, I just want to train the 2X scale.
But when I adjust the parameter in config.py, it occur the error as shown below.

image

Error like this :
image

Do I need to change the code in another column? Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.