This is the source code for the paper Robust Neural Text Classification and Entailment via Mixup Regularized Adversarial Training (MRAT). In Proceedings of the 44th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’21). [link]
-
Create environment and install requirement packages using provided
environment.yml:conda env create -f environment.yml conda activate MRAT -
Download datasets. We follow the practice of textfooler and use the dataset they provided.
- Download
mr.zipandsnli.zipdatasets from googledrive. - Extract
mr.ziptotrain/data/mrand extractsnli.ziptotrain/data/snli. - Download 1k split datasets of
mrandsnlifrom url. - Put
mrandsnlitoattack/data.
- Download
- [Optional] For ease of replication, we shared
adversarial examplesandtrained BERT model. Details see here.
-
Train victim model. Alternatively, you can download our trained victim model from here.
cd train/train_command_mr/ python bert_mr_normal.pyFor
snlidataset we do not need this step. We use thebert-base-uncased-snlifrom TextAttack Model Zoo as attack target model on snli dataset. -
Test the adversarial robustness of the victim model.
cd ../../ cd attack/attack_command_mr/ python attack_bert_mr_test_textbugger.py python attack_bert_mr_test_deepwordbug.py python attack_bert_mr_test_textfooler.py -
Attack victim model on training set and save generate adversarial examples. Alternatively, you can download our generated
adversarial examplesfrom here.python attack_bert_mr_train_textbugger.py python attack_bert_mr_train_deepwordbug.py python attack_bert_mr_train_textfooler.py -
Train
MRATandMRAT+models. Alternatively, you can download our trained models from here.cd ../../ cd train/train_command_mr/ python bert_mr_mix_multi.py # MRAT python bert_mr_mixN_multi.py # MRAT+ -
Test the adversarial robustness of
MRATandMRAT+.cd ../../ cd attack/attack_command_mr/ # MRAT python attack_bert_mr_test_mix-multi_textbugger.py python attack_bert_mr_test_mix-multi_deepwordbug.py python attack_bert_mr_test_mix-multi_textfooler.py # MRAT+ python attack_bert_mr_test_mixN-multi_textbugger.py python attack_bert_mr_test_mixN-multi_deepwordbug.py python attack_bert_mr_test_mixN-multi_textfooler.py
- Above steps uses
mrdataset as example. Forsnlidataset, the usage is the same. Just changeattack_command_mrandtrain_command_mrtoattack_command_snliandtrain_command_snlifolds and run the python scripts inside. - Follows the above steps, you can reproduce the results in
Table 2of our paper. For additional results, a little change of argument setting in this scripts may be needed. - The code has been tested on CentOS 7 using multi-GPU. For single GPU usage, a little change may be needed.
- If you need helps, feel free to open an issue. 😊
🎉A huge thanks to TextAttack!~ The most of the code in this project is derived from the helpful toolbox TextAttack. Because TextAttack update frequently, we leave a copy in attack\textattack fold.