- Install HuggingFace dependences
conda install -c huggingface transformers
pip install datasets
- (Optional) For attacks against DBPedia14, download from Kaggle and setup data directory to contain:
<data_dir>/dbpedia_csv/
train.csv
test.csv
Use the following training script to finetune a pre-trained transformer model from HuggingFace:
python text_classification.py --data_folder <data_dir> --dataset <dataset_name> --model <model_name> --finetune True
To attack a finetuned model after running text_classification.py
or from the TextAttack library:
python whitebox_attack.py --data_folder <data_dir> --dataset <dataset_name> --model <model_name> --finetune True --start_index 0 --num_samples 100 --gumbel_samples 100
This runs the GBDA on the first 100 samples from the test set.
To attack a BERT model, GBDA requires a casual language model trained on the BERT tokenizer. We provide a pretrained GPT-2 model for this purpose. Before the attack, please run the following script to download the model from the Amazon S3 bucket:
curl https://dl.fbaipublicfiles.com/text-adversarial-attack/transformer_wikitext-103.pth -o transformer_wikitext-103.pth
After attacking a model, run the following script to query a target model from the optimized adversarial distribution:
python evaluate_adv_samples.py --data_folder <data_dir> --dataset <dataset_name> --surrogate_model <surrogate_model_name> --target_model <target_model_name> --finetune True --start_index 0 --num_samples 100 --end_index 100 --gumbel_samples 1000
Please cite [1] if you found the resources in this repository useful.
[1] C. Guo *, A. Sablayrolles *, Herve Jegou, Douwe Kiela. Gradient-based Adversarial Attacks against Text Transformers. EMNLP 2021.
@article{guo2021gradientbased,
title={Gradient-based Adversarial Attacks against Text Transformers},
author={Guo, Chuan and Sablayrolles, Alexandre and Jégou, Hervé and Kiela, Douwe},
journal={arXiv preprint arXiv:2104.13733},
year={2021}
}
See the CONTRIBUTING file for how to help out.
This project is CC-BY-NC 4.0 licensed, as found in the LICENSE file.