Git Product home page Git Product logo

multi-source-pointer-network's Introduction

多源指针网络生成短标题

Multi-Source Pointer Network


一、模型说明:

🙏 基于论文"Multi-Source Pointer Network for Product Title Summarization"开发,模型的具体介绍以及短标题的生成参考: https://mp.weixin.qq.com/s/5rAM44D50JHE-q1IrLEatw,在论文的基础上进行了如下几方面的优化:

  • 1️⃣ 融合词语和字符特征,解决测试阶段多个OOV词汇共享同一个Embedding的问题。

  • 2️⃣ 针对不同特征,编码器采用多组Transformer,然后对特征进行了融合。

  • 3️⃣ 推理阶段采用局部拷贝机制,解决全局拷贝无法解决#UNK#的情形。

  • 4️⃣ Beam Search阶段对初始分布添加mask,防止一开始生成 '#EOS#' 这一token。

  • ♻️ 如果采用GPU,所有batch的数据会一次性加载到GPU中,可以修改data.py分批加载。

二、运行方式:

  1. 数据格式:

    • 训练/验证数据:训练数据和验证数据分别有3列,即:source1,source2和target。
    • 在数据预处理阶段过滤了字包含英文、数字、标点等token。
    • 测试数据:只需要两列,即source1,source2。
    • 需要注意的是:两个source和target的数据需要进行分词,词语之间用空格分隔。
  2. 模型训练:

    • 训练阶段利用到的是 main.py 脚本中的 train() 函数,训练阶段可以根据情况把 main.py 脚本中的 test() 函数注释掉。

    • 准备好训练数据后按照下面的示例进行训练,或者修改 config.py 中相关训练数据的路径,然后直接执行 python3 main.py如果需要使用GPU,参考config.py 中的相关参数。

      python3 main.py --train_data_path ../data/train.dat \
                      --valid_data_path ../data/valid.dat \
                      --model_dir ../model/ \
                      --max_epoch 50
  3. 模型预测:

    • 预测阶段利用到的是 main.py 脚本中的 test() 函数,预测阶段需要把 train() 函数注释掉。
    • 预测阶段,可以参考 main.py 脚本中给定的测试数据格式、模型加载方式等进行预测。
    • 需要注意的是:模型预测阶段也是以batch的方式进行的,因此在预测时需要参考 main.py 脚本中的 test() 函数,将数据准备成batch的形式。

三、注意事项:

  • 项目的data目录下给定的是demo数据,是为了测试模型训练和预测的流程,由于demo数据的量很少,所以训练时打印的mean_valid_loss可能会随着训练次数变大,或者出现BLEU值为0的情况,在训练数据足够的情况下,不会出现此种情况。

🚩TODO:

  • 在训练阶段会保存每个epoch的模型参数,后期可以只保留top n个准确率比较高的模型。
  • target和输入数据共享了embedding,训练阶段容易发生震荡,以后可以考虑尝试对输入数据和target采用不同的embedding。

multi-source-pointer-network's People

Contributors

xiaolongjean avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.