NeMo/external/get_modules.py
Tomasz Kornuta be289129ad
removed the # ==== showstoppers in all headers (#924)
Signed-off-by: Tomasz Kornuta <tkornuta@nvidia.com>
2020-07-27 12:55:51 -07:00

160 lines
5.8 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Script responsible for generation of a JSON file containing list of modules of a given collection. """
import argparse
import importlib
import inspect
import json
import os
import nemo
from nemo.utils import logging
def process_member(name, obj, module_list):
""" Helper function processing the passed object and, if ok, adding a record to the module list.
Args:
name: name of the member
obj: member (class/function etc.)
module_list: list of modules that (probably) will be expanded.
"""
# It is not a class - skip it.
if not inspect.isclass(obj):
return
# Check inheritance - we know that all our datasets/modules/losses inherit from Serialization,
# Btw. Serialization is also required by this script.
if not issubclass(obj, nemo.core.Serialization):
return
logging.info(" * Processing `{}`".format(str(obj)))
module_list.append(
{
"name": name,
"cls": str(obj),
# Temporary solution: mockup arguments.
"arguments": [
"jasper",
"activation",
"feat_in",
"normalization_mode",
"residual_mode",
"norm_groups",
"conv_mask",
"frame_splicing",
"init_mode",
],
# Temporary solution: mockup input types.
"input_types": {
"audio_signal": "axes: (batch, dimension, time); elements_type: MelSpectrogramType",
"length": "axes: (batch,); elements_type: LengthType",
},
# Temporary solution: mockup output types.
"output_types": {
"encoder_output": "axes: (batch, dimension, time); elements_type: AcousticEncodedRepresentation"
},
}
)
def main():
""" Main function analysing the indicated NeMo collection and generating a JSON file with module descriptions. """
# Parse filename.
parser = argparse.ArgumentParser()
parser.add_argument('--collection', help='ID of the collection', type=str)
parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="modules.json")
args = parser.parse_args()
# Get collections directory.
colletions_dir = os.path.dirname(nemo.collections.__file__)
logging.info('Analysing collections in `{}`'.format(colletions_dir))
# Generate list of NeMo collections - from the list of collection subfolders.
collections = {}
for sub_dir in os.listdir(colletions_dir):
# Skip cache.
if sub_dir == "__pycache__":
continue
# Check if it is a directory.
if os.path.isdir(os.path.join(colletions_dir, sub_dir)):
collections[sub_dir] = "nemo.collections." + sub_dir
# Check the collection.
if args.collection not in collections.keys():
logging.error("Coudn't process the incidated `{}` collection".format(args.collection))
logging.info(
"Please select one of the existing collections using `--collection [{}]`".format("|".join(collections))
)
exit(-1)
# Load the collection specification.
collection_spec = importlib.util.find_spec(collections[args.collection])
if collection_spec is None:
logging.error("Failed to load the `{}` collection".format(val))
# Import the module from the module specification.
collection = importlib.util.module_from_spec(collection_spec)
collection_spec.loader.exec_module(collection)
module_list = []
# Iterate over the packages in the indicated collection.
logging.info("Analysing the `{}` collection".format(args.collection))
try: # Datasets in dataset folder
logging.info("Analysing the 'data' package")
for name, obj in inspect.getmembers(collection.data):
process_member(name, obj, module_list)
except AttributeError as e:
logging.info(" * No datasets found")
try: # Datasets in dataset folder
logging.info("Analysing the 'datasets' package")
for name, obj in inspect.getmembers(collection.datasets):
process_member(name, obj, module_list)
except AttributeError as e:
logging.info(" * No datasets found")
try: # Modules
logging.info("Analysing the 'modules' package")
for name, obj in inspect.getmembers(collection.modules):
process_member(name, obj, module_list)
except AttributeError as e:
logging.info(" * No modules found")
try: # Losses
logging.info("Analysing the 'losses' package")
for name, obj in inspect.getmembers(collection.losses):
process_member(name, obj, module_list)
except AttributeError as e:
logging.info(" * No losses found")
# Add prefix - only for default name.
filename = args.filename if args.filename != "modules.json" else args.collection + "_" + args.filename
# Export to JSON.
with open(filename, 'w') as outfile:
json.dump(module_list, outfile)
logging.info(
'Finished analysis of the `{}` collection, results exported to `{}`.'.format(args.collection, filename)
)
if __name__ == '__main__':
main()