Skip to content

Commit

Permalink
added infer for classification, detached xlnet tokenizer from pytorch…
Browse files Browse the repository at this point in the history
…_transformers
  • Loading branch information
plkmo committed Sep 24, 2019
1 parent 9743108 commit 7e0489c
Show file tree
Hide file tree
Showing 7 changed files with 1,124 additions and 34 deletions.
80 changes: 57 additions & 23 deletions classification/models/XLNet/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
import json
import logging
import os
import six
import shutil
import tempfile
import fnmatch
from functools import wraps
from hashlib import sha256
import sys
from io import open

import boto3
import requests
from botocore.config import Config
from botocore.exceptions import ClientError
import requests
from tqdm import tqdm

try:
Expand All @@ -39,13 +40,43 @@
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path)
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path))

PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility

WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json"

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

if not six.PY2:
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = ''.join(docstr) + fn.__doc__
return fn
return docstring_decorator

def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + ''.join(docstr)
return fn
return docstring_decorator
else:
# Not possible to update class docstrings on python2
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
return fn
return docstring_decorator

def add_end_docstrings(*docstr):
def docstring_decorator(fn):
return fn
return docstring_decorator

def url_to_filename(url, etag=None):
"""
Expand All @@ -71,7 +102,7 @@ def filename_to_url(filename, cache_dir=None):
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)

Expand All @@ -91,15 +122,18 @@ def filename_to_url(filename, cache_dir=None):
return url, etag


def cached_path(url_or_filename, cache_dir=None):
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
Expand All @@ -109,7 +143,7 @@ def cached_path(url_or_filename, cache_dir=None):

if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
Expand Down Expand Up @@ -154,24 +188,24 @@ def wrapper(url, *args, **kwargs):


@s3_request
def s3_etag(url):
def s3_etag(url, proxies=None):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag


@s3_request
def s3_get(url, temp_file):
def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def http_get(url, temp_file):
req = requests.get(url, stream=True)
def http_get(url, temp_file, proxies=None):
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
Expand All @@ -182,13 +216,13 @@ def http_get(url, temp_file):
progress.close()


def get_from_cache(url, cache_dir=None):
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
Expand All @@ -199,10 +233,10 @@ def get_from_cache(url, cache_dir=None):

# Get eTag to add to filename, if it exists.
if url.startswith("s3:https://"):
etag = s3_etag(url)
etag = s3_etag(url, proxies=proxies)
else:
try:
response = requests.head(url, allow_redirects=True)
response = requests.head(url, allow_redirects=True, proxies=proxies)
if response.status_code != 200:
etag = None
else:
Expand All @@ -225,17 +259,17 @@ def get_from_cache(url, cache_dir=None):
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])

if not os.path.exists(cache_path):
if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)

# GET file object
if url.startswith("s3:https://"):
s3_get(url, temp_file)
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file)
http_get(url, temp_file, proxies=proxies)

# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
Expand Down
2 changes: 1 addition & 1 deletion classification/models/XLNet/preprocessing_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import pickle
import pandas as pd
from pytorch_transformers import XLNetTokenizer
from .tokenization_xlnet import XLNetTokenizer
import logging

logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
Expand Down
Loading

0 comments on commit 7e0489c

Please sign in to comment.