Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Script responsible for generation of a JSON file containing list of modules of a given collection. """ | |
| import argparse | |
| import importlib | |
| import inspect | |
| import json | |
| import os | |
| import nemo | |
| from nemo.utils import logging | |
| def process_member(name, obj, module_list): | |
| """ Helper function processing the passed object and, if ok, adding a record to the module list. | |
| Args: | |
| name: name of the member | |
| obj: member (class/function etc.) | |
| module_list: list of modules that (probably) will be expanded. | |
| """ | |
| # It is not a class - skip it. | |
| if not inspect.isclass(obj): | |
| return | |
| # Check inheritance - we know that all our datasets/modules/losses inherit from Serialization, | |
| # Btw. Serialization is also required by this script. | |
| if not issubclass(obj, nemo.core.Serialization): | |
| return | |
| logging.info(" * Processing `{}`".format(str(obj))) | |
| module_list.append( | |
| { | |
| "name": name, | |
| "cls": str(obj), | |
| # Temporary solution: mockup arguments. | |
| "arguments": [ | |
| "jasper", | |
| "activation", | |
| "feat_in", | |
| "normalization_mode", | |
| "residual_mode", | |
| "norm_groups", | |
| "conv_mask", | |
| "frame_splicing", | |
| "init_mode", | |
| ], | |
| # Temporary solution: mockup input types. | |
| "input_types": { | |
| "audio_signal": "axes: (batch, dimension, time); elements_type: MelSpectrogramType", | |
| "length": "axes: (batch,); elements_type: LengthType", | |
| }, | |
| # Temporary solution: mockup output types. | |
| "output_types": { | |
| "encoder_output": "axes: (batch, dimension, time); elements_type: AcousticEncodedRepresentation" | |
| }, | |
| } | |
| ) | |
| def main(): | |
| """ Main function analysing the indicated NeMo collection and generating a JSON file with module descriptions. """ | |
| # Parse filename. | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--collection', help='ID of the collection', type=str) | |
| parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="modules.json") | |
| args = parser.parse_args() | |
| # Get collections directory. | |
| colletions_dir = os.path.dirname(nemo.collections.__file__) | |
| logging.info('Analysing collections in `{}`'.format(colletions_dir)) | |
| # Generate list of NeMo collections - from the list of collection subfolders. | |
| collections = {} | |
| for sub_dir in os.listdir(colletions_dir): | |
| # Skip cache. | |
| if sub_dir == "__pycache__": | |
| continue | |
| # Check if it is a directory. | |
| if os.path.isdir(os.path.join(colletions_dir, sub_dir)): | |
| collections[sub_dir] = "nemo.collections." + sub_dir | |
| # Check the collection. | |
| if args.collection not in collections.keys(): | |
| logging.error("Coudn't process the incidated `{}` collection".format(args.collection)) | |
| logging.info( | |
| "Please select one of the existing collections using `--collection [{}]`".format("|".join(collections)) | |
| ) | |
| exit(-1) | |
| # Load the collection specification. | |
| collection_spec = importlib.util.find_spec(collections[args.collection]) | |
| if collection_spec is None: | |
| logging.error("Failed to load the `{}` collection".format(val)) | |
| # Import the module from the module specification. | |
| collection = importlib.util.module_from_spec(collection_spec) | |
| collection_spec.loader.exec_module(collection) | |
| module_list = [] | |
| # Iterate over the packages in the indicated collection. | |
| logging.info("Analysing the `{}` collection".format(args.collection)) | |
| try: # Datasets in dataset folder | |
| logging.info("Analysing the 'data' package") | |
| for name, obj in inspect.getmembers(collection.data): | |
| process_member(name, obj, module_list) | |
| except AttributeError as e: | |
| logging.info(" * No datasets found") | |
| try: # Datasets in dataset folder | |
| logging.info("Analysing the 'datasets' package") | |
| for name, obj in inspect.getmembers(collection.datasets): | |
| process_member(name, obj, module_list) | |
| except AttributeError as e: | |
| logging.info(" * No datasets found") | |
| try: # Modules | |
| logging.info("Analysing the 'modules' package") | |
| for name, obj in inspect.getmembers(collection.modules): | |
| process_member(name, obj, module_list) | |
| except AttributeError as e: | |
| logging.info(" * No modules found") | |
| try: # Losses | |
| logging.info("Analysing the 'losses' package") | |
| for name, obj in inspect.getmembers(collection.losses): | |
| process_member(name, obj, module_list) | |
| except AttributeError as e: | |
| logging.info(" * No losses found") | |
| # Add prefix - only for default name. | |
| filename = args.filename if args.filename != "modules.json" else args.collection + "_" + args.filename | |
| # Export to JSON. | |
| with open(filename, 'w', encoding='utf-8') as outfile: | |
| json.dump(module_list, outfile) | |
| logging.info( | |
| 'Finished analysis of the `{}` collection, results exported to `{}`.'.format(args.collection, filename) | |
| ) | |
| if __name__ == '__main__': | |
| main() | |