NeMo/tests/conftest.py
Somshubra Majumdar d43d3ab85e
Update conftest.py to auto untar local data (#2655)
Signed-off-by: smajumdar <titu1994@gmail.com>
2021-08-13 14:03:13 -07:00

214 lines
7.3 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tarfile
import urllib.request
from os import mkdir
from os.path import dirname, exists, getsize, join
from pathlib import Path
from shutil import rmtree
import pytest
# Those variables probably should go to main NeMo configuration file (config.yaml).
__TEST_DATA_FILENAME = "test_data.tar.gz"
__TEST_DATA_URL = "https://github.com/NVIDIA/NeMo/releases/download/v1.0.0rc1/"
__TEST_DATA_SUBDIR = ".data"
def pytest_addoption(parser):
"""
Additional command-line arguments passed to pytest.
For now:
--cpu: use CPU during testing (DEFAULT: GPU)
--use_local_test_data: use local test data/skip downloading from URL/GitHub (DEFAULT: False)
"""
parser.addoption(
'--cpu', action='store_true', help="pass that argument to use CPU during testing (DEFAULT: False = GPU)"
)
parser.addoption(
'--use_local_test_data',
action='store_true',
help="pass that argument to use local test data/skip downloading from URL/GitHub (DEFAULT: False)",
)
parser.addoption(
'--with_downloads',
action='store_true',
help="pass this argument to active tests which download models from the cloud.",
)
parser.addoption(
'--relax_numba_compat',
action='store_false',
help="numba compatibility checks will be relaxed to just availability of cuda, "
"without cuda compatibility matrix check",
)
parser.addoption(
'--tn_cache_dir',
type=str,
default=None,
help="path to a directory with .far grammars for CPU TN/ITN tests, (DEFAULT: None, i.e. no cache)",
)
@pytest.fixture
def device(request):
""" Simple fixture returning string denoting the device [CPU | GPU] """
if request.config.getoption("--cpu"):
return "CPU"
else:
return "GPU"
@pytest.fixture(autouse=True)
def run_only_on_device_fixture(request, device):
if request.node.get_closest_marker('run_only_on'):
if request.node.get_closest_marker('run_only_on').args[0] != device:
pytest.skip('skipped on this device: {}'.format(device))
@pytest.fixture(autouse=True)
def downloads_weights(request, device):
if request.node.get_closest_marker('with_downloads'):
if not request.config.getoption("--with_downloads"):
pytest.skip(
'To run this test, pass --with_downloads option. It will download (and cache) models from cloud.'
)
@pytest.fixture(autouse=True)
def cleanup_local_folder():
# Asserts in fixture are not recommended, but I'd rather stop users from deleting expensive training runs
assert not Path("./lightning_logs").exists()
assert not Path("./NeMo_experiments").exists()
assert not Path("./nemo_experiments").exists()
yield
if Path("./lightning_logs").exists():
rmtree('./lightning_logs', ignore_errors=True)
if Path("./NeMo_experiments").exists():
rmtree('./NeMo_experiments', ignore_errors=True)
if Path("./nemo_experiments").exists():
rmtree('./nemo_experiments', ignore_errors=True)
@pytest.fixture
def test_data_dir():
""" Fixture returns test_data_dir. """
# Test dir.
test_data_dir_ = join(dirname(__file__), __TEST_DATA_SUBDIR)
return test_data_dir_
def extract_data_from_tar(test_dir, test_data_archive, url=None):
# Remove .data folder.
if exists(test_dir):
rmtree(test_dir)
# Create one .data folder.
mkdir(test_dir)
# Download (if required)
if url is not None:
urllib.request.urlretrieve(url, test_data_archive)
# Extract tar
print("Extracting the `{}` test archive, please wait...".format(test_data_archive))
tar = tarfile.open(test_data_archive)
tar.extractall(path=test_dir)
tar.close()
def pytest_configure(config):
"""
Initial configuration of conftest.
The function checks if test_data.tar.gz is present in tests/.data.
If so, compares its size with github's test_data.tar.gz.
If file absent or sizes not equal, function downloads the archive from github and unpacks it.
"""
config.addinivalue_line(
"markers", "run_only_on(device): runs the test only on a given device [CPU | GPU]",
)
# Test dir and archive filepath.
test_dir = join(dirname(__file__), __TEST_DATA_SUBDIR)
test_data_archive = join(dirname(__file__), __TEST_DATA_SUBDIR, __TEST_DATA_FILENAME)
# Get size of local test_data archive.
try:
test_data_local_size = getsize(test_data_archive)
except:
# File does not exist.
test_data_local_size = -1
if config.option.use_local_test_data:
if test_data_local_size == -1:
pytest.exit("Test data `{}` is not present in the system".format(test_data_archive))
else:
print(
"Using the local `{}` test archive ({}B) found in the `{}` folder.".format(
__TEST_DATA_FILENAME, test_data_local_size, test_dir
)
)
# Get size of remote test_data archive.
if not config.option.use_local_test_data:
try:
url = __TEST_DATA_URL + __TEST_DATA_FILENAME
u = urllib.request.urlopen(url)
except:
# Couldn't access remote archive.
if test_data_local_size == -1:
pytest.exit("Test data not present in the system and cannot access the '{}' URL".format(url))
else:
print(
"Cannot access the '{}' URL, using the test data ({}B) found in the `{}` folder.".format(
url, test_data_local_size, test_dir
)
)
return
# Get metadata.
meta = u.info()
test_data_remote_size = int(meta["Content-Length"])
# Compare sizes.
if test_data_local_size != test_data_remote_size:
print(
"Downloading the `{}` test archive from `{}`, please wait...".format(
__TEST_DATA_FILENAME, __TEST_DATA_URL
)
)
extract_data_from_tar(test_dir, test_data_archive, url=url)
else:
print(
"A valid `{}` test archive ({}B) found in the `{}` folder.".format(
__TEST_DATA_FILENAME, test_data_local_size, test_dir
)
)
else:
# untar local test data
extract_data_from_tar(test_dir, test_data_archive)
if config.option.relax_numba_compat is not None:
from nemo.core.utils import numba_utils
print("Setting numba compat :", config.option.relax_numba_compat)
numba_utils.set_numba_compat_strictness(strict=config.option.relax_numba_compat)
# Set cache directory for TN/ITN tests
from .nemo_text_processing.utils import set_cache_dir
set_cache_dir(config.option.tn_cache_dir)