Git Product home page Git Product logo

keras-ctpn's Introduction

keras-ctpn

[TOC]

  1. 说明
  2. 预测
  3. 训练
  4. 例子
    4.1 ICDAR2015
    4.1.1 带侧边细化
    4.1.2 不带带侧边细化
    4.1.3 做数据增广-水平翻转
    4.2 ICDAR2017
    4.3 其它数据集
  5. toDoList
  6. 总结

说明

​ 本工程是keras实现的CPTN: Detecting Text in Natural Image with Connectionist Text Proposal Network . 本工程实现主要参考了keras-faster-rcnn ; 并在ICDAR2015和ICDAR2017数据集上训练和测试。

​ 工程地址: keras-ctpn

​ cptn论文翻译:CTPN.md

效果

​ 使用ICDAR2015的1000张图像训练在500张测试集上结果为:Recall: 37.07 % Precision: 42.94 % Hmean: 39.79 %; 原文中的F值为61%;使用了额外的3000张图像训练。

关键点说明:

a.骨干网络使用的是resnet50

b.训练输入图像大小为720*720; 将图像的长边缩放到720,保持长宽比,短边padding;原文是短边600;预测时使用1024*1024

c.batch_size为4, 每张图像训练128个anchor,正负样本比为1:1;

d.分类、边框回归以及侧边细化的损失函数权重为1:1:1;原论文中是1:1:2

e.侧边细化与边框回归选择一样的正样本anchor;原文中应该是分开选择的

f.侧边细化还是有效果的(注:网上很多人说没有啥效果)

g.由于有双向GRU,水平翻转会影响效果(见样例做数据增广-水平翻转)

h.随机裁剪做数据增广,网络不收敛

预测

a. 工程下载

git clone https://github.com/yizt/keras-ctpn

b. 预训练模型下载

​ ICDAR2015训练集上训练好的模型下载地址: google drive百度云盘 取码:wm47

c.修改配置类config.py中如下属性

	WEIGHT_PATH = '/tmp/ctpn.h5'

d. 检测文本

python predict.py --image_path image_3.jpg

评估

a. 执行如下命令,并将输出的txt压缩为zip包

python evaluate.py --weight_path /tmp/ctpn.100.h5 --image_dir /opt/dataset/OCR/ICDAR_2015/test_images/ --output_dir /tmp/output_2015/

b. 提交在线评估 将压缩的zip包提交评估,评估地址:http://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1

训练

a. 训练数据下载

#icdar2013
wget http://rrc.cvc.uab.es/downloads/Challenge2_Training_Task12_Images.zip
wget http://rrc.cvc.uab.es/downloads/Challenge2_Training_Task1_GT.zip
wget http://rrc.cvc.uab.es/downloads/Challenge2_Test_Task12_Images.zip
#icdar2015
wget http://rrc.cvc.uab.es/downloads/ch4_training_images.zip
wget http://rrc.cvc.uab.es/downloads/ch4_training_localization_transcription_gt.zip
wget http://rrc.cvc.uab.es/downloads/ch4_test_images.zip
#icdar2017
wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_training_images_1~8.zip
wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_training_localization_transcription_gt_v2.zip
wget -c -t 0 http://datasets.cvc.uab.es/rrc/ch8_test_images.zip

b. resnet50与训练模型下载

wget https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5

c. 修改配置类config.py中,如下属性

	# 预训练模型
    PRE_TRAINED_WEIGHT = '/opt/pretrained_model/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'

    # 数据集路径
    IMAGE_DIR = '/opt/dataset/OCR/ICDAR_2015/train_images'
    IMAGE_GT_DIR = '/opt/dataset/OCR/ICDAR_2015/train_gt'

d.训练

python train.py --epochs 50

例子

ICDAR2015

带侧边细化

不带侧边细化

做数据增广-水平翻转

ICDAR2017

其它数据集

toDoList

  1. 侧边细化(已完成)
  2. ICDAR2017数据集训练(已完成)
  3. 检测文本行坐标映射到原图(已完成)
  4. 精度评估(已完成)
  5. 侧边回归,限制在边框内(已完成)
  6. 增加水平翻转(已完成)
  7. 增加随机裁剪(已完成)

总结

  1. ctpn对水平文字检测效果不错
  2. 整个网络对于数据集很敏感;在2017上训练的模型到2015上测试效果很不好;同样2015训练的在2013上测试效果也很差
  3. 推测由于双向GRU,网络有存储记忆的缘故?在使用随机裁剪作数据增广时网络不收敛,使用水平翻转时预测结果也水平对称出现

keras-ctpn's People

Contributors

yizt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

keras-ctpn's Issues

预测时报错numpy.linalg.linalg.LinAlgError: SVD did not converge

您好,我在使用您的模型预测时,报了错误:
File predict.py line 64 in module main(argments)
File "...image_utils.py" in load_image_gt
image, window, scale,padding = resize_image(image, output_size)
File "...image_utils.py" in resize_image
order=1, mode='constant',cval=0,clip=True,preserve_range=True
FIle ".../skimage/transform/_warps.py", line 165 in resize
tform.estimate(src_corners,dst_corners)
File ".../skimage/transform/_geometric.py", line679 in estimate
_,_,V=np.linalg.svd(A)
...
numpy.linalg.linalg.LinAlgError: SVD did not converge
请问是什么原因呢?谢谢!

看源码时有一个疑问

def smooth_l1_loss(y_true, y_predict, sigma2=9.0):
    abs_diff = tf.abs(y_true - y_predict, name='abs_diff')
    loss = tf.where(tf.less(abs_diff, 1. / sigma2), 0.5 * sigma2 * tf.pow(abs_diff, 2), abs_diff - 0.5 / sigma2)
    return loss

在源码中有一个关于smooth l1的loss的函数,我网上查询到的smooth l1的定义是
image
当函数sigma2=1的时候和网上定义的是一致的,
但是我看大佬这里用的是9.0的默认值,说实话以我的水平第一眼觉得是不是应该是0.9?
想请问大佬这里sigma2的默认值为什么要设置成9.0,有什么用意或者是经过测试这个数字比较好吗

求助

如果我想基于您的模型进行证件文本检测,是否需要用自己已标记的数据集重新训练?我最近也在想用什么模型或框架来处理这个问题。
因为证件文本识别相比于自然场景下的文本识别要规范很多,是否可以先训练一个模型识别出证件的边框,然后就可以取固定坐标的文本了?这样是否更简单
我这方面刚接触并不懂,期待您的回复

text_proposals.py文件apply_regress函数的侧边精调代码是不是有问题?第38行

text_proposals.py文件apply_regress函数第38行代码
cx += dx * w

而target.py文件中side_regress_target函数中代码(第83行),dx计算方式为:
dx = (gt_center_x - center_x) * 2 / w

貌似 不能这么计算吧? cx += dx * w

后面有cx + w * 0.5, cx - w * 0.5 ;
这里w是anchor_box的width,也不是预测box的width,是不是有问题?

环境版本不适配

tensorflow-gpu 1.14.0 requires keras-applications>=1.0.6, but you have keras-applications 1.0.2 which is incompatible.
tensorflow-gpu 1.14.0 requires keras-preprocessing>=1.0.5, but you have keras-preprocessing 1.0.1 which is incompatible.
你好 我安装了python3.7和tensorflow-gpu==1.14和keras==2.2.0 有如上错误,请问作者是否遇到过,谢谢!

target.py正样本问题?不同的gt选择同一个anchor?

合并两部分正样本索引

positive_bool_matrix = tf.logical_or(gt_iou_max_bool, anchors_iou_max_bool)

感觉按照target.py文件前面的规则,有可能不同的gt选择了同一个anchor;

而后面的逻辑选择正样本anchor时,每次随机抽取一部分(正样本anchor),其中会不会多次出现同一个anchor选择了不同的gt?
或者不同epoch中,同一张图片里面,相同索引的anchor选择了不同的gt?

标志位问题

tag 1:正样本,0:负样本,-1 padding

true_cls_ids和indices的标志位说明是否错误?在target.py文件中,cls_ids的正负样本tag均为1,而padding样本为0;indices的tag生成为,正样本为1,负样本为-1,padding样本为0。
请查看,谢谢!

关于多分类的问题

代码中和训练有关的代码反复看了好几遍,算大致搞懂了流程,想基于这个模型做点调整,代码中文本分类只有两类,文字和背景,我想修改成文字多分类,比如有10类,维度什么就跟着调整了下从原来的2 变成了11,但是训练开始没多久loss就开始无限增大了,估计是和修改的地方有关系 。自己折腾了很久也没看出哪里不对。。大佬有什么建议么

有个疑问
一开始的这个Input 参数2 代表的是2个分类还是第一索引是分类,第二索引是padding?

gt_class_ids = Input(shape=(config.MAX_GT_INSTANCES, 2), name='gt_class_ids')

predict

预测识别效果不大好,请问什么原因。还有请问下,一行字是如何cv2.line的?

训练模型时的问题

AttributeError: 'google.protobuf.pyext._message.FieldProperty' object has no attribute 'allow_growth'

关于度量的数据类型

# dtype=[tf.float32] * 2 + [tf.int64] + [tf.float32] + [tf.int64] + [tf.float32] * 3)

在前面ctpn_target_graph的输出中,后三个度量的类型均为tf.float32,可是这里倒数第四个指定类型为tf.int64,是否应为tf.float32呢?

Change config.py file

Thanks for your sharing!
But I'm very confusing when change TRAIN_ANCHORS_PER_IMAGE and TEXT_PROPOSALS_MAX_NUM. How can I decide exactly what it is with my dataset?

多GPU训练

请问一下楼主,您有使用过多GPU训练吗?
我使用keras的multi_gpu_model报错了:
IndexError: list assignment index out of range

预测图片生成结果

检测好文本行后生成的结果图片固定设置到了1600*1600,怎么保持测试图片原尺寸输出?

debug code

@yizt How to debug function like ctpn_target_graph() in target.py?I used pycharm to debug but value =tensor()...I want to see the specific values when putting an image into the network.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.