Git Product home page Git Product logo

deep-text-edit's Introduction

Deep-text-edit

This project aims to implement neural network architecture, described in Krishnan et al. (2021) -- Text Style Brush.

Our implementation is unofficial and might contain some differences from the origin implementation. You can find a link to slides from project presentation as well.

How to run ?

  • Install requirements pip install -r requirements.txt
  • Choose config file in src/config folder
  • Log in into wandb if needed wandb login
  • Download models folder from cloud this folder contains all pretrained models which we use. This folder should be in root folder as shown below.
  • Download IMGUR5K dataset: use original download_imgur5k.py script which you can find here. You can clone whole origin repo, it will be easier. Tip: there is a PR with parallel execution of image download.
  • Add prepare_dataset.py script in that repo and run it to preprocess files as we did it.
  • Put prepared dataset in data/ folder of this project.
  • Run python3 run.py './src/config/<chosen config>'. In the most cases python3 run.py './src/config/stylegan_adversarial.py'.

Repo structure

├── run.py                  <- [entry point]
│
├── prepare_dataset.py      <- [our preprocess of images]
│
├── requirements.txt        <- [necessary requirements]
│
├── data                    <- [necessary data (including downloaded datasets)]
|
├── docs                    <- [docs and images]
|
├── models                  <- [pretrained models -- download this folder from cloud]
|
├── src                     <- [project source code]
│   ├── config 
│   │   ├── simple.py           <- [Template Config]
│   │   ├── gan.py
│   │   ├── ...
│   │
│   ├── data
│   │   ├── simple.py           <- [Template CustomDataset]
│   │   ├── ...
│   │
│   ├── disk
│   │   ├── disk.py             <- [Disk class to upload and download data from cloud]
│   │   ├── ...
│   │
│   ├── logger
│   │   ├── simple.py           <- [Logger class to log train and validation process]
│   │   ├── ...
│   │ 
│   ├── losses
│   │   ├── ocr.py              <- [Recognizer Loss]
│   │   ├── perceptual.py
│   │   ├── ...
│   │
│   ├── metrics
│   │   ├── accuracy.py         <- [Accuracy Metric]
│   │   ├── ...
│   │
│   ├── models
│   │   ├── ocr.py              <- [Model for CTC Loss]
│   │   ├── ...
│   │
│   ├── storage
│   │   ├── simple.py           <- [Storage class to save models' checkpoints]
│   │   ├── ...
│   │
│   ├── training
│   │   ├── simple.py           <- [Template Trainer]
│   │   ├── stylegan.py
│   │   ├── ...
│   │
│   ├── utils
│   │   ├── download.py         <- [Tool to download data from remote to cloud]
│   │   ├── ...
│   │
│   ├── ...

Architecture

We started our work from a very simple architecture, shown below:

Baseline

We call it baseline and you can find its config here. We did it because we could and because we needed something to set up work space.

Anyway, we ended up with this architecture, very similar to TextStyleBrush:

final architecture

You can find its config here. It's not perfect, but we did our best -- you can check out results below.

Before you do, there are differences with the original paper:

Subject Us TextStyleBrush
Generator styleGAN styleGAN2
Encoders resNet18 resNet34
Style loss model VGG16 VGG19
Input style size 64 x 192 256 x 256
Input content size 64 x 192 64 x 256
Soft masks no yes
Adversarial loss MSE non-saturating loss with regularization
Discriminator NLayerDiscriminator ??
Text recognizer TRBA ??
Hardware Google Colab resources : ) 8GPUS with 16GB of RAM

Results

results 1

Datasets

Imgur5K

We trained our model using Imgur5K dataset. You can download it using instruction from the origin repo.

What we did: we dowloaded original repo from the link above. We modified download_imgur5k.py a little bit: added ability to proceed download process from the point where it stopped in case of exeptions and added ability to run it in parallel. We do not publish this version because we were afraid of conflicts with their Licence. Anyway you can do it yourself or use code from PR in official repo.

After that we added prepare_dataset.py to that folder and ran it. Output if this script is the dataset which we use. Put in into data/ folder of this project and you are ready to go.

Classes design

We did our best to make classes' names speak for themselves. Anyway, small intro:

  • Config class stored in src/config contains information about experiment configuration: model, loss functions, coefficients, optimaizer, dataset info, tools, etc.
  • Trainer class stored in src/training contains information about experiment training process: train and validation steps, losses' calculation and propagation, etc.

File storage

We use Yandex.disk with 1TB storage to store dataset, logs and checkpoints. Main reason for us to use it -- we had free access to this service.

We understand that usage of this service is not user-friendly for other users and will come up with the solution soon. Right now you can comment out disk class from the code and download necessary datasets manually in data folder.

Requirements & restrictions

  • PyTorch framework
  • Python 3.7.13
  • Type Annotations

Future plans

Tests

  • CI tests

Acknowledgements

deep-text-edit's People

Contributors

grenlayk avatar nikoryagin avatar sphericalpotatoinvacuum avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

deep-text-edit's Issues

Needed Training Time

Hi authors,

Thanks for the great work! I am curious how long does it takes to train the model?

RRDB in colorization

  • Use RRDB in colorization model
  • Train on 32x32 crops
  • 1 channel input and 3 channel output

No such file or directory: 'data/IMGUR5K/train/words.json' ?

This is an interesting work. but i got an error when I try to run the project by follow command
python run.py './src/config/baseline.py'
but I have download IMGUR5K dataset and put it under data floder. So waht should i do to generate training dataset from IMGUR5K?

Inference on provided models

I wrote inference code based off config and training stylegan_adversarial.py but the results look nothing like what is shown in the repository, have I skipped any important step? Code and results attached below

import torch
import numpy as np
from PIL import Image
from pathlib import Path
from random import shuffle
from stylegan import StyleBased_Generator
from embedders import ContentResnet, StyleResnet
from draw import draw_word, img_to_tensor
from nlayer_discriminator import NLayerDiscriminator

from torchvision import transforms as T

def main():
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    weights_folder_name = 'Stylegan (pretrained on content)'
    weights_folder = f'models/{weights_folder_name}'

    model_G = StyleBased_Generator(dim_latent=512)
    model_G.load_state_dict(torch.load(f'{weights_folder}/model'))
    model_G.to(device)

    style_embedder = StyleResnet().to(device) 
    style_embedder.load_state_dict(torch.load(f'{weights_folder}/style_embedder'))

    content_embedder = ContentResnet().to(device)
    content_embedder.load_state_dict(torch.load(f'{weights_folder}/content_embedder'))

    model_G.eval()
    content_embedder.eval()
    style_embedder.eval()

    # Word 'dictionary'
    with open('br-utf8.txt', 'r', encoding='UTF-8') as f:
        lines = f.readlines()

    shuffle(lines)

    style_imgs = []

    # same as draw.img_to_tensor
    transform = T.Compose([
        T.ToTensor(),
        T.Resize((64, 192)),
    ])

    to_pil_image = T.ToPILImage()

    for img_path in Path('images').glob('*.png'):
        with Image.open(img_path) as im:
            img_style = transform(im.convert('RGB'))
            style_imgs.append(img_style)

    desired_content = []
    for i, word in enumerate(lines[:len(style_imgs)]):
        img = draw_word(word)

        img.save(f'results/sample_{i}.png')

        img_content = transform(img)
        desired_content.append(img_content)

    style_imgs = torch.from_numpy(np.array(style_imgs))
    desired_content = torch.from_numpy(np.array(desired_content))

    style_imgs = style_imgs.to(device)
    desired_content = desired_content.to(device)

    style_embeds = style_embedder(style_imgs)
    content_embeds = content_embedder(desired_content)

    preds = model_G(content_embeds, style_embeds)

    for i, pred in enumerate(preds):
        img = to_pil_image(pred)
        img.save(f'results/{i}.png')

if __name__ == '__main__':
    main()

Sem Título-1

error of pretrain models

when train the model ,there is a error:

RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.Transformation.GridGenerator.inv_delta_C", "module.Transformation.GridGenerator.P_hat".

[BUG] extracting info['bounding_box'] in prepare_dataset.py

Encountered an error when running prepare_dataset.py with the following command.

python prepare_dataset.py

mainly due to the following error.

TypeError: crop_minAreaRect() takes 6 positional arguments but 33 were given.

This has to do with the format of imageur5k_annotations_train.json where "bounding_box" data is stored in a string format.

This can be fixed by replacing *info['bounding_box'] with *eval(info['bounding_box']).

How to get deep text edit to generate document-like text?

Hi, I'm finetuning this model to my problem, where i need to edit document text. However, even after finetuning (200 images, 20 epochs), the images generated by model continue to look handwritten, whereas I need them to resemble printed text, like my ground truth is. Is there a way to adapt this approach to generate proper (not handwritten or scene) text? Should I try training from scratch on my images?

Some generated images:
1
0

It should resemble the Arial font here (the background and the colors are pretty spot on though)

Is there any code available?

I am really interested in this amazing work. And currently I am looking for a code implementation to study.
Really looking forward to your code.🙂️

prepare_dataset.py not work

Traceback (most recent call last):
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/types.py", line 608, in convert
st = os.stat(rv)
FileNotFoundError: [Errno 2] No such file or directory: 'dataset_info'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "prepare_dataset.py", line 90, in
main() # pylint: disable=no-value-for-parameter
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 829, in call
return self.main(*args, **kwargs)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 781, in main
with self.make_context(prog_name, args, **extra) as ctx:
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 700, in make_context
self.parse_args(ctx, args)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1048, in parse_args
value, args = param.handle_parse_result(ctx, opts, args)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1623, in handle_parse_result
value = self.full_process_value(ctx, value)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1965, in full_process_value
return Parameter.full_process_value(self, ctx, value)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1592, in full_process_value
value = self.get_default(ctx)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1917, in get_default
return Parameter.get_default(self, ctx)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1534, in get_default
return self.type_cast_value(ctx, rv)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1568, in type_cast_value
return _convert(value, (self.nargs != 1) + bool(self.multiple))
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1565, in _convert
return self.type(value, self, ctx)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/types.py", line 46, in call
return self.convert(value, param, ctx)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/types.py", line 614, in convert
self.path_type, filename_to_ui(value)
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/_compat.py", line 474, in filename_to_ui
value = value.encode("utf-8", "surrogateescape").decode("utf-8", "replace")
AttributeError: 'PosixPath' object has no attribute 'encode'
(venv) [jorj@jorj-systemproductname text-deep-fake-main]$ cd /home/jorj/Загрузки/text-deep-fake-main/IMGUR5K-Handwriting-Dataset-main/
(venv) [jorj@jorj-systemproductname IMGUR5K-Handwriting-Dataset-main]$ python prepare_dataset.py
0it [00:00, ?it/s] 2023-07-07 18:57:47.034 | ERROR | main:main:84 - An error has been caught in function 'main', process 'MainProcess' (252331), thread 'MainThread' (140470161389376):
Traceback (most recent call last):

File "prepare_dataset.py", line 90, in
main() # pylint: disable=no-value-for-parameter

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 829, in call
return self.main(*args, **kwargs)
│ │ │ └ {}
│ │ └ ()
│ └ <function BaseCommand.main at 0x7fc1c157a8c0>

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
│ │ └ <click.core.Context object at 0x7fc1baaf79d0>
│ └ <function Command.invoke at 0x7fc1c150ab90>

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
│ │ │ │ │ └ {'annotations_path': PosixPath('dataset_info'), 'save_path': PosixPath('cropped'), 'no_split': False, 'reduce': None}
│ │ │ │ └ <click.core.Context object at 0x7fc1baaf79d0>
│ │ │ └ <function main at 0x7fc1ba7d9e60>
│ │ └
│ └ <function Context.invoke at 0x7fc1c150b050>
└ <click.core.Context object at 0x7fc1baaf79d0>
File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
│ │ └ {'annotations_path': PosixPath('dataset_info'), 'save_path': PosixPath('cropped'), 'no_split': False, 'reduce': None}
│ └ ()
└ <function main at 0x7fc1ba7d9e60>

File "prepare_dataset.py", line 84, in main
img_cropped = crop_minAreaRect(img, *info['bounding_box'])
│ │ └ {'word': 'SIRIÜS', 'bounding_box': '[97.0, 1011.5, 103.0, 36.0, 4.5]'}
│ └ array([[[165, 210, 224],
│ [163, 210, 224],
│ [163, 210, 224],
│ ...,
│ [148, 200, 213],
│ [150...
└ <function crop_minAreaRect at 0x7fc1c156ecb0>

TypeError: crop_minAreaRect() takes 6 positional arguments but 33 were given
0it [00:00, ?it/s]
2023-07-07 18:57:47.039 | ERROR | click.core:invoke:610 - An error has been caught in function 'invoke', process 'MainProcess' (252331), thread 'MainThread' (140470161389376):
Traceback (most recent call last):

File "prepare_dataset.py", line 90, in
main() # pylint: disable=no-value-for-parameter

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 829, in call
return self.main(*args, **kwargs)
│ │ │ └ {}
│ │ └ ()
│ └ <function BaseCommand.main at 0x7fc1c157a8c0>

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
│ │ └ <click.core.Context object at 0x7fc1baaf79d0>
│ └ <function Command.invoke at 0x7fc1c150ab90>

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
│ │ │ │ │ └ {'annotations_path': PosixPath('dataset_info'), 'save_path': PosixPath('cropped'), 'no_split': False, 'reduce': None}
│ │ │ │ └ <click.core.Context object at 0x7fc1baaf79d0>
│ │ │ └ <function main at 0x7fc1ba7d9e60>
│ │ └
│ └ <function Context.invoke at 0x7fc1c150b050>
└ <click.core.Context object at 0x7fc1baaf79d0>

File "/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
│ │ └ {'annotations_path': PosixPath('dataset_info'), 'save_path': PosixPath('cropped'), 'no_split': False, 'reduce': None}
│ └ ()
└ <function main at 0x7fc1ba7d9e60>

File "prepare_dataset.py", line 85, in main
cv2.imwrite(str(output_path / f'{ann_id}.png'), img_cropped)
│ │ │ └ None
│ │ └ PosixPath('cropped/train')
│ └
└ <module 'cv2.cv2' from '/home/jorj/Загрузки/text-deep-fake-main/venv/lib/python3.7/site-packages/cv2/cv2.cpython-37m-x86_64-l...

cv2.error: OpenCV(4.1.2) /io/opencv/modules/imgcodecs/src/loadsave.cpp:715: error: (-215:Assertion failed) !_img.empty() in function 'imwrite'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.