Git Product home page Git Product logo

fast_adversarial_for_text_classification's Introduction

fast_adversarial_for_text_classification

本项目基于TextCNN,测试三种对抗训练模型(FGSM,PGD,FREE)在text classification上的表现。主要参考论文Fast is better than free: Revisiting adversarial training涉及的三个对抗训练方法:FGSM(Fast Gradient Sign Method)、PGD(projected gradient decent)、FREE(Free adversarial based on FGSM)。这三种方法主要差异在于delta、alpha参数的初始化和更新方式上,其差异性可以见下面三个模型对应的伪代码。

PGD、FREE、FGSM模型的计算逻辑:
image

1 环境

python3.7
torch 1.8.0+cu111
scikit-learn
scipy
numpy

2 数据集

本实验同样是使用THUCNews的一个子集进行训练与测试,数据集请自行到THUCTC:一个高效的中文文本分类工具包下载,请遵循数据提供方的开源协议;
文本类别涉及10个类别:categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'];
cnews.train.txt: 训练集(5000*10)
cnews.val.txt: 验证集(500*10)
cnews.test.txt: 测试集(1000*10)
训练所用的数据,以及训练好的词向量可以下载:链接: https://pan.baidu.com/s/1DOgxlY42roBpOKAMKPPKWA,密码: up9d

3 生成对抗样本思路

本实验在文本按字级别Embedding上,利用对抗训练方法产生attack后,然后再加入embedding中,最后利用cnn来进行文本特征学习。其实现部分代码如下:

    def forward(self, inputs_ids,attack=None,is_training=True):
       embs = self.embedding(inputs_ids)
       if attack is not None:
           embs=embs+attack        #加入干扰信息
       embs=embs.unsqueeze(1)
       ....
       out = self.fc2(fc)
       return out

此外,论文涉及到三个超参数分别为epsilon、alpha、attack_iters,同时epsilon会控制delta参数的生成。本实验中,根据论文在图像领域的设定,以及自己主观经验判断和小范围的网格搜索方式设定以下具体值。相比论文中根据数据分布的均值和方差方式来设定,本项目显得更粗糙些,但为了达到只是实验对比的目的,此设定也是有效的。

epsilon = torch.tensor(0.1)
alpha= 0.04
attack_iters=5

4 运行步骤

首先在config.py中选择要运行的mode,训练与测试分别执行如下:
python run.py train
python run.py test

5 训练结果

四种模型训练20轮后,在测试集上实验结果如下:

Model Accuracy Precision Recall F1-score
TextCNN 95.14 95.16 95.14 95.11
FGSM 95.53 95.60 95.53 95.50
PGD 95.63 95.67 95.63 95.60
FREE 95.49 95.54 94.49 95.46

四种模型训练消耗的时间(minutes)对比如下:

model total_cost mean_cost
TextCNN 3.7 0.185
FGSM 5.83 0.2915
PGD 13.54 0.677
FREE 12.22 0.611

6 结论

通过本次实验,有以下几点结论与想法:

  • 对抗训练技术方法的确有助于提高文本分类任务的效果;
  • FGSM方法虽然提高了训练效率,但并不影响推理速度,而且NLP领域任务都不用大的轮数,所以PGD方法更合适些;
  • 三种方法涉及delta、alpha超参数的初始化设定,面临不同的任务,会有变动,增加探寻合适参数的难度;
  • 在文本分类中,觉得用word2vec或者bert方式初始化向量来进行干扰样本生成,应会比随机初始化embedding方式更合适,而且可以根据高频率词的分布来初始化delta、alpha更合理;
  • 若在本论文提出的改进版FGSM基础上班,考虑如何更稳定或自动化的方式初始化delta等参数,是一个值得优化的方向。

Reference

1.FAST IS BETTER THAN FREE: REVISITING ADVERSARIAL TRAINING
2.https://github.com/locuslab/fast_adversarial
3.Adversarial Training Methods for Semi-Supervised Text Classification.
4.[论文笔记] Projected Gradient Descent (PGD)

fast_adversarial_for_text_classification's People

Contributors

cjymz886 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

yuanjie-ai

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.