本项目是基于对论文《Attention Is All You Need》的学习和理解,对其中描述的模型和方法的一个简化实现。
项目在《WMT18》数据集上进行了简单的训练。《WMT18》是一个新闻评论数据集,包含训练集、测试集和验证集。训练集包含252,700对中英文句子,验证集和测试集各有大约2,000对句子。数据处理部分针对《WMT18》数据集进行了设计,若要训练其他数据集,则需要修改数据处理部分的代码。
目前训练尚未收敛,此代码实现的目的是为了更深入地学习和理解Transformer模型,而非追求最佳性能。
代码实现较为简单,并不完善。如果您在阅读代码后有更好的想法或改进建议,欢迎提出。
- Python (3.8及以上版本)
- PyTorch (1.31及以上版本)
- NVIDIA V100 8G GPU
- 若要使用自己的数据集,请修改
utils.py
中的文本读取和分词相关代码。 - 可以在
config.py
文件中修改超参数。 - 学习率调整策略已在
schedule.py
中实现。 - 运行
main.py
之前,需要先运行create_vocab.py
生成词表。 eval.py
中定义了评估函数,使用BLEU作为评价指标,可在测试集上评估模型训练结果。demo.py
中可以对训练好的模型进行演示。
关于论文的详细讲解,推荐观看B站上李沐老师的视频:链接。
GitHub仓库地址:mli/paper-reading。
作为一个学习者,我意识到项目中存在许多不足之处,期待得到大家的宝贵建议。由于算力限制,训练尚未完成,我将在适当的时候继续训练,并分享训练成果。