Git Product home page Git Product logo

bert-classification-tutorial's People

Contributors

ejld avatar linktopast1990 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bert-classification-tutorial's Issues

你好,是直接将data下的train.csb和dev.csv文件复制到MRPC目录下是吧,覆盖掉原先的

你好,是直接将data下的train.csb和dev.csv文件复制到MRPC目录下是吧,覆盖掉原先的,覆盖掉好,然后运行报如下错误:
InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [10,768] rhs shape= [3668,768]

怎么使用GPU模式的

我安装了 tensorflow-gpu==1.10.0 版本 但是启动的时候 cpu还是爆满的 gpu没有什么明显变化 ,他也提示
INFO:tensorflow:***** Running training *****
INFO:tensorflow: Num examples = 567
INFO:tensorflow: Batch size = 32
INFO:tensorflow: Num steps = 177
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Running train on CPU

验证的精度,只有0.1,为什么?

使用的命令为:
python run_classifier.py --task_name=MRPC --do_train=true --do_eval=true --data_dir=data/ --vocab_file=uncased_L-12_H-768_A-12/vocab.txt --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt --output_dir=mrpc_output/

结果如下:
2019-03-03 17:41:02.999284: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1103] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 9947 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1)
INFO:tensorflow:Restoring parameters from mrpc_output/model.ckpt-312
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2019-03-03-09:41:04
INFO:tensorflow:Saving dict for global step 312: eval_accuracy = 0.1, eval_loss = 2.3766909, global_step = 312, loss = 2.3766909
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 312: mrpc_output/model.ckpt-312
INFO:tensorflow:evaluation_loop marked as finished
INFO:tensorflow:***** Eval results *****
INFO:tensorflow: eval_accuracy = 0.1
INFO:tensorflow: eval_loss = 2.3766909
INFO:tensorflow: global_step = 312
INFO:tensorflow: loss = 2.3766909
是由于训练步问题吗?

issues about new data which 7 classification

我现在新的数据是7大类英文文本。train, dev我都变成了 样本数据一样的格式。
但是我跑的时候有错误。好像是7大类与10大类不符合。这个代码10大类分类,我想知道哪里修改。

[Caused by op 'save/Assign_602', defined at:
File "run_classifier.py", line 929, in
tf.app.run()
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "run_classifier.py", line 848, in main
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
saving_listeners=saving_listeners
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1241, in _train_model_default
saving_listeners)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec
log_step_count_steps=log_step_count_steps) as mon_sess:
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 921, in init
stop_grace_period_secs=stop_grace_period_secs)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 643, in init
self._sess = _RecoverableSession(self._coordinated_creator)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1107, in init
_WrappedSession.init(self, self._create_session())
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 1112, in _create_session
return self._sess_creator.create_session()
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 800, in create_session
self.tf_sess = self._session_creator.create_session()
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 557, in create_session
self._scaffold.finalize()
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py", line 215, in finalize
self._saver.build()
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1114, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1151, in _build
build_save=build_save, build_restore=build_restore)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 789, in _build_internal
restore_sequentially, reshape)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 459, in _AddShardedRestoreOps
name="restore_shard"))
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 428, in _AddRestoreOps
assign_ops.append(saveable.restore(saveable_tensors, shapes))
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 119, in restore
self.op.get_shape().is_fully_defined())
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 221, in assign
validate_shape=validate_shape)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_state_ops.py", line 61, in assign
use_locking=use_locking, name=name)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
op_def=op_def)
File "/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1770, in init
self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [7,768] rhs shape= [10,768]
[[node save/Assign_602 (defined at /anaconda3/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py:2403) = Assign[T=DT_FLOAT, _class=["loc:@output_weights/adam_v"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](output_weights/adam_v, save/RestoreV2:603)]]](url)

能导出环境配置文件?跑了你的项目报错了

跑了你的项目,Git Bash终端报错numpy无法导入之类的错误,我无法确定numpy需要安装什么版本的

我尝试tensorflow==1.15,但依旧报错

请项目开发者输入

# conda
conda env export --file requirements.txt
# 原生python
pip freeze > requirements.txt

导出requirements.txt,并且把环境配置文件添加至项目中

关于显卡显存

看到您文章里写的,加载BERT使用显存9.5G,请问您使用的是11G显存的单卡还是多个卡并联的?

关于预处理的问题

你好,我简单看了下BERT对中文是用字做embedding,我想请问下使用bert之后,是否还需要自己做一些文本清洗,去停词什么的。我现在是中文和英文混着的情感分析任务,跑通了你的代码后验证集大概90左右精度,请问对于精度提升还有什么建议吗?

关于中文二分类问题

修改为中文数据源,标签总共两类,每一个类别的数据样本条数是1000
格式类似 character1空格character2空格character3\tlabel
看最终看eval_result.txt文件准确率一直是0.5
但跟了下代码,输入与预期是相同,不知道为什么一直是0.5,有遇到类似的问题吗

关于多文本分类任务

最近我也在尝试多分类的fine-tuning ,看了paper和源码,感觉直接用run_classifier.py中的ColaProcessor 进行多分类即可,而且不用调整_create_examples这个函数,只需要调整get_labels这个函数中label list。
而且paper中关于fine-tuning的使用也在figure 3 进行了阐述,博主为啥用MRPC?

max_seq_length的最大值不超过512

短序列(长度128)比长序列(长度512)更耗费计算力。
有人试过取max_seq_length的值为512吗?
中文的文本分类,如果取max_seq_length为128,文本的长度会不会太短了,不符合一般的使用场景?

关于预测准确率

我用您的代码跑了一下数据,eval accuracy 只有0.1,predict 的结果并没有图片中的80%那么高?

Tokens difference between bert project and yours. Is run_classifier.py the only changed file?

For English, the tokens are same.
But for Chinese, the tokens are different when I use the same run_classifier.py

Using https://github.com/google-research/bert
INFO:tensorflow:*** Example ***
INFO:tensorflow:guid: train-5
INFO:tensorflow:tokens: [CLS] 1 。 我 住 的 是 靠 马 路 的 标 准 间 。 房 间 内 设 施 简 陋 , 并 且 的 房 间 玻 璃 窗 户 外 还 有 一 层 幕 墙 玻 璃 , 而 且 不 能 打 开 , 导 致 房 间 不 能 自 然 通 风 , 采 光 不 好 。 [SEP]

Using your project
INFO:tensorflow:*** Example ***
INFO:tensorflow:guid: train-5
INFO:tensorflow:tokens: [CLS] 1 。 我 ##住 ##的 ##是 ##靠 ##马 ##路 ##的 ##标 ##准 ##间 。 房 ##间 ##内 ##设 ##施 ##简 ##陋 , 并 ##且 ##的 ##房 ##间 ##玻 ##璃 ##窗 ##户 ##外 ##还 ##有 ##一 ##层 ##幕 ##墙 ##玻 ##璃 , 而 ##且 ##不 ##能 ##打 ##开 , 导 ##致 ##房 ##间 ##不 ##能 ##自 ##然 ##通 ##风 , 采 ##光 ##不 ##好 。 [SEP]

emmm 这里的classification 好像不止改了一点点

def get_test_examples(self, data_dir): return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "test")
run_classification 里修改的部分还有上面这条把 test.tsv改成了dev.tsv
做预测时用的时dev.tsv
这也是data里面没有test集的原因
输出本来就不直观,查了半天发现数据量都不对
悄悄改了文件还是在说明里提一下的好

问题

你好,
请问data数据集里为什么没有test.tsv?
最后的参数设置只需设置
flags.DEFINE_string(
"data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
就可以了吗?
比如在你的项目下,把None改成'/data'吗?
最后运行run_classifier.py这个文件,是这样吗?
谢谢回复〜

eval_drop_remainder = True if FLAGS.use_tpu else Falsed

CPU下,若eval样本数不能被batch_size整除,则

INFO:tensorflow: name = input_ids, shape = (?, 128)
INFO:tensorflow: name = input_mask, shape = (?, 128)
INFO:tensorflow: name = label_ids, shape = (?,)
INFO:tensorflow: name = segment_ids, shape = (?, 128)

修改为:eval_drop_remainder = True

do_eval的问题

我在训练的时候“do_eval”和用训练好的模型“do_eval”得到的结果差别非常大,前者80%,后者只有1%,感觉很奇怪

中文乱码

用的是chinese_L-12_H-768_A-12,train.tsv和dev.tsv都是UTF-8。

同时报错
keuerror: "cad"

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.