Git Product home page Git Product logo

vious / lbam_pytorch Goto Github PK

View Code? Open in Web Editor NEW
131.0 3.0 21.0 1.86 MB

Pytorch re-implementation of Paper: Image Inpainting with Learnable Bidirectional Attention Maps (ICCV 2019)

Home Page: http://openaccess.thecvf.com/content_ICCV_2019/papers/Xie_Image_Inpainting_With_Learnable_Bidirectional_Attention_Maps_ICCV_2019_paper.pdf

License: MIT License

Python 100.00%
image-inpainting iccv pytorch-implementation

lbam_pytorch's Introduction

LBAM_inpainting

Introduction

This is the pytorch implementation of Paper: Image Inpainting With Learnable Bidirectional Attention Maps (ICCV 2019) paper suppl

Model Architecture

We propose a Bidirectional Attention model based on the U-Net architecture. model

Bidrectional Attention Layer

Layer

Prerequisites

  • Python 3.6
  • Pytorch >= 1.0 (tested on pytorch version 1.0.0, 1.2.0, 1.3.0)
  • CPU or NVIDIA GPU + Cuda + Cudnn

Training

To train the LBAM model:

python train.py --batchSize numOf_batch_size --dataRoot your_image_path \
--maskRoot your_mask_root --modelsSavePath path_to_save_your_model \
--logPath path_to_save_tensorboard_log --pretrain(optional) pretrained_model_path

Testing

To test the model:

python test.py --input input_image --mask your_mask --output output_file_prefix --pretrain pretrained_model_path

To test with random batch with random masks:

python test_random_batch.py --dataRoot your_image_path
--maskRoot your_mask_path --batchSize numOf_batch_size --pretrain pretrained_model_path

Some Results

We suggest that you train our model with a large batch size (>= 48 or so). We re-train our model with batch size 10, the results degrades a little bit, I guess it may be due to the batch-normalization opreation (I would try removing bn from LBAM and see how it affects).

The pretrained model can be found at google drive, or baidu cloud with extract code: mvzh. I made a slight change by setting the bn to false and modify the last tanh from absolute value to (tanh() + 1) / 2.

Here are some inpainting results that we train with batch size of 10 on Paris StreetView dataset:

Input Results Ground-Truth

If you find this code would be useful

Please cite our paper

@InProceedings{Xie_2019_ICCV,
author = {Xie, Chaohao and Liu, Shaohui and Li, Chao and Cheng, Ming-Ming and Zuo, Wangmeng and Liu, Xiao and Wen, Shilei and Ding, Errui},
title = {Image Inpainting With Learnable Bidirectional Attention Maps},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
}

Acknowledgement

We benifit a lot from NVIDIA-partialconv and naoto0804-pytorch-inpainting-with-partial-conv, thanks for their excellent work.

lbam_pytorch's People

Contributors

vious avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

lbam_pytorch's Issues

About code error

Hi Thank your great project!
The code works up to PyTorch 1.4.There seems to be an problem with PyTorch 1.6. the description as followed:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:[torch.cuda.FloatTensor [1, 1024, 4, 4]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Can you update the code for PyTorch 1.5 or 1.6?😂

about inputImage, groundTruth, mask

Thanks for your excellent work! I have some question.
In the dataloader.py :
mask = 1 - mask
inputImage = groundTruth * mask
here,Dose zeros denotes the area to be inpainted and ones is the values remained in mask?
so zeros denotes the area to be inpainted in InputImage.
second question:
for a large hole, the result is a little poor, such as severe artifacts. Are there some ways to solve it?
looking forward to your reply.

help for the generateMask.py

Hi ! Thank you for your great project!
Can you tell me the detaill of how to use the generateMask.py to gengeate the mask files?

程序问题

能不能加个好友,我遇到的问题有点多,想请您指导一下,一个初学者,想复现一篇论文,加强学习动力,如果可以,可以加我qq2635505974

Question about number of parameters

Hi,
I have a short question about how many parameters are in your network?
I printed it with:

print(sum([np.prod(_.shape) for _ in netG.parameters()]))

in test.py and got the number 68.3M parameters in the generator. Could you confirm that this is correct?

Thanks!

About the training mask set

Hi ! Thank you for your great project!
Can you tell me the detaill of training mask set? What about the 18,000 masks you mentioned in your paper?

code of loss part may have some problem?

hello, thank you for your excellent work and the open source code.
And I have some questions about InpaintingLoss.py:

D_loss = D_fake - D_real + gp

why not use D_loss = D_fake - D_real + gp?
if D_real = D_real.mean().sum() * -1 and D_fake = D_fake.mean().sum() * 1
I think D_loss = D_fake + D_real + gp

and line 120
GLoss = holeLoss + validAreaLoss + prcLoss + styleLoss + 0.1 * D_fake
should be
GLoss = holeLoss + validAreaLoss + prcLoss + styleLoss - 0.1 * D_fake

thank you for your reply.

Train error

hello,I'm very sorry to disturb you. I use my data to train model, but Training process has a error:
Exception in thread Thread-1:
Traceback (most recent call last):
File "/usr/local/python3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/usr/local/python3/lib/python3.6/site-packages/tensorboardX/event_file_writer.py", line 180, in run
self._ev_writer.write_event(event)
File "/usr/local/python3/lib/python3.6/site-packages/tensorboardX/event_file_writer.py", line 61, in write_event
return self._write_serialized_event(event.SerializeToString())
File "/usr/local/python3/lib/python3.6/site-packages/tensorboardX/event_file_writer.py", line 65, in _write_serialized_event
self._py_recordio_writer.write(event_str)
File "/usr/local/python3/lib/python3.6/site-packages/tensorboardX/record_writer.py", line 121, in write
self._writer.flush()
OSError: [Errno 5] Input/output error
Do you have the same error?

Severe artifacts for central hole

@Vious Thanks for your excellent work! I'm training your model on Paris StreetView dataset with central square hole, but the results all suffer from severe artifacts like these:
350_100
340_1
350_16
I'm training with batchsize=16, do you have any idea why?

About the α

Hello, I would like to ask why the value of α in formula 8 in the paper is 0.8? On what basis?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.