Git Product home page Git Product logo

tensorflow-mtcnn's Introduction

tensorflow-MTCNN

人脸检测MTCNN算法,采用tensorflow框架编写,从理解到训练,中文注释完全,含测试和训练,支持摄像头,代码参考AITTSMD,做了相应删减和优化。

模型理解

MTCNN是目前比较流行的人脸检测方法,通过人脸检测可以进行更精准的人脸识别。模型主要通过PNet,RNet,ONet三个网络级联,一步一步精调来对人脸进行更准确的检测。论文中的模型图如下:


接下来我会从我在训练中的理解来解释MTCNN模型都干了什么。

三个模型要按顺序训练,首先是PNet,然后RNet,最后ONet。

PNet:

PNet是全卷积网络,主要为了应对不同输入尺度,层数很浅,主要作用是尽可能多的把人脸框都选进来,宁愿错误拿来好多个,也不丢掉一个。训练数据由四部分组成:pos,part,neg,landmark,比例为1:1:3:1。数据是怎么来的呢?

pos,part,neg是随机和人脸的数据裁剪得到的,裁剪图片与人脸框最大的iou值大于0.65的为pos图像,大于0.4的为part图像,小于0.3的为neg图像,landmark截取的是带有关键点的图像。其中pos,part的label含有它们的类别1,-1还有人脸框相对于图像左上角的偏移量,偏移量除以图像大小做了归一化;neg的label只含有类别0;landmark的label含有类别-2和5个关键点的坐标偏移也是进行了归一化的。

这四种图像都resize成12x12作为PNet的输入,通过PNet得到了是否有人脸的概率[batch,2],人脸框的偏移量[batch,4],关键点的偏移量[batch,10]。四种不同的数据该怎么训练呢?

对于是否存在人脸的类别损失只通过neg和pos数据来对参数进行更新,具体办法是通过label中的类别值做了一个遮罩来划分数据,只计算neg和pos的损失,不计算其他数据的损失;人脸框的损失只计算pos和part数据的;关键点的损失只计算landmark的。论文中有个小技巧就是只通过前70%的数据进行更新参数,说是模型准确率会有提升,在代码中也都有体现,具体实现可以参考代码。

RNet,ONet:

RNet和ONet都差不多都是精修人脸框,放在一起解释。RNet的landmark数据和PNet一样,是对带有关键点的图像截取得到,但要resize成24x24的作为输入。

pos,part,neg的数据是通过PNet得到的。这里就可以理解为什么PNet的输入要是四种数据大小是12了,为了速度,也为了RNet的输入。一张图片输入到PNet中会得到[1,h,w,2],[1,h,w,4],[1,h,w,10]的label预测值,这有点像yolo的**,如果不理解yolo的可以参考我的yolo介绍

把一张图片像网格一样划分,每一个网格都预测它的人脸框,划分的图片包含的人脸有多有少,所以划分了neg,pos,part三种数据,landmark只是起辅助作用。图片还要以一定值来缩小尺度做成图像金字塔目的是获取更多可能的人脸框,人脸框中有人的概率大于一定阈值才保留,还要进行一定阈值的非极大值抑制,将太过重合的人脸框除掉,将PNet预测的人脸框于原图上截取,与真实人脸框的最大iou值来划分neg,pos,part数据,并resize成24作为RNet的输入。

RNet,ONet的损失函数和PNet相同,不同的是三种损失所占的比例不同。
ONet的输入是图片通过PNet金字塔得到的裁剪框再经过RNet的裁剪框裁剪的图片划分neg,pos,part三种数据resize成48作为输入,landmark与RNet相同只不过resize成48大小的了。

代码介绍

环境说明

ubuntu16.04
python3.6.5
tensorflow1.8.0
opencv3.4.3
pip install tqdm为了显示进度条

代码介绍

data下放置训练所用的原始数据和划分数据,生成的tfrecord等

detection下的fcn_detector.py主要用于PNet的单张图片识别,detector.py用于RNet和ONet的一张图片通过PNet截取的多个人脸框的批次识别,MtcnnDetector.py为识别人脸和生成RNet,ONet输入数据

graph里放置的是训练过程中生成的graph文件

output里放置识别图像或视频完成后存储放置的路径

picture里是要测试的图像放置路径

preprocess里是预处理数据程序,BBox_utils.py和utils.py,loader.py是一些辅助程序,gen_12net_data.py是生成PNet的pos,neg,part的程序,gen_landmark_aug.py是生成landmark数据的程序,gen_imglist_pnet.py是pnet的四种数据组合一起,gen_hard_example.py是生成rnet,onet的三种数据程序,gen_tfrecords.py是生成tfrecord文件的程序

train中的config是一些参数设定,大都文件夹我都直接写死了,所以里面参数能改的很少,model.py是模型,train.py是训练,train_model.py针对不同网络训练

test.py是测试代码

下载数据

WIDERFace的训练数据下载解压,将里面的WIDER_train文件夹放置到data下,将Deep Convolutional Network Cascade for Facial Point Detection的训练集解压,将里面的lfw_5590和net_7876文件夹放置到data下。model文件夹下已存储好我训练的权值文件了。

运行

训练:

将目录cd到preprocess上,
python gen_12net_data.py生成三种pnet数据,
python gen_landmark_aug.py 12 生成pnet的landmark数据,
python gen_imglist_pnet.py整理到一起,
python gen_tfrecords.py 12生成tfrecords文件
将目录cd到train上python train.py 12 训练pnet

将目录cd到preprocess上,
python gen_hard_example.py 12 生成三种rnet数据,
python gen_landmark_aug.py 24 生成rnet的landmark数据,
python gen_tfrecords.py 24生成tfrecords文件
将目录cd到train上python train.py 24 训练rnet

将目录cd到preprocess上,
python gen_hard_example.py 24 生成三种onet数据,
python gen_landmark_aug.py 48 生成onet的landmark数据,
python gen_tfrecords.py 48生成tfrecords文件
将目录cd到train上python train.py 48 训练onet

测试:

python test.py

一些建议

生成hard_example时间非常长需要三到四小时,所以如果你想从头训练请耐心等待,如果代码或理解有什么问题,欢迎批评指正。

结果展示

测试图片来源百度图片,测试结果如下:



tensorflow-mtcnn's People

Contributors

lesliezhoa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-mtcnn's Issues

About boxes

测试的时候,有时候会出现人脸上有好几框的情况,如何解决?

关于训练自己数据集的问题

我将landmark用的数据集换成了我自己的数据集,process对数据的处理都没有问题,也成功生成了tfcord文件,但是在训练p-net的时候却遇到了如下错误。求解。
2019-08-03 17:03:27.010008: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
Traceback (most recent call last):
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _do_call
return fn(*args)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
[[Node: gradients/PNet/TopKV2_2_grad/Reshape = Reshape[T=DT_INT32, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](PNet/TopKV2_2:1, gradients/PNet/TopKV2_2_grad/stack)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train.py", line 58, in
main(parse_arguments(sys.argv[1:]))
File "train.py", line 41, in main
train(net_factory,prefix,end_epoch,base_dir,display,lr)
File "/home/lishuwei/dsx/tensorflow-MTCNN-master/train/train_model.py", line 122, in train
,,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
[[Node: gradients/PNet/TopKV2_2_grad/Reshape = Reshape[T=DT_INT32, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](PNet/TopKV2_2:1, gradients/PNet/TopKV2_2_grad/stack)]]

Caused by op 'gradients/PNet/TopKV2_2_grad/Reshape', defined at:
File "train.py", line 58, in
main(parse_arguments(sys.argv[1:]))
File "train.py", line 41, in main
train(net_factory,prefix,end_epoch,base_dir,display,lr)
File "/home/lishuwei/dsx/tensorflow-MTCNN-master/train/train_model.py", line 82, in train
train_op,lr_op=optimize(base_lr,total_loss_op,num)
File "/home/lishuwei/dsx/tensorflow-MTCNN-master/train/train_model.py", line 159, in optimize
train_op = optimizer.minimize(loss, global_step)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 414, in minimize
grad_loss=grad_loss)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 526, in compute_gradients
colocate_gradients_with_ops=colocate_gradients_with_ops)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 494, in gradients
gate_gradients, aggregation_method, stop_gradients)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 636, in _GradientsHelper
lambda: grad_fn(op, *out_grads))
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 385, in _MaybeCompile
return grad_fn() # Exit early
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 636, in
lambda: grad_fn(op, *out_grads))
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/nn_grad.py", line 978, in _TopKGrad
ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6113, in reshape
"Reshape", tensor=tensor, shape=shape, name=name)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
op_def=op_def)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

...which was originally created as op 'PNet/TopKV2_2', defined at:
File "train.py", line 58, in
main(parse_arguments(sys.argv[1:]))
[elided 0 identical lines from previous traceback]
File "train.py", line 41, in main
train(net_factory,prefix,end_epoch,base_dir,display,lr)
File "/home/lishuwei/dsx/tensorflow-MTCNN-master/train/train_model.py", line 80, in train
label,bbox_target,landmark_target,training=True)
File "/home/lishuwei/dsx/tensorflow-MTCNN-master/train/model.py", line 41, in P_Net
landmark_loss=landmark_ohem(landmark_pred,landmark_target,label)
File "/home/lishuwei/dsx/tensorflow-MTCNN-master/train/model.py", line 219, in landmark_ohem
square_error,_=tf.nn.top_k(square_error,k=keep_num)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 2352, in top_k
return gen_nn_ops.top_kv2(input, k=k, sorted=sorted, name=name)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 7660, in top_kv2
"TopKV2", input=input, k=k, sorted=sorted, name=name)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
op_def=op_def)
File "/home/lishuwei/anaconda3/envs/dsx_tf/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
[[Node: gradients/PNet/TopKV2_2_grad/Reshape = Reshape[T=DT_INT32, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](PNet/TopKV2_2:1, gradients/PNet/TopKV2_2_grad/stack)]]

about gpu

i want to ask how to set the parameter to train data by using gpu?

r_net、o_net

你好哦!我想问一下r_net的neg、pos、part的数据集只是用p_net进行hard_example处理嘛?需不需要还对原始的wideface数据集进行处理?

在其它博客看到如下的描述:
由原始图片和PNet生成预测的bounding boxes;
输入原始图片和PNet生成的bounding box,通过RNet,生成校正后的bounding box;
输入元素图片和RNet生成的bounding box,通过ONet,生成校正后的bounding box和人脸面部轮廓关键点。

训练标签

作者你好,我想问下有关训练标签的问题。我自己制作的小的数据集已经标记完成,也已经做好对应的txt文件。
s1
但是看到在data文件下wider face的训练标签文件有两个。
TIM截图20190315101210
一个是wider_face_train.txt,一个是wider_face_train_bbx_gt.txt,train_bbx_gt文件的边框格式是x1,y1,w,h,我对应的看了下train.txt文件,发现这个格式应该是x1,y1,x2,y2的格式,但是不明白为啥边框数据还带着小数,虽然这个对边框位置感觉影响不大
train
另一个问题是根据您仓库中gen_12net_data.py中这一句来看
gen,真正的训练标签格式是x1,y1,x2,y2格式。请问您这样理解是否是对的

p网络的检测框

你好,p网络检测出来的框,在输入R网络前,没有进行回归吗?

关于测试后输出问题

想问问为什么更新了picture之后,进行test之后output是没有更新的。
希望可以得到您的解答

hard_example转换问题

请问楼主,你在util函数中,convert_to_square子函数转换为更大正方形极有可能超出图像索引,请问这样是否正确?

训练自己的模型

可以自己训练一个其他的模型吗?比如:检测猫头或狗头😁
如果可以,代码部分有要改的吗?谢谢

关于缺少input_size

train.py: error: the following arguments are required: input_size
请问这个报错是什么原因

关于数据集的选择和制作问题

你好,小白有一个脑残问题,既然lfw数据集既有bbox数据,也有landmark数据,为什么要分别用wider_face数据集来训练bbox,lfw来训练landmark呢? 而且这样就造成了P-net的训练过程中每条样本就只能计算一种损失,期待您的回复

关于测试结果

运行 test.py
只输出一个结果图片并一直保持不变,
这是为什呢?
并有如下提示:
**Qt: Cannot set locale modifiers:**

关于精度问题

你好,请问这个模型训练出来在FDDB数据上的精度能达到多少呢?

摄像头检测使用

请问除了图片,摄像头的检测怎么使用?十分感谢希望能够解答

关于Pnet的训练数据的问题

请问一下,产生P-net的训练数据时。所有的postive,part的偏移量值是否都在-1到1之间。而landmark的偏移量值是否都在0到1之间?

windows上能否运行该程序

您好,不知道该MTCNN项目能不能在windows上运行呢?如不能的话,有windows上相关的资源吗?如您方便时能回复,万分感谢

经过PNet检测后输出的图像是不是明显变少了呀?

经过PNet检测后输出的图像是不是明显变少了呀?
这些输出图像用来训练RNet,但是我发现这些图像的3类数据特别少
比如我用来训练PNet的数据多达1750000
但是用来训练RNet的数据只有7000,pos只有63个。。。
请问博主有发生这种情况吗?该怎么解决呢?

Training speed on GPU

Hi, I'm trying this project. The speed on GPU is lower than CPU in training phase, WHY?

网络推断速度问题

我的pnet训练准确率为93%左右,在图片较大时推断时会生成成千上万个候选框,导致NMS就会花费2s左右。且网络推断速度很慢pnet0.5s左右,gpu是1080ti,已绝望,请问楼主的推断速度是多少呢,有没有注意过候选框数量的问题

运行 gen_hard_example.py这个文件,跑了一晚还没才处理了800张照片??

你好,非常高兴能够看到这么通俗易懂的代码。
因为项目需要,我对代码进行了修改,删除了landmark识别,将网络的保存格式转为了pb模型,并用了tf.layers api进行了改写网络,加入了batch_nromal。

目前,程序运行至生成hard_example 12 阶段。
前面几张图片处理速度还比较快,大概3秒一张,但是处理到800多张的时候,却需要2分钟一张了,想问一下,有没有人遇到这个问题,可以一起讨论一下。

centos7
8 core
16 mem
1080ti

关于loss函数问题

./train/models.py 166

image

在求face classification的时候,代码中loss中使用的是

P1=y*log(p(y))

(这个应该是用在多分类问题中吧),但是论文里面说了这个是二分类问题,应该选用的是

P2=y*log(p(y))+(1-y)*log(1-p(y))

毕竟这个两个求出的相关交叉熵函数是不一样的,可能在optimizer的时候有差异。
例如: y =1,p(y)=0.0, P1=log(0.6); P2=log(0.6)+log(0.4)

(优化未能充分说明,结果待验证)

楼主怎么看?

o_net hard_example报错

lunar@lunar-virtual-machine:~/PycharmProjects/tensorflow_mtcnn/preprocess$ python gen_hard_example.py 24
2019-05-12 09:33:06.690596: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
载入数据
8247it [2:47:10, 1.26it/s]Traceback (most recent call last):
File "gen_hard_example.py", line 193, in
main(parse_arguments(sys.argv[1:]))
File "gen_hard_example.py", line 70, in main
detectors,_=mtcnn_detector.detect_face(test_data)
File "../detection/MtcnnDetector.py", line 87, in detect_face
boxes, boxes_c, landmark = self.detect_rnet(im, boxes_c)
File "../detection/MtcnnDetector.py", line 181, in detect_rnet
cls_scores, reg, _ = self.rnet_detector.predict(cropped_ims)
File "../detection/detector.py", line 63, in predict
return np.concatenate(cls_prob_list, axis=0), np.concatenate(bbox_pred_list, axis=0), np.concatenate(landmark_pred_list, axis=0)
ValueError: need at least one array to concatenate

请问一下这是什么报错?谢谢。

关于删减和优化

您好,跑到第二个步骤
2 python gen_landmark_aug.py 12 生成pnet的landmark数据,

文件中只有python gen_landmark_aug.py 这个文件,请问要在哪里给input_size赋值?
问题可能比较小白~期待您的答复。

对image,landmark进行flip,却没有对bbox进行flip?

`image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch, landmark_batch])

随机翻转图像

image_batch_array, landmark_batch_array = random_flip_images(image_batch_array, label_batch_array, landmark_batch_array)
_, _, summary = sess.run([train_op, lr_op, summary_op], feed_dict={input_image: image_batch_array,label: label_batch_array,bbox_target: bbox_batch_array,landmark_target: landmark_batch_array})`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.