DeepLearningExamples/PyTorch/SpeechRecognition/QuartzNet/utils/download_utils.py
2021-09-14 06:03:36 -07:00

72 lines
2.3 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python
import hashlib
import requests
import os
import tarfile
import tqdm
def download_file(url, dest_folder, fname, overwrite=False):
fpath = os.path.join(dest_folder, fname)
if os.path.isfile(fpath):
if overwrite:
print("Overwriting existing file")
else:
print("File exists, skipping download.")
return
tmp_fpath = fpath + '.tmp'
if not os.path.exists(os.path.dirname(tmp_fpath)):
os.makedirs(os.path.dirname(tmp_fpath))
r = requests.get(url, stream=True)
file_size = int(r.headers['Content-Length'])
chunk_size = 1024 * 1024 # 1MB
total_chunks = int(file_size / chunk_size)
with open(tmp_fpath, 'wb') as fp:
content_iterator = r.iter_content(chunk_size=chunk_size)
chunks = tqdm.tqdm(content_iterator, total=total_chunks,
unit='MB', desc=fpath, leave=True)
for chunk in chunks:
fp.write(chunk)
os.rename(tmp_fpath, fpath)
def md5_checksum(fpath, target_hash):
file_hash = hashlib.md5()
with open(fpath, "rb") as fp:
for chunk in iter(lambda: fp.read(1024*1024), b""):
file_hash.update(chunk)
return file_hash.hexdigest() == target_hash
def extract(fpath, dest_folder):
if fpath.endswith('.tar.gz'):
mode = 'r:gz'
elif fpath.endswith('.tar'):
mode = 'r:'
else:
raise IOError('fpath has unknown extention: %s' % fpath)
with tarfile.open(fpath, mode) as tar:
members = tar.getmembers()
for member in tqdm.tqdm(iterable=members, total=len(members), leave=True):
tar.extract(path=dest_folder, member=member)