Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import json | |
| import os | |
| from nemo.collections.asr.metrics.der import evaluate_der | |
| from nemo.collections.asr.parts.utils.diarization_utils import OfflineDiarWithASR | |
| from nemo.collections.asr.parts.utils.manifest_utils import read_file | |
| from nemo.collections.asr.parts.utils.speaker_utils import ( | |
| get_uniqname_from_filepath, | |
| labels_to_pyannote_object, | |
| rttm_to_labels, | |
| ) | |
| """ | |
| Evaluation script for diarization with ASR. | |
| Calculates Diarization Error Rate (DER) with RTTM files and WER and cpWER with CTM files. | |
| In the output ctm_eval.csv file in the output folder, | |
| session-level DER, WER, cpWER and speaker counting accuracies are evaluated. | |
| - Evaluation mode | |
| diar_eval_mode == "full": | |
| DIHARD challenge style evaluation, the most strict way of evaluating diarization | |
| (collar, ignore_overlap) = (0.0, False) | |
| diar_eval_mode == "fair": | |
| Evaluation setup used in VoxSRC challenge | |
| (collar, ignore_overlap) = (0.25, False) | |
| diar_eval_mode == "forgiving": | |
| Traditional evaluation setup | |
| (collar, ignore_overlap) = (0.25, True) | |
| diar_eval_mode == "all": | |
| Compute all three modes (default) | |
| Use CTM files to calculate WER and cpWER | |
| ``` | |
| python eval_diar_with_asr.py \ | |
| --hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \ | |
| --ref_rttm_list="/path/to/reference_rttm_filepaths.list" \ | |
| --hyp_ctm_list="/path/to/hypothesis_ctm_filepaths.list" \ | |
| --ref_ctm_list="/path/to/reference_ctm_filepaths.list" \ | |
| --root_path="/path/to/output/directory" | |
| ``` | |
| Use .json files to calculate WER and cpWER | |
| ``` | |
| python eval_diar_with_asr.py \ | |
| --hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \ | |
| --ref_rttm_list="/path/to/reference_rttm_filepaths.list" \ | |
| --hyp_json_list="/path/to/hypothesis_json_filepaths.list" \ | |
| --ref_ctm_list="/path/to/reference_ctm_filepaths.list" \ | |
| --root_path="/path/to/output/directory" | |
| ``` | |
| Only use RTTMs to calculate DER | |
| ``` | |
| python eval_diar_with_asr.py \ | |
| --hyp_rttm_list="/path/to/hypothesis_rttm_filepaths.list" \ | |
| --ref_rttm_list="/path/to/reference_rttm_filepaths.list" \ | |
| --root_path="/path/to/output/directory" | |
| ``` | |
| """ | |
| def get_pyannote_objs_from_rttms(rttm_file_path_list): | |
| """Generate PyAnnote objects from RTTM file list | |
| """ | |
| pyannote_obj_list = [] | |
| for rttm_file in rttm_file_path_list: | |
| rttm_file = rttm_file.strip() | |
| if rttm_file is not None and os.path.exists(rttm_file): | |
| uniq_id = get_uniqname_from_filepath(rttm_file) | |
| ref_labels = rttm_to_labels(rttm_file) | |
| reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) | |
| pyannote_obj_list.append([uniq_id, reference]) | |
| return pyannote_obj_list | |
| def make_meta_dict(hyp_rttm_list, ref_rttm_list): | |
| """Create a temporary `audio_rttm_map_dict` for evaluation | |
| """ | |
| meta_dict = {} | |
| for k, rttm_file in enumerate(ref_rttm_list): | |
| uniq_id = get_uniqname_from_filepath(rttm_file) | |
| meta_dict[uniq_id] = {"rttm_filepath": rttm_file.strip()} | |
| if hyp_rttm_list is not None: | |
| hyp_rttm_file = hyp_rttm_list[k] | |
| meta_dict[uniq_id].update({"hyp_rttm_filepath": hyp_rttm_file.strip()}) | |
| return meta_dict | |
| def make_trans_info_dict(hyp_json_list_path): | |
| """Create `trans_info_dict` from the `.json` files | |
| """ | |
| trans_info_dict = {} | |
| for json_file in hyp_json_list_path: | |
| json_file = json_file.strip() | |
| with open(json_file) as jsf: | |
| json_data = json.load(jsf) | |
| uniq_id = get_uniqname_from_filepath(json_file) | |
| trans_info_dict[uniq_id] = json_data | |
| return trans_info_dict | |
| def read_file_path(list_path): | |
| """Read file path and strip to remove line change symbol | |
| """ | |
| return sorted([x.strip() for x in read_file(list_path)]) | |
| def main( | |
| hyp_rttm_list_path: str, | |
| ref_rttm_list_path: str, | |
| hyp_ctm_list_path: str, | |
| ref_ctm_list_path: str, | |
| hyp_json_list_path: str, | |
| diar_eval_mode: str = "all", | |
| root_path: str = "./", | |
| ): | |
| # Read filepath list files | |
| hyp_rttm_list = read_file_path(hyp_rttm_list_path) if hyp_rttm_list_path else None | |
| ref_rttm_list = read_file_path(ref_rttm_list_path) if ref_rttm_list_path else None | |
| hyp_ctm_list = read_file_path(hyp_ctm_list_path) if hyp_ctm_list_path else None | |
| ref_ctm_list = read_file_path(ref_ctm_list_path) if ref_ctm_list_path else None | |
| hyp_json_list = read_file_path(hyp_json_list_path) if hyp_json_list_path else None | |
| audio_rttm_map_dict = make_meta_dict(hyp_rttm_list, ref_rttm_list) | |
| trans_info_dict = make_trans_info_dict(hyp_json_list) if hyp_json_list else None | |
| all_hypothesis = get_pyannote_objs_from_rttms(hyp_rttm_list) | |
| all_reference = get_pyannote_objs_from_rttms(ref_rttm_list) | |
| diar_score = evaluate_der( | |
| audio_rttm_map_dict=audio_rttm_map_dict, | |
| all_reference=all_reference, | |
| all_hypothesis=all_hypothesis, | |
| diar_eval_mode=diar_eval_mode, | |
| ) | |
| # Get session-level diarization error rate and speaker counting error | |
| der_results = OfflineDiarWithASR.gather_eval_results( | |
| diar_score=diar_score, | |
| audio_rttm_map_dict=audio_rttm_map_dict, | |
| trans_info_dict=trans_info_dict, | |
| root_path=root_path, | |
| ) | |
| if ref_ctm_list is not None: | |
| # Calculate WER and cpWER if reference CTM files exist | |
| if hyp_ctm_list is not None: | |
| wer_results = OfflineDiarWithASR.evaluate( | |
| audio_file_list=hyp_rttm_list, | |
| hyp_trans_info_dict=None, | |
| hyp_ctm_file_list=hyp_ctm_list, | |
| ref_ctm_file_list=ref_ctm_list, | |
| ) | |
| elif hyp_json_list is not None: | |
| wer_results = OfflineDiarWithASR.evaluate( | |
| audio_file_list=hyp_rttm_list, | |
| hyp_trans_info_dict=trans_info_dict, | |
| hyp_ctm_file_list=None, | |
| ref_ctm_file_list=ref_ctm_list, | |
| ) | |
| else: | |
| raise ValueError("Hypothesis information is not provided in the correct format.") | |
| else: | |
| wer_results = {} | |
| # Print average DER, WER and cpWER | |
| OfflineDiarWithASR.print_errors(der_results=der_results, wer_results=wer_results) | |
| # Save detailed session-level evaluation results in `root_path`. | |
| OfflineDiarWithASR.write_session_level_result_in_csv( | |
| der_results=der_results, | |
| wer_results=wer_results, | |
| root_path=root_path, | |
| csv_columns=OfflineDiarWithASR.get_csv_columns(), | |
| ) | |
| return None | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--hyp_rttm_list", help="path to the filelist of hypothesis RTTM files", type=str, required=True, default=None | |
| ) | |
| parser.add_argument( | |
| "--ref_rttm_list", help="path to the filelist of reference RTTM files", type=str, required=True, default=None | |
| ) | |
| parser.add_argument( | |
| "--hyp_ctm_list", help="path to the filelist of hypothesis CTM files", type=str, required=False, default=None | |
| ) | |
| parser.add_argument( | |
| "--ref_ctm_list", help="path to the filelist of reference CTM files", type=str, required=False, default=None | |
| ) | |
| parser.add_argument( | |
| "--hyp_json_list", | |
| help="(Optional) path to the filelist of hypothesis JSON files", | |
| type=str, | |
| required=False, | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--diar_eval_mode", | |
| help='evaluation mode: "all", "full", "fair", "forgiving"', | |
| type=str, | |
| required=False, | |
| default="all", | |
| ) | |
| parser.add_argument( | |
| "--root_path", help='directory for saving result files', type=str, required=False, default="./" | |
| ) | |
| args = parser.parse_args() | |
| main( | |
| args.hyp_rttm_list, | |
| args.ref_rttm_list, | |
| args.hyp_ctm_list, | |
| args.ref_ctm_list, | |
| args.hyp_json_list, | |
| args.diar_eval_mode, | |
| args.root_path, | |
| ) | |