Git Product home page Git Product logo

rainnet's Introduction

RainNet — Official Pytorch Implementation

Sample image

Region-aware Adaptive Instance Normalization for Image Harmonization
Jun Ling, Han Xue, Li Song*, Rong Xie, Xiao Gu

Paper: link
Video: link

Update

  • 2021.07. We trained a 512x512 resolution model with several data augmentation methods, including random flip(horizontal, vertical), and random crop. The PSNR score is 38.14 now. Download the model via Google Drive, Baidu Drive (code: n1fl )

Table of Contents

  1. Introduction
  2. Preparation
  3. Usage
  4. Results
  5. Citation
  6. Acknowledgement

Introduction

This work treats image harmonization as a style transfer problem. In particular, we propose a simple yet effective Region-aware Adaptive Instance Normalization (RAIN) module, which explicitly formulates the visual style from the background and adaptively applies them to the foreground. With our settings, our RAIN module can be used as a drop-in module for existing image harmonization networks and is able to bring significant improvements. Extensive experiments on the existing image harmonization benchmark datasets shows the superior capability of the proposed method.

Preparation

1. Clone this repo:

git clone https://github.com/junleen/RainNet
cd RainNet

2. Requirements

  • Both Linux and Windows are supported, but Linux is recommended for compatibility reasons.
  • We have tested on Python 3.6 with PyTorch 1.4.0 and PyTorch 1.8.1+cu11.

install the required packages using pip:

pip3 install -r requirements.txt

or conda:

conda create -n rainnet python=3.6
conda activate rainnet
pip install -r requirements.txt

3. Prepare the data

  • Download iHarmony4 dataset and extract the images. Because the images are too big in the origianl dataset, we suggest you to resize the images (eg, 512x512, or 256x256) and save the resized images in your local device.
  • We provide the code in data/preprocess_iharmony4.py. For example, you can run:
    python data/preprocess_iharmony4.py --dir_iharmony4 <DIR_of_iHarmony4> --save_dir <SAVE_DIR> --image_size <IMAGE_SIZE>
    This will help you to resize the images to a fixed size, eg, <image_size, image_size>. If you want to keep the aspect ratio of the original images, please run:
    python data/preprocess_iharmony4.py --dir_iharmony4 <DIR_of_iHarmony4> --save_dir <SAVE_DIR> --image_size <IMAGE_SIZE> --keep_aspect_ratio

4. Download our pre-trained model

  • Download the pretrained model from Google Drive or Baidu Drive (code: 3qjk ), and put net_G_last.pth (not net_G.pth) in the directory checkpoints/experiment_train. You can also save the checkpoint in other directories and change the checkpoints_dir and name in /util/config.py accordingly.

Usage

1. Evaluation

We provide the code in evaluate.py, which supports the model evaluation in iHarmony4 dataset.

Run:

python evaluate.py --dataset_root <DATA_DIR> --save_dir evaluated --batch_size 16 --device cuda 

If you want to save the harmonized images, you can add --store_image at the end of the command. The evaluating results will be saved in the evaluated directory.

2. Testing with your own examples

In this project, we also provide the easy testing code in test.py to help you test on other cases. However, you are required to assign image paths in the file for each trial. For example, you can follow:

comp_path = 'examples/1.png' or ['examples/1.png', 'examples/2.png']
mask_path = 'examples/1-mask.png' or ['examples/1-mask.png', 'examples/2-mask.png']
real_path = 'examples/1-gt.png' or ['examples/1-gt.png', 'examples/2-gt.png']

If there is no groundtruth image, you can set real_path to None

3. Training your own model

Please update the command arguments in scripts/train.sh and run:

bash scripts/train.sh

Results

Comparison1 Comparison2

Citation

If you use our code or find this work useful for your future research, please kindly cite our paper:

@inproceedings{ling2021region,
  title={Region-aware Adaptive Instance Normalization for Image Harmonization},
  author={Ling, Jun and Xue, Han and Song, Li and Xie, Rong and Gu, Xiao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={9361--9370},
  year={2021}
}

Acknowledgement

For some of the data modules and model functions used in this source code, we need to acknowledge the repo of DoveNet and pix2pix.

rainnet's People

Contributors

junleen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

rainnet's Issues

关于用自己的图片进行测试

作者您好,我在用您提供的测试代码对单个图片进行测试时遇到了一些问题:
image
我的图片并没有前景mask图,所以无法提供mask_path路径,模型是否可以在没有前景mask的情况下进行和谐化任务呢

关于数据集

训练时用resize的数据集,请问测试时应该用原本的数据集还是resize的数据集?

test accuracy

Hi, I reproduce test performance (PSNR:35.8788, MSE:44.5023 in my experiment) using your provided model weight (net_G.pth, net_G_last.pth).
However, the result seems different from your paper (PSNR:36.12, MSE: 40.29). Especially MSE is much higher than reported in the paper. Could you help me in solving this issue. Thanks.

关于分辨率

我想请问一个问题,对于这个项目的模型,可以处理高分辨率2k、4k的图像或者视频吗,还是说对于任意分辨率的输入,输出结果都是256,有什么方法可以实现高分辨率图像和谐化呢
1-results
image

这是我使用1920*1080测试的结果,得到256分辨率

Interesting Work. But gamma and beta are handled on shifted distributions(Background Style Distribution)

Interesting idea.

However, the $\gamma$ and $\beta$ of the background are processed on the standard normal distribution feature, but the parameters of foreground are processed on the style shifted distribution feature (mea, std of the background).

It just doesn't seem intuitive to me to transfer the style of the background to the foreground.

Maybe only perform the region norm is enough, background norm and foreground norm with the same $\gamma$ and $\beta$ shifting.

关于test.py的问题

我按照指导,尝试用test.py验证附件中/examples/1.png等,但是报错中提示我Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
2021-07-22 18-10-35 的屏幕截图
请问有人遇到过相似的问题吗

about evaluate.py test,

我安装都没有报错,但是运行时候报这个错误了
image
我增加了这个代码:我在evaluate.py 首行
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
之后出现这个问题:
image
image

Training dataset

Hi, I have downloaded the iHarmony4 dataset and extract it, but when I run the script python data/preprocess_iharmony4.py --dir_iharmony4 <DIR_of_iHarmony4> --save_dir <SAVE_DIR> --image_size <IMAGE_SIZE>, it raises

Traceback (most recent call last):
  File "data/preprocess_iharmony4.py", line 22, in <module>
    with open(os.path.join(args.dir_iharmony4, 'IHD_train.txt'), 'r') as f:
FileNotFoundError: [Errno 2] No such file or directory: '.../datasets/iHarmony4/IHD_train.txt'

I have searched through the dataset but the file IHD_train.txt is not found.

the pytorch version about the new model weights which is trained by resolution 512 images

I tryed the new weight file trained by images whose size is 512, but I get the error by this code state_dict = torch.load(load_path):

RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022060824/work/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /opt/conda/conda-bld/pytorch_1579022060824/work/caffe2/serialize/inline_container.cc:132)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7f7024113627 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1f5b (0x7f70286ac9ab in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch.so)

according to the suggestion, my pytorch version is too old(1.4.0), so what the pytorch version when training this model, thanks:)

关于训练流程

您好,是每个batch都更新一下鉴别器和生成器嘛?我按照这样的方式训练好像无法收敛诶,MSE无法降下来

关于训练时使用gpu

您好,我在您的训练代码里没有找到关于规定gpu使用编号的相关代码,但您在train.sh里使用了一个相关参数,请问这个gpu nums的参数是可以使用多卡训练的嘛?您在训练中的相关代码又是在哪个文件里呢?期待您的回复

Question about normalization

Is there any specific reason why discriminator process both spectral and instance normalization on the forward function?
Based on the paper, it only covers spectral normalization for discriminator network.

options中的参数问题

参数'--load_size'为什么是str类型呢?运行train脚本报错,
将其改为int类型,把默认值'0'改为0正常运行。

关于总的epoch数

论文中总epoch数为100,但默认option中总epoch数为100 + 100 = 200。请问是否应该设置为50 + 50?

Training data format

Hi, when training RainNet, does it need any grounth truth labels (real images), or just 2 pair of composite image and its segmentation mask is enough? thank you

error in line 778, networks.py

Hi junleen

I tried to run the training code, but occur an error in

feat_l, feat_g = torch.cat([xf, xb])

File "/ssd3/vis/lintianwei/project/harmonization/RainNet-main/models/networks.py", line 778, in forward
feat_l, feat_g = torch.cat([xf, xb])
ValueError: too many values to unpack (expected 2)

Actually, feat_l and feat_g are not used during training. Is this a bug?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.