Skip to content

Commit

Permalink
Fixed errors in transformation notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
amrzv committed Aug 6, 2022
1 parent a5d639e commit e4f0392
Showing 1 changed file with 38 additions and 26 deletions.
64 changes: 38 additions & 26 deletions notebooks/Write_a_sample_transformation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@
},
"source": [
"!pip install -r requirements.txt --quiet\n",
"!pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz\n"
"!pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz"
],
"execution_count": 5,
"outputs": []
Expand All @@ -1380,6 +1380,16 @@
"## Load modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import nltk\n",
"nltk.download('omw-1.4')"
]
},
{
"cell_type": "code",
"metadata": {
Expand All @@ -1390,13 +1400,13 @@
"outputId": "991e3d98-7e75-4129-a41a-2a04e1ffbd94"
},
"source": [
"from transformations.butter_fingers_perturbation.transformation import ButterFingersPerturbation\n",
"from transformations.change_person_named_entities.transformation import ChangePersonNamedEntities\n",
"from transformations.replace_numerical_values.transformation import ReplaceNumericalValues\n",
"from interfaces.SentenceOperation import SentenceOperation\n",
"from interfaces.QuestionAnswerOperation import QuestionAnswerOperation\n",
"from evaluation.evaluation_engine import evaluate, execute_model\n",
"from tasks.TaskTypes import TaskType"
"from nlaugmenter.transformations.butter_fingers_perturbation.transformation import ButterFingersPerturbation\n",
"from nlaugmenter.transformations.change_person_named_entities.transformation import ChangePersonNamedEntities\n",
"from nlaugmenter.transformations.replace_numerical_values.transformation import ReplaceNumericalValues\n",
"from nlaugmenter.interfaces.SentenceOperation import SentenceOperation\n",
"from nlaugmenter.interfaces.QuestionAnswerOperation import QuestionAnswerOperation\n",
"from nlaugmenter.evaluation.evaluation_engine import evaluate, execute_model\n",
"from nlaugmenter.tasks.TaskTypes import TaskType"
],
"execution_count": null,
"outputs": [
Expand Down Expand Up @@ -1728,24 +1738,25 @@
"import torch\n",
"from transformers import T5ForConditionalGeneration, AutoTokenizer\n",
"\n",
"\n",
"class MySecondTransformation(QuestionAnswerOperation):\n",
" tasks = [TaskType.QUESTION_ANSWERING, TaskType.QUESTION_GENERATION]\n",
" languages = [\"en\"]\n",
"\n",
" def __init__(self, max_outputs=5):\n",
" super().__init__()\n",
" model_name=\"prithivida/parrot_paraphraser_on_T5\"\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_name) \n",
" model_name = \"prithivida/parrot_paraphraser_on_T5\"\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" self.model = T5ForConditionalGeneration.from_pretrained(model_name)\n",
" self.max_outputs = max_outputs\n",
"\n",
" def generate(self, context, question, answers): # Note that the choice of inputs for 'generate' is consistent with those in QuestionAnswerOperation\n",
" \n",
" def generate(self, context, question, answers): # Note that the choice of inputs for 'generate' is consistent with those in QuestionAnswerOperation\n",
"\n",
" # Let's call the HF model to generate a paraphrase for the question\n",
" paraphrase_input = question\n",
" batch = self.tokenizer([paraphrase_input],truncation=True,padding='longest',max_length=60, return_tensors=\"pt\")\n",
" translated = self.model.generate(**batch,max_length=60,num_beams=10, num_return_sequences=self.max_outputs, temperature=1.5)\n",
" paraphrased_questions = self.tokenizer.batch_decode(translated, skip_special_tokens=True) \n",
" batch = self.tokenizer([paraphrase_input], truncation=True, padding='longest', max_length=60, return_tensors=\"pt\")\n",
" translated = self.model.generate(**batch, max_length=60, num_beams=10, num_return_sequences=self.max_outputs, temperature=1.5)\n",
" paraphrased_questions = self.tokenizer.batch_decode(translated, skip_special_tokens=True)\n",
"\n",
" # context = \"Apply your own logic here\"\n",
" # answers = \"And here too :)\"\n",
Expand Down Expand Up @@ -1941,7 +1952,7 @@
},
"source": [
"t4.generate(context=\"Mumbai, Bengaluru, New Delhi are among the many famous places in India.\", \n",
" question=\"What are the famous places we should not miss in India?\", \n",
" question=\"What are the famous places we should not miss in India?\",\n",
" answers=[\"Mumbai\", \"Bengaluru\", \"Delhi\", \"New Delhi\"])"
],
"execution_count": null,
Expand Down Expand Up @@ -2022,8 +2033,8 @@
"id": "WfUvpkSN0BKB"
},
"source": [
"from filters.keywords import TextContainsKeywordsFilter\n",
"from filters.length import TextLengthFilter, SentenceAndTargetLengthFilter"
"from nlaugmenter.filters.keywords import TextContainsKeywordsFilter\n",
"from nlaugmenter.filters.length import TextLengthFilter, SentenceAndTargetLengthFilter"
],
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -2134,7 +2145,7 @@
"outputId": "066fd81f-ac9f-400d-d14d-be26dabdc84b"
},
"source": [
"f2.filter(\"That show is going to take place in front of immensely massive crowds.\", \n",
"f2.filter(\"That show is going to take place in front of immensely massive crowds.\",\n",
" \"Large crowds would attend the show.\")"
],
"execution_count": null,
Expand Down Expand Up @@ -2163,7 +2174,7 @@
"outputId": "5f17c054-a00f-4aa2-dc7a-b19b4e719a0d"
},
"source": [
"f2.filter(\"The film was nominated for the Academy Award for Best Art Direction.\", \n",
"f2.filter(\"The film was nominated for the Academy Award for Best Art Direction.\",\n",
" \"The movie was a nominee for the Academy Award for Best Art Direction.\")"
],
"execution_count": null,
Expand Down Expand Up @@ -2201,25 +2212,26 @@
"source": [
"import spacy\n",
"\n",
"\n",
"class LowLexicalOverlapFilter(QuestionAnswerOperation):\n",
" tasks = [TaskType.QUESTION_ANSWERING, TaskType.QUESTION_GENERATION]\n",
" languages = [\"en\"]\n",
" \n",
"\n",
" def __init__(self, threshold=3):\n",
" super().__init__()\n",
" self.nlp = spacy.load(\"en_core_web_sm\")\n",
" self.threshold = threshold\n",
"\n",
" def filter(self, context, question, answers): \n",
" # Note that the only difference between a filter and a transformation is this method! \n",
" def filter(self, context, question, answers):\n",
" # Note that the only difference between a filter and a transformation is this method!\n",
" # The inputs remain the same!\n",
" \n",
"\n",
" question_tokenized = self.nlp(question, disable=[\"parser\", \"tagger\", \"ner\"])\n",
" context_tokenized = self.nlp(context, disable=[\"parser\", \"tagger\", \"ner\"])\n",
" \n",
"\n",
" q_tokens = set([t.text for t in question_tokenized])\n",
" c_tokens = set([t.text for t in context_tokenized])\n",
" \n",
"\n",
" low_lexical_overlap = len(q_tokens.intersection(c_tokens)) > self.threshold\n",
" return low_lexical_overlap"
],
Expand Down

0 comments on commit e4f0392

Please sign in to comment.