Comments (4)
打扰作者了,更新一下讯息。我刚调整了一下数据集,增加到了32万(原本只用了2万小测一下),然后loss在第一个epoch刚跑了200steps,就开始有值了,此时loss在70+。 没有做其他改变。 不过,增加数据行为,按理说也应该至少跑完一个epoch,才能看出效果,仅仅200steps(batch_size=8),居然相比从前就有这么大区别,令人不解。
from crnn.tf2.
from crnn.tf2.
您好,我在自己写一个数据预处理的流程,但是我写完后,accuracy一直是0,我觉得我对于您的padding这块没太搞懂。
- dataset_factory.py 的
def build(self, data, batch_size, is_training)
方法里面调用了ds.padded_batch
,我看了下源码,这里如果不设置padding_values,默认用0填充,也就是您的labels是用0补齐的, 0在table.txt中对应。 然后losses.py中blank_index设置的是-1 。这是为什么呢? 我print了一下您的losses中的y_true,我猜测是因为dataset_factory输出的SparseTensor,所以即便用0进行padding,也被稀疏化的过程抹去了? - 另外,我看tf.nn.ctc_loss的注释中有这样一段话On TPU and GPU: Only dense padded labels are supported. 我自己搞的预处理,输出不是SparseTensor,而是DenseTensor,我的做法是用0把label补齐,然后以稠密矩阵送入tf.nn.ctc_loss,blank_index=0. 但是发现当blank_index=0时,训练速度急剧下降,慢了近3倍。如果我用-1补齐label,blank_index=-1,这样速度就很快,不知为什么。
- 但是无论怎么设置,accuracy=0. 数据集用的 360万中文文档数据集,数据应该没什么问题,一个公开的,被广泛使用了的中文数据集。
我贴部分关键代码,如下。
每个batch数据的处理,这里我改用了 tensorflow.keras.utils.Sequence ,贴出的是 def __getitem__(self, idx)
中的部分代码:
把本batch的图补齐等长
for i in range(len(train_image_data)):
template = np.zeros((32, batch_max_width, 3))
template[:32, :train_image_data[i].shape[1], :] = train_image_data[i]
train_image_data[i] = template
train_image_data = np.array(train_image_data)
把本batch的label补齐等长
for i in range(len(labels)):
template = np.zeros(batch_max_label_length)
template[: len(labels[i]) ] = labels[i]
labels[i] = template
labels = np.array(labels)
相应的losses.py改为了
def call(self, y_true, y_pred):
y_true = tf.cast(y_true, tf.int32)
logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
label_length = tf.fill([tf.shape(y_true)[0]], tf.shape(y_true)[1])
loss = tf.nn.ctc_loss(
labels=y_true,
logits=y_pred,
label_length=label_length,
logit_length=logit_length,
logits_time_major=False,
blank_index=0)
return tf.reduce_mean(loss)
相应metrics.py 改为
def update_state(self, y_true, y_pred, sample_weight=None):
y_true_shape = tf.shape(y_true)
batch_size = y_true_shape[0]
y_pred_shape = tf.shape(y_pred)
max_width = tf.maximum(y_true_shape[1], y_pred_shape[1])
logit_length = tf.fill([batch_size], y_pred_shape[1])
decoded, _ = tf.nn.ctc_greedy_decoder(
inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
sequence_length=logit_length)
y_pred = self.to_dense(decoded[0], [batch_size, max_width])
仅这里有变化,我觉得我对您的-1理解有欠缺,所以在这里用from_dense把补的0抹掉了,变稀疏矩阵,再用您的to_dense,也就是用-1补齐了,恢复成您源码中输入给tf.math.reduce_any的数据原貌。
y_true = tf.sparse.from_dense(y_true)
y_true = self.to_dense(y_true, [batch_size, max_width])
num_errors = tf.math.reduce_any(
tf.math.not_equal(y_true, y_pred), axis=1)
num_errors = tf.cast(num_errors, tf.float32)
num_errors = tf.reduce_sum(num_errors)
batch_size = tf.cast(batch_size, tf.float32)
self.total.assign_add(batch_size)
self.count.assign_add(batch_size - num_errors)
label的制作,和您一样,首部UNK, 尾部BLK
from crnn.tf2.
- 请注意padding上下代码的顺序,padding的时候,label还是字符串,所以并没有补0
- 我曾经试过DenseTensor,在这样的逻辑结构下,不如使用SparseTensor。使用CPU更快。
from crnn.tf2.
Related Issues (20)
- UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 363: illegal multibyte sequence HOT 5
- Load model in OpenCV HOT 3
- Prediction accuracy in ONNX Runtime HOT 1
- About Mjsynth HOT 1
- What is validation and test accuracy on MJSynth dataset? HOT 1
- after special characters are predicting <UNK> HOT 1
- tf2.3 not find "keras.optimizers.schedules.CosineDecay" HOT 1
- Ragged Tensors in Keras HOT 3
- Is 'Sequence Accuracy' referring only to exact matches? HOT 4
- Does CRNN support text line images? HOT 1
- predict not True HOT 1
- can i help me HOT 3
- 作者你好,请问你的SavedModel是用什么数据训练的? HOT 1
- 你好,首先很感谢你的工作,获益很多,请教一下您有试过添加白名单功能吗 HOT 1
- chinese predicting problems
- How to get the Last layer (ctc_greedy_decoder)? HOT 1
- Invalid argument error HOT 2
- Expected concatenating dimensions in the range [-1, 1), but got 1 [Op:ConcatV2] name: concat HOT 1
- Handling invalid image path or corrupted image files.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from crnn.tf2.