Git Product home page Git Product logo

multi-source-pointer-network's Introduction

多源指针网络生成短标题

Multi-Source Pointer Network


一、模型说明:

🙏 基于论文"Multi-Source Pointer Network for Product Title Summarization"开发,模型的具体介绍以及短标题的生成参考: https://mp.weixin.qq.com/s/5rAM44D50JHE-q1IrLEatw,在论文的基础上进行了如下几方面的优化:

  • 1️⃣ 融合词语和字符特征,解决测试阶段多个OOV词汇共享同一个Embedding的问题。

  • 2️⃣ 针对不同特征,编码器采用多组Transformer,然后对特征进行了融合。

  • 3️⃣ 推理阶段采用局部拷贝机制,解决全局拷贝无法解决#UNK#的情形。

  • 4️⃣ Beam Search阶段对初始分布添加mask,防止一开始生成 '#EOS#' 这一token。

  • ♻️ 如果采用GPU,所有batch的数据会一次性加载到GPU中,可以修改data.py分批加载。

二、网络结构:

模型编码器部分的整体结构如下图所示,解码器部分后面会进行补充。

encoder

模型在推理阶段采用了局部拷贝机制,将每个训练样本的source1和source2的tokens动态地映射成局部id,计算token概率分布,并按照局部token id进行token的概率合并。

Step-1: 对于每个训练样本,利用两个输入源的tokens来动态构地造局部的token词典。

Step-2: 基于该token词典,分别将两个输入源的tokens序列映射成局部id序列。

Step-3: 分别得到两个输入源的所有tokens关于解码器的概率分布source1 probs和source2 probs。

Step-4: 将两个输入源的概率分布进行merge,并取概率最大的token id。

Step-5: 利用Step-1构造的局部动态词典将概率最大的token id映射成对应的token,得到当前时刻的预测token。

inference

三、运行方式:

  1. 数据格式:

    • 训练/验证数据:训练数据和验证数据分别有3列,即:source1,source2和target。
    • 在数据预处理阶段过滤了字包含英文、数字、标点等token。
    • 测试数据:只需要两列,即source1,source2。
    • 需要注意的是:两个source和target的数据需要进行分词,词语之间用空格分隔。
  2. 模型训练:

    • 训练阶段利用到的是 main.py 脚本中的 train() 函数,训练阶段可以根据情况把 main.py 脚本中的 test() 函数注释掉。

    • 准备好训练数据后按照下面的示例进行训练,或者修改 config.py 中相关训练数据的路径,然后直接执行 python3 main.py如果需要使用GPU,参考config.py 中的相关参数。

      python3 main.py --train_data_path ../data/train.dat \
                      --valid_data_path ../data/valid.dat \
                      --model_dir ../model/ \
                      --max_epoch 50
  3. 模型预测:

    • 预测阶段利用到的是 main.py 脚本中的 test() 函数,预测阶段需要把 train() 函数注释掉。
    • 预测阶段,可以参考 main.py 脚本中给定的测试数据格式、模型加载方式等进行预测。
    • 需要注意的是:模型预测阶段也是以batch的方式进行的,因此在预测时需要参考 main.py 脚本中的 test() 函数,将数据准备成batch的形式。

四、注意事项:

  • 项目的data目录下给定的是demo数据,是为了测试模型训练和预测的流程,由于demo数据的量很少,所以训练时打印的mean_valid_loss可能会随着训练次数变大,或者出现BLEU值为0的情况,在训练数据足够的情况下,不会出现此种情况。

🚩TODO:

  • 在训练阶段会保存每个epoch的模型参数,后期可以只保留top n个准确率比较高的模型。
  • target和输入数据共享了embedding,训练阶段容易发生震荡,以后可以考虑尝试对输入数据和target采用不同的embedding。

multi-source-pointer-network's People

Contributors

xiaolongjean avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

multi-source-pointer-network's Issues

真实数据训练情况

你好,请问能够给出使用论文中的数据集训练的参考结果吗?比如loss、blue值、epoch等。我使用论文的数据集训练了10个epoch但是似没有收敛,Loss在3.2左右波动,产生了很多[['Zara', '#PAD#', '#PAD#'], ['衬衫', '牛仔', '#PAD#'], ['衬衫', '#PAD#', '#PAD#']],这样的p#PAD#值,烦请解答,谢谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.