Git Product home page Git Product logo

zh_cnn_text_classify's Introduction

基于cnn的中文文本分类算法

简介

参考IMPLEMENTING A CNN FOR TEXT CLASSIFICATION IN TENSORFLOW实现的一个简单的卷积神经网络,用于中文文本分类任务(此项目使用的数据集是中文垃圾邮件识别任务的数据集),数据集下载地址:百度网盘

区别

原博客实现的cnn用于英文文本分类,没有使用word2vec来获取单词的向量表达,而是在网络中添加了embedding层来来获取向量。
而此项目则是利用word2vec先获取中文测试数据集中各个的向量表达,再输入卷积网络进行分类。

运行方法

训练

run python train.py to train the cnn with the spam and ham files (only support chinese!) (change the config filepath in FLAGS to your own)

在tensorboard上查看summaries

run tensorboard --logdir /{PATH_TO_CODE}/runs/{TIME_DIR}/summaries/ to view summaries in web view

测试、分类

run python eval.py --checkpoint_dir /{PATH_TO_CODE/runs/{TIME_DIR}/checkpoints}
如果需要分类自己提供的文件,请更改相关输入参数

如果需要测试准确率,需要指定对应的标签文件(input_label_file):
python eval.py --input_label_file /PATH_TO_INPUT_LABEL_FILE
说明:input_label_file中的每一行是0或1,需要与input_text_file中的每一行对应。
在eval.py中,如果有这个对照标签文件input_label_file,则会输出预测的准确率

推荐运行环境

python 2.7.13 :: Anaconda 4.3.1 (64-bit)
tensorflow 1.0.0
gensim 1.0.1
Ubuntu16.04 64bit

说明

若按照以上步骤无法正常运行程序,请在Issues或在博客中提问,我会尽快回复。

zh_cnn_text_classify's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

zh_cnn_text_classify's Issues

关于训练的问题

在运行train.py的时候系统会提示如下错误
File "/home/robin/Downloads/zh_cnn_text_classify-master/train.py", line 54
out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
^
SyntaxError: invalid syntax
我该如何进行修改,我的理解是out_dir应该是个变量命名,理论上不应该报错

配置中遇到一些问题,麻烦作者给解决一下

1、这是在巡行train文件时候提示的错误
问题:InvalidArgumentError (see above for traceback): Input to reshape is a tensor wit
h 8192 values, but the requested shape requires a multiple of 384
[[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:lo
输出提示信息为:
calhost/replica:0/task:0/cpu:0"](concat, Reshape/shape)]]
、Traceback (most recent call last):
File "train.py", line 193, in
train_step(x_batch, y_batch)
File "train.py", line 164, in train_step
feed_dict)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client
session.py", line 778, in run
run_metadata_ptr)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client
session.py", line 982, in _run
feed_dict_string, options, run_metadata)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client
session.py", line 1032, in _do_run
target_list, options, run_metadata)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client
session.py", line 1052, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape i
s a tensor with 8192 values, but the requested shape requires a multiple of 384
[[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:lo
calhost/replica:0/task:0/cpu:0"](concat, Reshape/shape)]]

Caused by op 'Reshape', defined at:
File "train.py", line 108, in
l2_reg_lambda = FLAGS.l2_reg_lambda)
File "C:\Users\Administrator\PycharmProjects\zh_cnn_text_classify-master\text_
cnn.py", line 56, in init
self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen
_array_ops.py", line 2510, in reshape
name=name)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framewo
rk\op_def_library.py", line 768, in apply_op
op_def=op_def)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framewo
rk\ops.py", line 2336, in create_op
original_op=self._default_original_op, op_def=op_def)
File "d:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framewo
rk\ops.py", line 1228, in init
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Input to reshape is a tensor wit
h 8192 values, but the requested shape requires a multiple of 384
[[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:lo
calhost/replica:0/task:0/cpu:0"](concat, Reshape/shape)]]
2、在运行eval文件中input_label_file这个参数如何设置?
本人对这方面感兴趣,麻烦作者帮忙解决!!!

打扰了,关于测试文件

您好,感谢您的分享!请问训练完之后自己的测试数据集也一定要是.utf8格式的吗?txt文件可行吗?
= = 主要我是不知道如何生成utf8文件 谢谢您了

关于数据问题

看您data文件夹下的文件是原始文件,没有经过jieba的分词处理。想请教下,训练时,使用原始数据或分词处理后的数据,对结果是否有影响?哪种比较好?
谢谢!

疑问。

您好,这个网络是您自己设计的吗?

池化层的设置有些疑问,望赐教。

作者,你好,我原来看到的池化层ksize = [1,4,4,1]即第二维和第三维相同,但你的池化层设置成ksize=[1, sequence_length - filter_size + 1, 1, 1],不知道这样设置有何特殊作用?

环境配置问题,麻烦作者指导一下

tensorflow 1.0.0
gensim 1.0.1这两个包我安装的版本不一样可以用吗?我用的是anaconda3.5,环境是Python2.7.13
然后我运行报错,小白不知道怎么解决.
/usr/anaconda3/envs/python27/bin/python2.7 /home/Snake/PycharmProjects/zh_cnn_text_classify/train.py
Traceback (most recent call last):
File "/home/Snake/PycharmProjects/zh_cnn_text_classify/train.py", line 44, in
FLAGS._parse_flags()
File "/usr/anaconda3/envs/python27/lib/python2.7/site-packages/tensorflow/python/platform/flags.py", line 85, in getattr
return wrapped.getattr(name)
File "/usr/anaconda3/envs/python27/lib/python2.7/site-packages/absl/flags/_flagvalues.py", line 472, in getattr
raise AttributeError(name)
AttributeError: _parse_flags

Process finished with exit code 1

padding_sentence的bug

data_helper.py里面
padding_sentences函数并不能对过长的句子进行截断

bug出现在 sentence = sentence[:max_sentence_length]
这个只能修改sentence变量并不能把sentence的修改应用到sentences列表里
测试代码:
a=[
[1,2,3,4],
[1,2,3,4,5,6],
[1,2,3,4,5,6,7,8,9],
]

for s in a:
if len(s) > 5:
a=a[:5]
print(a)

[[1, 2, 3, 4], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9]]

请问

我在运行你的代码的时候,在进行预测的时候,模型给出的预测值一直是固定的数。我想请问一下,这个是什么情况?而且,在你给的prediction.csv这个文件中您的预测值也都是0.0请问这个是什么情况?

eval.py运行出现问题

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_DIR=./runs/1498711744/checkpoints
EVAL_TRAIN=True
INPUT_LABEL_FILE=./data/testlabel.txt
INPUT_TEXT_FILE=./data/test.txt
LOG_DEVICE_PLACEMENT=False

Cannot find a valid checkpoint file!

运行出现上面的问题,是什么原因呢

使用自己的数据训练模型出现如下错误?

您好,我现在想使用自己的文本数据训练模型,一共有12类,自己重写了data_helper.py文件,但是运行batch_iter函数的时候出现了如下错误请问是怎么回事呢?

`[root@localhost WebClassify]# python data_helper.py
len of x_train is: 144171
len of y_train is: (144171, 12)

Traceback (most recent call last):
File "data_helper.py", line 113, in
for batch in batches:
File "data_helper.py", line 88, in batch_iter
shuffled_data = data[shuffle_indices]
MemoryError
`

运行时遇到些问题,求解释下及该怎么做

2017-05-23 17:09:08.518960: E c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\stream_executor\cuda\cuda_driver.cc:1037] failed to synchronize the stop event: CUDA_ERROR_LAUNCH_FAILED
2017-05-23 17:09:08.519396: E c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\stream_executor\cuda\cuda_timer.cc:54] Internal: error destroying CUDA event in context 000002A7F6F70040: CUDA_ERROR_LAUNCH_FAILED
2017-05-23 17:09:08.519922: E c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\stream_executor\cuda\cuda_timer.cc:59] Internal: error destroying CUDA event in context 000002A7F6F70040: CUDA_ERROR_LAUNCH_FAILED
2017-05-23 17:09:08.520451: F c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\stream_executor\cuda\cuda_dnn.cc:2478] failed to enqueue convolution on stream: CUDNN_STATUS_EXECUTION_FAILED

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.