515 lines
18 KiB
Python
515 lines
18 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Utilities for working with the local dataset cache.
|
|
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
|
Copyright by the AllenNLP authors.
|
|
"""
|
|
|
|
import fnmatch
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tarfile
|
|
import tempfile
|
|
from contextlib import contextmanager
|
|
from functools import partial, wraps
|
|
from hashlib import sha256
|
|
from typing import Optional
|
|
from urllib.parse import urlparse
|
|
from zipfile import ZipFile, is_zipfile
|
|
|
|
import boto3
|
|
import requests
|
|
from botocore.config import Config
|
|
from botocore.exceptions import ClientError
|
|
from filelock import FileLock
|
|
from tqdm.auto import tqdm
|
|
|
|
# from examples import __version__
|
|
__version__ = "0.1"
|
|
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
try:
|
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
|
|
import torch
|
|
|
|
_torch_available = True # pylint: disable=invalid-name
|
|
logger.info("PyTorch version {} available.".format(torch.__version__))
|
|
else:
|
|
logger.info("Disabling PyTorch because USE_TF is set")
|
|
_torch_available = False
|
|
except ImportError:
|
|
_torch_available = False # pylint: disable=invalid-name
|
|
|
|
try:
|
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
|
|
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
|
|
import tensorflow as tf
|
|
|
|
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
|
_tf_available = True # pylint: disable=invalid-name
|
|
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
|
else:
|
|
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
|
_tf_available = False
|
|
except (ImportError, AssertionError):
|
|
_tf_available = False # pylint: disable=invalid-name
|
|
|
|
try:
|
|
from torch.hub import _get_torch_home
|
|
|
|
torch_cache_home = _get_torch_home()
|
|
except ImportError:
|
|
torch_cache_home = os.path.expanduser(
|
|
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
|
)
|
|
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
|
|
|
try:
|
|
from pathlib import Path
|
|
|
|
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
|
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path))
|
|
)
|
|
except (AttributeError, ImportError):
|
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv(
|
|
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
|
)
|
|
|
|
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
|
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
|
|
|
WEIGHTS_NAME = "pytorch_model.bin"
|
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
|
TF_WEIGHTS_NAME = "model.ckpt"
|
|
CONFIG_NAME = "config.json"
|
|
MODEL_CARD_NAME = "modelcard.json"
|
|
|
|
|
|
MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
|
|
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
|
|
|
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
|
CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net"
|
|
|
|
|
|
def is_torch_available():
|
|
return _torch_available
|
|
|
|
|
|
def is_tf_available():
|
|
return _tf_available
|
|
|
|
|
|
def add_start_docstrings(*docstr):
|
|
def docstring_decorator(fn):
|
|
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def add_start_docstrings_to_callable(*docstr):
|
|
def docstring_decorator(fn):
|
|
class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
|
|
intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
|
|
note = r"""
|
|
|
|
.. note::
|
|
Although the recipe for forward pass needs to be defined within
|
|
this function, one should call the :class:`Module` instance afterwards
|
|
instead of this since the former takes care of running the
|
|
pre and post processing steps while the latter silently ignores them.
|
|
"""
|
|
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def add_end_docstrings(*docstr):
|
|
def docstring_decorator(fn):
|
|
fn.__doc__ = fn.__doc__ + "".join(docstr)
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def is_remote_url(url_or_filename):
|
|
parsed = urlparse(url_or_filename)
|
|
return parsed.scheme in ("http", "https", "s3")
|
|
|
|
|
|
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
|
|
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX
|
|
if postfix is None:
|
|
return "/".join((endpoint, identifier))
|
|
else:
|
|
return "/".join((endpoint, identifier, postfix))
|
|
|
|
|
|
def url_to_filename(url, etag=None):
|
|
"""
|
|
Convert `url` into a hashed filename in a repeatable way.
|
|
If `etag` is specified, append its hash to the url's, delimited
|
|
by a period.
|
|
If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
|
|
so that TF 2.0 can identify it as a HDF5 file
|
|
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
|
"""
|
|
url_bytes = url.encode("utf-8")
|
|
url_hash = sha256(url_bytes)
|
|
filename = url_hash.hexdigest()
|
|
|
|
if etag:
|
|
etag_bytes = etag.encode("utf-8")
|
|
etag_hash = sha256(etag_bytes)
|
|
filename += "." + etag_hash.hexdigest()
|
|
|
|
if url.endswith(".h5"):
|
|
filename += ".h5"
|
|
|
|
return filename
|
|
|
|
|
|
def filename_to_url(filename, cache_dir=None):
|
|
"""
|
|
Return the url and etag (which may be ``None``) stored for `filename`.
|
|
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
cache_path = os.path.join(cache_dir, filename)
|
|
if not os.path.exists(cache_path):
|
|
raise EnvironmentError("file {} not found".format(cache_path))
|
|
|
|
meta_path = cache_path + ".json"
|
|
if not os.path.exists(meta_path):
|
|
raise EnvironmentError("file {} not found".format(meta_path))
|
|
|
|
with open(meta_path, encoding="utf-8") as meta_file:
|
|
metadata = json.load(meta_file)
|
|
url = metadata["url"]
|
|
etag = metadata["etag"]
|
|
|
|
return url, etag
|
|
|
|
|
|
def cached_path(
|
|
url_or_filename,
|
|
cache_dir=None,
|
|
force_download=False,
|
|
proxies=None,
|
|
resume_download=False,
|
|
user_agent=None,
|
|
extract_compressed_file=False,
|
|
force_extract=False,
|
|
local_files_only=False,
|
|
) -> Optional[str]:
|
|
"""
|
|
Given something that might be a URL (or might be a local path),
|
|
determine which. If it's a URL, download the file and cache it, and
|
|
return the path to the cached file. If it's already a local path,
|
|
make sure the file exists and then return the path.
|
|
Args:
|
|
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
|
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
|
resume_download: if True, resume the download if incompletly recieved file is found.
|
|
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
|
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
|
file in a folder along the archive.
|
|
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
|
re-extract the archive and overide the folder where it was extracted.
|
|
|
|
Return:
|
|
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
|
Local path (string) otherwise
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
if isinstance(url_or_filename, Path):
|
|
url_or_filename = str(url_or_filename)
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
if is_remote_url(url_or_filename):
|
|
# URL, so get it from the cache (downloading if necessary)
|
|
output_path = get_from_cache(
|
|
url_or_filename,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
user_agent=user_agent,
|
|
local_files_only=local_files_only,
|
|
)
|
|
elif os.path.exists(url_or_filename):
|
|
# File, and it exists.
|
|
output_path = url_or_filename
|
|
elif urlparse(url_or_filename).scheme == "":
|
|
# File, but it doesn't exist.
|
|
raise EnvironmentError("file {} not found".format(url_or_filename))
|
|
else:
|
|
# Something unknown
|
|
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
|
|
|
if extract_compressed_file:
|
|
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
|
return output_path
|
|
|
|
# Path where we extract compressed archives
|
|
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
|
output_dir, output_file = os.path.split(output_path)
|
|
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
|
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
|
|
|
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
|
return output_path_extracted
|
|
|
|
# Prevent parallel extractions
|
|
lock_path = output_path + ".lock"
|
|
with FileLock(lock_path):
|
|
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
|
os.makedirs(output_path_extracted)
|
|
if is_zipfile(output_path):
|
|
with ZipFile(output_path, "r") as zip_file:
|
|
zip_file.extractall(output_path_extracted)
|
|
zip_file.close()
|
|
elif tarfile.is_tarfile(output_path):
|
|
tar_file = tarfile.open(output_path)
|
|
tar_file.extractall(output_path_extracted)
|
|
tar_file.close()
|
|
else:
|
|
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
|
|
|
return output_path_extracted
|
|
|
|
return output_path
|
|
|
|
|
|
def split_s3_path(url):
|
|
"""Split a full s3 path into the bucket name and path."""
|
|
parsed = urlparse(url)
|
|
if not parsed.netloc or not parsed.path:
|
|
raise ValueError("bad s3 path {}".format(url))
|
|
bucket_name = parsed.netloc
|
|
s3_path = parsed.path
|
|
# Remove '/' at beginning of path.
|
|
if s3_path.startswith("/"):
|
|
s3_path = s3_path[1:]
|
|
return bucket_name, s3_path
|
|
|
|
|
|
def s3_request(func):
|
|
"""
|
|
Wrapper function for s3 requests in order to create more helpful error
|
|
messages.
|
|
"""
|
|
|
|
@wraps(func)
|
|
def wrapper(url, *args, **kwargs):
|
|
try:
|
|
return func(url, *args, **kwargs)
|
|
except ClientError as exc:
|
|
if int(exc.response["Error"]["Code"]) == 404:
|
|
raise EnvironmentError("file {} not found".format(url))
|
|
else:
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
|
|
@s3_request
|
|
def s3_etag(url, proxies=None):
|
|
"""Check ETag on S3 object."""
|
|
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
|
bucket_name, s3_path = split_s3_path(url)
|
|
s3_object = s3_resource.Object(bucket_name, s3_path)
|
|
return s3_object.e_tag
|
|
|
|
|
|
@s3_request
|
|
def s3_get(url, temp_file, proxies=None):
|
|
"""Pull a file directly from S3."""
|
|
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
|
bucket_name, s3_path = split_s3_path(url)
|
|
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
|
|
|
|
|
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
|
if is_torch_available():
|
|
ua += "; torch/{}".format(torch.__version__)
|
|
if is_tf_available():
|
|
ua += "; tensorflow/{}".format(tf.__version__)
|
|
if isinstance(user_agent, dict):
|
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
|
elif isinstance(user_agent, str):
|
|
ua += "; " + user_agent
|
|
headers = {"user-agent": ua}
|
|
if resume_size > 0:
|
|
headers["Range"] = "bytes=%d-" % (resume_size,)
|
|
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
|
if response.status_code == 416: # Range not satisfiable
|
|
return
|
|
content_length = response.headers.get("Content-Length")
|
|
total = resume_size + int(content_length) if content_length is not None else None
|
|
progress = tqdm(
|
|
unit="B",
|
|
unit_scale=True,
|
|
total=total,
|
|
initial=resume_size,
|
|
desc="Downloading",
|
|
disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
|
|
)
|
|
for chunk in response.iter_content(chunk_size=1024):
|
|
if chunk: # filter out keep-alive new chunks
|
|
progress.update(len(chunk))
|
|
temp_file.write(chunk)
|
|
progress.close()
|
|
|
|
|
|
def get_from_cache(
|
|
url,
|
|
cache_dir=None,
|
|
force_download=False,
|
|
proxies=None,
|
|
etag_timeout=10,
|
|
resume_download=False,
|
|
user_agent=None,
|
|
local_files_only=False,
|
|
) -> Optional[str]:
|
|
"""
|
|
Given a URL, look for the corresponding file in the local cache.
|
|
If it's not there, download it. Then return the path to the cached file.
|
|
|
|
Return:
|
|
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
|
Local path (string) otherwise
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
etag = None
|
|
if not local_files_only:
|
|
# Get eTag to add to filename, if it exists.
|
|
if url.startswith("s3://"):
|
|
etag = s3_etag(url, proxies=proxies)
|
|
else:
|
|
try:
|
|
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
|
if response.status_code == 200:
|
|
etag = response.headers.get("ETag")
|
|
except (EnvironmentError, requests.exceptions.Timeout):
|
|
# etag is already None
|
|
pass
|
|
|
|
filename = url_to_filename(url, etag)
|
|
|
|
# get cache path to put the file
|
|
cache_path = os.path.join(cache_dir, filename)
|
|
|
|
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
|
# try to get the last downloaded one
|
|
if etag is None:
|
|
if os.path.exists(cache_path):
|
|
return cache_path
|
|
else:
|
|
matching_files = [
|
|
file
|
|
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
|
if not file.endswith(".json") and not file.endswith(".lock")
|
|
]
|
|
if len(matching_files) > 0:
|
|
return os.path.join(cache_dir, matching_files[-1])
|
|
else:
|
|
# If files cannot be found and local_files_only=True,
|
|
# the models might've been found if local_files_only=False
|
|
# Notify the user about that
|
|
if local_files_only:
|
|
raise ValueError(
|
|
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
|
" to False."
|
|
)
|
|
return None
|
|
|
|
# From now on, etag is not None.
|
|
if os.path.exists(cache_path) and not force_download:
|
|
return cache_path
|
|
|
|
# Prevent parallel downloads of the same file with a lock.
|
|
lock_path = cache_path + ".lock"
|
|
with FileLock(lock_path):
|
|
|
|
if resume_download:
|
|
incomplete_path = cache_path + ".incomplete"
|
|
|
|
@contextmanager
|
|
def _resumable_file_manager():
|
|
with open(incomplete_path, "a+b") as f:
|
|
yield f
|
|
|
|
temp_file_manager = _resumable_file_manager
|
|
if os.path.exists(incomplete_path):
|
|
resume_size = os.stat(incomplete_path).st_size
|
|
else:
|
|
resume_size = 0
|
|
else:
|
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
|
resume_size = 0
|
|
|
|
# Download to temporary file, then copy to cache dir once finished.
|
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
|
with temp_file_manager() as temp_file:
|
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
|
|
|
# GET file object
|
|
if url.startswith("s3://"):
|
|
if resume_download:
|
|
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
|
|
s3_get(url, temp_file, proxies=proxies)
|
|
else:
|
|
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
|
|
|
logger.info("storing %s in cache at %s", url, cache_path)
|
|
os.replace(temp_file.name, cache_path)
|
|
|
|
logger.info("creating metadata file for %s", cache_path)
|
|
meta = {"url": url, "etag": etag}
|
|
meta_path = cache_path + ".json"
|
|
with open(meta_path, "w") as meta_file:
|
|
json.dump(meta, meta_file)
|
|
|
|
return cache_path
|