Git Product home page Git Product logo

attention_ocr.pytorch's Introduction

attention-ocr.pytorch:Encoder+Decoder+attention model

This repository implements the the encoder and decoder model with attention model for OCR, the encoder uses CNN+Bi-LSTM, the decoder uses GRU. This repository is modified from https://github.com/meijieru/crnn.pytorch
Earlier I had an open source version, but had some problems identifying images of fixed width. Recently I modified the model to support image recognition with variable width. The function is the same as CRNN. Due to the time problem, there is no pre-training model this time, which will be updated later.

requirements

pytorch 0.4.1
opencv_python

cd Attention_ocr.pytorch
pip install -r requirements.txt

Test

pretrained model coming soon

Train

  1. Here i choose a small dataset from Synthetic_Chinese_String_Dataset, about 270000+ images for training, 20000 images for testing. download the image data from Baidu
  2. the train_list.txt and test_list.txt are created as the follow form:
# path/to/image_name.jpg label
path/AttentionData/50843500_2726670787.jpg 情笼罩在他们满是沧桑
path/AttentionData/57724421_3902051606.jpg 心态的松弛决定了比赛
path/AttentionData/52041437_3766953320.jpg 虾的鲜美自是不可待言
  1. change the trainlist and vallist parameter in train.py, and start train
cd Attention_ocr.pytorch
python train.py --trainlist ./data/ch_train.txt --vallist ./data/ch_test.txt

then you can see in the terminel as follow: attentionocr there uses the decoderV2 model for decoder.

The previous version

git checkout AttentionOcrV1

Reference

  1. crnn.pytorch
  2. Attention-OCR
  3. Seq2Seq-PyTorch
  4. caffe_ocr

TO DO

  • change LSTM to Conv1D, it can greatly accelerate the inference
  • change the cnn bone model with inception net, densenet
  • realize the decoder with transformer model

attention_ocr.pytorch's People

Contributors

chenjun2hao avatar marvis avatar meijieru avatar zhangxinnan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

attention_ocr.pytorch's Issues

不定长的识别问题

你好,用您提供的开源模型进行不定长测试,有这两种问题:
1.图片不定长:
transformer = dataset.resizeNormalize((280, 32)),非280会报错,CRNN的处理是按照32的高然后同比例缩放图片的宽,因此输入是(x,32)
2.文字不定长:
可能是因为训练的时候都是10个字,预测的时候不管图片里面几个字,预测结果还都是10个字左右?

举个例子,把图片
20436312_1683447152
中的字去掉几个后,还是280*32输入识别,
2043
结果是这样:
predict_str:,__不愿意意意资(9个字) => prob:0.002346405293792486
20437421_
predict_str:**通信信位主办、《 (10个字) => prob:0.05960559844970703
20437421_21
predict_str:,(通信学会主主府 (9个字) => prob:0.000349084148183465
204
predict_str:叶国通信学会主里”《 (10个字) => prob:0.01799328438937664

想问如何解决?是不是训练需要不定长训练啊?谢谢~

'unexpected key "cnn.7.num_batches_tracked" in state_dict'

运行demo.py的时候,出错?
loading pretrained models ......
Traceback (most recent call last):
File "demo.py", line 34, in
encoder.load_state_dict(torch.load(encoder_path))
File "/usr/lib64/python2.7/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "cnn.7.num_batches_tracked" in state_dict'

报错KeyError:' '

你好,我在运行你的代码时候报错KeyError:' ', 这种是怎么回事呀?
if isinstance(text, str):
text = [self.dict[item] for item in text]

模型准确率的问题

你好,我想问一下,我重新运行了一下你的代码但是21轮epoch之后,识别的准确率还是很低,达不到你所给出的效果。您觉得,这可能与什么原因有关呢?

GO 和END_TOKEN?

这里面训练有加GO(START_TOKEN)和END_TOKEN么?我只在crnn.lang中看到target_txt_decode有加,但是这个函数没有被调用到。
data = val_iter.next()
cpu_images, cpu_texts = data
...
target_variable = converter.encode(cpu_texts)
target_variable = target_variable.cuda()
decoder_input = target_variable[0].cuda()
这里decoder_input val的decode_input(no_teach_forcing)应该是一个GO(START_TOKEN),看上去它调用的是一个cpu_texts的第一个字吧?

超参数设置

请问作者超参设置是程序默认值吗,大概训练多少epoch模型收敛?

训练报错 AttributeError: 'str' object has no attribute 'decode'

AttributeError Traceback (most recent call last)
~/Attention_ocr.pytorch-master/train.py in
9 import numpy as np
10 import os
---> 11 import src.utils as utils
12 import src.dataset as dataset
13 import time

~/Attention_ocr.pytorch-master/src/utils.py in
16 data = f.readlines()
17 alphabet = [x.rstrip() for x in data]
---> 18 alphabet = ''.join(alphabet).decode('utf-8') # python2不加decode的时候会乱码
19
20
调用decode时候报错

不定长测试图片

你好,
我目前还没有条件运行你的程序。
我想先问一下,这个模型可以识别长一点的文本行图片么?
我看了demo程序,里面有设置最大字符个数15,这个值是固定的么?
谢谢。

Related paper title?

Thanks a lot for sharing your experience and pytorch code. I would be happy to read the paper on attenton-encoder-decoder ocr algorithm.

Thanks.

python3 change

Great Thanks for sharing the code!

I found that this code must have been developed with python2.7.

In order to do experiments with python 3.x, I had to change some parts that dealing with unicode & utf-8.

Following is what I did.
dataset.py:
label = line_splits[1]#.decode('utf-8')
utils.py (line 53):
if isinstance(text, str): # python3 string default is unicode #unicode):

ref: https://stackoverflow.com/questions/4987327/how-do-i-check-if-a-string-is-unicode-or-ascii

thanks again for code sharing. It is very much helpful for studying DNNs.

load decoder_path error

hi, thanks your excellent job, I meet the error:
RuntimeError: Error(s) in loading state_dict for decoder:
size mismatch for decoder.embedding.weight: copying a param of torch.Size([17765, 256]) from checkpoint, where the shape is torch.Size([5992, 256]) in current model.
size mismatch for decoder.out.bias: copying a param of torch.Size([17765]) from checkpoint, where the shape is torch.Size([5992]) in current model.
size mismatch for decoder.out.weight: copying a param of torch.Size([17765, 256]) from checkpoint, where the shape is torch.Size([5992, 256]) in current model.

it looks like model your list for inference has size error. so, how to fix it.

解码器权重加载写成加载编码器了

if opt.decoder:
    print('loading pretrained encoder model from %s' % opt.decoder)
    encoder.load_state_dict(torch.load(opt.encoder))

上面这段代码应该是加载decoder, 但其实加载成了encoder,会导致后面测试的时候全是错的

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.