Git Product home page Git Product logo

Comments (4)

ddddddreamcastle avatar ddddddreamcastle commented on June 4, 2024

打扰作者了,更新一下讯息。我刚调整了一下数据集,增加到了32万(原本只用了2万小测一下),然后loss在第一个epoch刚跑了200steps,就开始有值了,此时loss在70+。 没有做其他改变。 不过,增加数据行为,按理说也应该至少跑完一个epoch,才能看出效果,仅仅200steps(batch_size=8),居然相比从前就有这么大区别,令人不解。

from crnn.tf2.

FLming avatar FLming commented on June 4, 2024

from crnn.tf2.

ddddddreamcastle avatar ddddddreamcastle commented on June 4, 2024

您好,我在自己写一个数据预处理的流程,但是我写完后,accuracy一直是0,我觉得我对于您的padding这块没太搞懂。

  1. dataset_factory.py 的 def build(self, data, batch_size, is_training) 方法里面调用了ds.padded_batch,我看了下源码,这里如果不设置padding_values,默认用0填充,也就是您的labels是用0补齐的, 0在table.txt中对应。 然后losses.py中blank_index设置的是-1 。这是为什么呢? 我print了一下您的losses中的y_true,我猜测是因为dataset_factory输出的SparseTensor,所以即便用0进行padding,也被稀疏化的过程抹去了?
  2. 另外,我看tf.nn.ctc_loss的注释中有这样一段话On TPU and GPU: Only dense padded labels are supported. 我自己搞的预处理,输出不是SparseTensor,而是DenseTensor,我的做法是用0把label补齐,然后以稠密矩阵送入tf.nn.ctc_loss,blank_index=0. 但是发现当blank_index=0时,训练速度急剧下降,慢了近3倍。如果我用-1补齐label,blank_index=-1,这样速度就很快,不知为什么。
  3. 但是无论怎么设置,accuracy=0. 数据集用的 360万中文文档数据集,数据应该没什么问题,一个公开的,被广泛使用了的中文数据集。

我贴部分关键代码,如下。

每个batch数据的处理,这里我改用了 tensorflow.keras.utils.Sequence ,贴出的是 def __getitem__(self, idx) 中的部分代码:

       把本batch的图补齐等长
        for i in range(len(train_image_data)):
            template = np.zeros((32, batch_max_width, 3))
            template[:32, :train_image_data[i].shape[1], :] = train_image_data[i]
            train_image_data[i] = template
        train_image_data = np.array(train_image_data)

        把本batch的label补齐等长
        for i in range(len(labels)):
            template = np.zeros(batch_max_label_length)  
            template[: len(labels[i]) ] = labels[i]
            labels[i] = template
        labels = np.array(labels) 

相应的losses.py改为了

 def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
        label_length = tf.fill([tf.shape(y_true)[0]], tf.shape(y_true)[1])
        loss = tf.nn.ctc_loss(
            labels=y_true,
            logits=y_pred,
            label_length=label_length,
            logit_length=logit_length,
            logits_time_major=False,
            blank_index=0)
        return tf.reduce_mean(loss) 

相应metrics.py 改为

def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_shape = tf.shape(y_true)
        batch_size = y_true_shape[0]
        y_pred_shape = tf.shape(y_pred)
        max_width = tf.maximum(y_true_shape[1], y_pred_shape[1])
        logit_length = tf.fill([batch_size], y_pred_shape[1])      
        decoded, _ = tf.nn.ctc_greedy_decoder(
            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
            sequence_length=logit_length)
        y_pred = self.to_dense(decoded[0], [batch_size, max_width])
         仅这里有变化,我觉得我对您的-1理解有欠缺,所以在这里用from_dense把补的0抹掉了,变稀疏矩阵,再用您的to_dense,也就是用-1补齐了,恢复成您源码中输入给tf.math.reduce_any的数据原貌。
        y_true = tf.sparse.from_dense(y_true)
        y_true = self.to_dense(y_true, [batch_size, max_width])
        num_errors = tf.math.reduce_any(
            tf.math.not_equal(y_true, y_pred), axis=1)
        num_errors = tf.cast(num_errors, tf.float32)
        num_errors = tf.reduce_sum(num_errors)
        batch_size = tf.cast(batch_size, tf.float32)
        self.total.assign_add(batch_size)
        self.count.assign_add(batch_size - num_errors) 

label的制作,和您一样,首部UNK, 尾部BLK

from crnn.tf2.

FLming avatar FLming commented on June 4, 2024

@ddddddreamcastle

  1. 请注意padding上下代码的顺序,padding的时候,label还是字符串,所以并没有补0
  2. 我曾经试过DenseTensor,在这样的逻辑结构下,不如使用SparseTensor。使用CPU更快。

from crnn.tf2.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.