Skip to content

Commit

Permalink
Merge pull request fastai#669 from stas00/nlp
Browse files Browse the repository at this point in the history
arxiv notebook improvements/fixes
  • Loading branch information
sgugger committed Jul 31, 2018
2 parents f1317af + 7aa087b commit 2712d01
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 7 deletions.
134 changes: 128 additions & 6 deletions courses/dl1/lang_model-arxiv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"from fastai.nlp import *\n",
"from fastai.lm_rnn import *\n",
"\n",
"import dill as pickle"
"import dill as pickle\n",
"import random"
]
},
{
Expand Down Expand Up @@ -54,6 +55,114 @@
"### Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import os, requests, time\n",
"import feedparser\n",
"import pandas as pd\n",
"\n",
"\n",
"class GetArXiv(object):\n",
" def __init__(self, pickle_path, categories=list()):\n",
" \"\"\"\n",
" :param pickle_path (str): path to pickle data file to save/load\n",
" :param pickle_name (str): file name to save pickle to path\n",
" :param categories (list): arXiv categories to query\n",
" \"\"\"\n",
" if os.path.isdir(pickle_path):\n",
" pickle_path = f\"{pickle_path}{'' if pickle_path[-1] == '/' else '/'}all_arxiv.pkl\"\n",
" if len(categories) < 1:\n",
" categories = ['cs*', 'cond-mat.dis-nn', 'q-bio.NC', 'stat.CO', 'stat.ML']\n",
" # categories += ['cs.CV', 'cs.AI', 'cs.LG', 'cs.CL']\n",
"\n",
" self.categories = categories\n",
" self.pickle_path = pickle_path\n",
" self.base_url = 'http:https://export.arxiv.org/api/query'\n",
"\n",
" @staticmethod\n",
" def build_qs(categories):\n",
" \"\"\"Build query string from categories\"\"\"\n",
" return '+OR+'.join(['cat:'+c for c in categories])\n",
"\n",
" @staticmethod\n",
" def get_entry_dict(entry):\n",
" \"\"\"Return a dictionary with the items we want from a feedparser entry\"\"\"\n",
" try:\n",
" return dict(title=entry['title'], authors=[a['name'] for a in entry['authors']],\n",
" published=pd.Timestamp(entry['published']), summary=entry['summary'],\n",
" link=entry['link'], category=entry['category'])\n",
" except KeyError:\n",
" print('Missing keys in row: {}'.format(entry))\n",
" return None\n",
"\n",
" @staticmethod\n",
" def strip_version(link):\n",
" \"\"\"Strip version number from arXiv paper link\"\"\"\n",
" return link[:-2]\n",
"\n",
" def fetch_updated_data(self, max_retry=5, pg_offset=0, pg_size=1000, wait_time=15):\n",
" \"\"\"\n",
" Get new papers from arXiv server\n",
" :param max_retry: max number of time to retry request\n",
" :param pg_offset: number of pages to offset\n",
" :param pg_size: num abstracts to fetch per request\n",
" :param wait_time: num seconds to wait between requests\n",
" \"\"\"\n",
" i, retry = pg_offset, 0\n",
" df = pd.DataFrame()\n",
" past_links = []\n",
" if os.path.isfile(self.pickle_path):\n",
" df = pd.read_pickle(self.pickle_path)\n",
" df.reset_index()\n",
" if len(df) > 0: past_links = df.link.apply(self.strip_version)\n",
"\n",
" while True:\n",
" params = dict(search_query=self.build_qs(self.categories),\n",
" sortBy='submittedDate', start=pg_size*i, max_results=pg_size)\n",
" response = requests.get(self.base_url, params='&'.join([f'{k}={v}' for k, v in params.items()]))\n",
" entries = feedparser.parse(response.text).entries\n",
" if len(entries) < 1:\n",
" if retry < max_retry:\n",
" retry += 1\n",
" time.sleep(wait_time)\n",
" continue\n",
" break\n",
"\n",
" results_df = pd.DataFrame([self.get_entry_dict(e) for e in entries])\n",
" max_date = results_df.published.max().date()\n",
" new_links = ~results_df.link.apply(self.strip_version).isin(past_links)\n",
" print(f'{i}. Fetched {len(results_df)} abstracts published {max_date} and earlier')\n",
" if not new_links.any():\n",
" break\n",
"\n",
" df = pd.concat((df, results_df.loc[new_links]), ignore_index=True)\n",
" i += 1\n",
" retry = 0\n",
" time.sleep(wait_time)\n",
"\n",
" print(f'Downloaded {len(df)-len(past_links)} new abstracts')\n",
" df.sort_values('published', ascending=False).groupby('link').first().reset_index()\n",
" df.to_pickle(self.pickle_path)\n",
" return df\n",
"\n",
" @classmethod\n",
" def load(cls, pickle_path):\n",
" \"\"\"Load data from pickle and remove duplicates\"\"\"\n",
" return pd.read_pickle(cls(pickle_path).pickle_path)\n",
"\n",
" @classmethod\n",
" def update(cls, pickle_path, categories=list(), **kwargs):\n",
" \"\"\"\n",
" Update arXiv data pickle with the latest abstracts\n",
" \"\"\"\n",
" cls(pickle_path, categories).fetch_updated_data(**kwargs)\n",
" return True"
]
},
{
"cell_type": "code",
"execution_count": 55,
Expand All @@ -62,10 +171,16 @@
},
"outputs": [],
"source": [
"PATH='/data2/datasets/part1/arxiv/'\n",
"PATH='data/arxiv/'\n",
"\n",
"ALL_ARXIV = f'{PATH}all_arxiv.pkl'\n",
"\n",
"# all_arxiv.pkl: if arxiv hasn't been downloaded yet, it'll take some time to get it - go get some coffee\n",
"if not os.path.exists(ALL_ARXIV): GetArXiv.update(ALL_ARXIV)\n",
"\n",
"# arxiv.csv: see dl1/nlp-arxiv.ipynb to get this one\n",
"df_mb = pd.read_csv(f'{PATH}arxiv.csv')\n",
"df_all = pd.read_pickle(f'{PATH}all_arxiv.pickle')"
"df_all = pd.read_pickle(ALL_ARXIV)"
]
},
{
Expand Down Expand Up @@ -146,6 +261,9 @@
"source": [
"from spacy.symbols import ORTH\n",
"\n",
"# install the 'en' model if the next line of code fails by running:\n",
"#python -m spacy download en # default English model (~50MB)\n",
"#python -m spacy download en_core_web_md # larger English model (~1GB)\n",
"my_tok = spacy.load('en')\n",
"\n",
"my_tok.tokenizer.add_special_case('<SUMM>', [{ORTH: '<SUMM>'}])\n",
Expand All @@ -167,7 +285,7 @@
"source": [
"TEXT = data.Field(lower=True, tokenize=my_spacy_tok)\n",
"FILES = dict(train='trn', validation='val', test='val')\n",
"md = LanguageModelData(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)\n",
"md = LanguageModelData.from_text_files(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)\n",
"pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))"
]
},
Expand Down Expand Up @@ -854,7 +972,9 @@
" fields = [('text', text_field), ('label', label_field)]\n",
" examples = []\n",
" for label in ['yes', 'no']:\n",
" for fname in glob(os.path.join(path, label, '*.txt')):\n",
" fnames = glob(os.path.join(path, label, '*.txt'));\n",
" assert fnames, f\"can't find 'yes.txt' or 'no.txt' under {path}/{label}\"\n",
" for fname in fnames:\n",
" with open(fname, 'r') as f: text = f.readline()\n",
" examples.append(data.Example.fromlist([text, label], fields))\n",
" super().__init__(examples, fields, **kwargs)\n",
Expand Down Expand Up @@ -944,8 +1064,10 @@
},
"outputs": [],
"source": [
"# this notebook has a mess of some things going under 'all/' others not, so a little hack here\n",
"!ln -sf ../all/models/adam3_20_enc.h5 {PATH}models/adam3_20_enc.h5\n",
"m3.load_encoder(f'adam3_20_enc')\n",
"lrs=np.array([1e-4,1e-3,1e-2])"
"lrs=np.array([1e-4,1e-3,1e-3,1e-2,3e-2])"
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion courses/dl1/nlp-arxiv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@
}
],
"source": [
"PATH='/data2/datasets/part1/arxiv/arxiv.csv'\n",
"PATH='data/arxiv/arxiv.csv'\n",
"\n",
"# You can download a similar to Jeremy's original arxiv.csv here: https://drive.google.com/file/d/0B34BjUTAgwm6SzdPWDAtVG1vWVU/. It comes from this article https://hackernoon.com/building-brundage-bot-10252facf3d1 and github https://github.com/amauboussin/arxiv-twitterbot, just rename it to arxiv.csv\n",
"\n",
"df = pd.read_csv(PATH)\n",
"df.head()"
Expand Down

0 comments on commit 2712d01

Please sign in to comment.