Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| This script creates a manifest file for diarization training. If you specify `pairwise_rttm_output_folder`, the script generates | |
| a two-speaker subset of the original RTTM files. For example, an RTTM file with 4 speakers will obtain 6 different pairs and | |
| 6 RTTM files with two speakers in each RTTM file. | |
| Args: | |
| --input_manifest_path: input json file name | |
| --output_manifest_path: output manifest_file name | |
| --pairwise_rttm_output_folder: Save two-speaker pair RTTM files | |
| --window: Window length for segmentation | |
| --shift: Shift length for segmentation | |
| --decimals: Rounding decimals | |
| """ | |
| import argparse | |
| import copy | |
| import itertools | |
| import os | |
| import random | |
| from tqdm import tqdm | |
| from nemo.collections.asr.parts.utils.manifest_utils import ( | |
| get_input_manifest_dict, | |
| get_subsegment_dict, | |
| rreplace, | |
| write_truncated_subsegments, | |
| ) | |
| from nemo.collections.asr.parts.utils.speaker_utils import ( | |
| audio_rttm_map, | |
| rttm_to_labels, | |
| segments_manifest_to_subsegments_manifest, | |
| write_rttm2manifest, | |
| ) | |
| from nemo.utils import logging | |
| random.seed(42) | |
| def labels_to_rttmfile(labels, uniq_id, filename, out_rttm_dir): | |
| """ | |
| Write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels | |
| """ | |
| filename = os.path.join(out_rttm_dir, filename + '.rttm') | |
| with open(filename, 'w') as f: | |
| for line in labels: | |
| line = line.strip() | |
| start, end, speaker = line.split() | |
| duration = float(end) - float(start) | |
| start = float(start) | |
| log = 'SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> {} <NA> <NA>\n'.format(uniq_id, start, duration, speaker) | |
| f.write(log) | |
| return filename | |
| def split_into_pairwise_rttm(audio_rttm_map, input_manifest_path, output_dir): | |
| """ | |
| Create pairwise RTTM files and save it to `output_dir`. This function picks two speakers from the original RTTM files | |
| then saves the two-speaker subset of RTTM to `output_dir`. | |
| Args: | |
| audio_rttm_map (dict): | |
| A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files | |
| input_manifest_path (str): | |
| Path of the input manifest file. | |
| output_dir (str): | |
| Path to the directory where the new RTTM files are saved. | |
| """ | |
| input_manifest_dict = get_input_manifest_dict(input_manifest_path) | |
| rttmlist = [] | |
| rttm_split_manifest_dict = {} | |
| split_audio_rttm_map = {} | |
| logging.info("Creating split RTTM files.") | |
| for uniq_id, line in tqdm(input_manifest_dict.items(), total=len(input_manifest_dict)): | |
| audiopath = line['audio_filepath'] | |
| num_speakers = line['num_speakers'] | |
| rttm_filepath = line['rttm_filepath'] | |
| rttm = rttm_to_labels(rttm_filepath) | |
| speakers = [] | |
| j = 0 | |
| while len(speakers) < num_speakers: | |
| if rttm[j].split(' ')[2] not in speakers: | |
| speakers.append(rttm[j].split(' ')[2]) | |
| j += 1 | |
| base_fn = audiopath.split('/')[-1].replace('.wav', '') | |
| for pair in itertools.combinations(speakers, 2): | |
| i, target_rttm = 0, [] | |
| while i < len(rttm): | |
| entry = rttm[i] | |
| sp_id = entry.split(' ')[2] | |
| if sp_id in pair: | |
| target_rttm.append(entry) | |
| i += 1 | |
| pair_string = f".{pair[0]}_{pair[1]}" | |
| uniq_id_pair = uniq_id + pair_string | |
| filename = base_fn + pair_string | |
| labels_to_rttmfile(target_rttm, base_fn, filename, output_dir) | |
| rttm_path = output_dir + filename + ".rttm" | |
| rttmlist.append(rttm_path) | |
| line_mod = copy.deepcopy(line) | |
| line_mod['rttm_filepath'] = rttm_path | |
| meta = copy.deepcopy(audio_rttm_map[uniq_id]) | |
| meta['rttm_filepath'] = rttm_path | |
| rttm_split_manifest_dict[uniq_id_pair] = line_mod | |
| split_audio_rttm_map[uniq_id_pair] = meta | |
| return rttm_split_manifest_dict, split_audio_rttm_map | |
| def main(input_manifest_path, output_manifest_path, pairwise_rttm_output_folder, window, shift, step_count, decimals): | |
| if '.json' not in input_manifest_path: | |
| raise ValueError("input_manifest_path file should be .json file format") | |
| if output_manifest_path and '.json' not in output_manifest_path: | |
| raise ValueError("output_manifest_path file should be .json file format") | |
| elif not output_manifest_path: | |
| output_manifest_path = rreplace(input_manifest_path, '.json', f'.{step_count}seg.json') | |
| if pairwise_rttm_output_folder is not None: | |
| if not pairwise_rttm_output_folder.endswith('/'): | |
| pairwise_rttm_output_folder = f"{pairwise_rttm_output_folder}/" | |
| org_audio_rttm_map = audio_rttm_map(input_manifest_path) | |
| input_manifest_dict, AUDIO_RTTM_MAP = split_into_pairwise_rttm( | |
| audio_rttm_map=org_audio_rttm_map, | |
| input_manifest_path=input_manifest_path, | |
| output_dir=pairwise_rttm_output_folder, | |
| ) | |
| else: | |
| input_manifest_dict = get_input_manifest_dict(input_manifest_path) | |
| AUDIO_RTTM_MAP = audio_rttm_map(input_manifest_path) | |
| segment_manifest_path = rreplace(input_manifest_path, '.json', '_seg.json') | |
| subsegment_manifest_path = rreplace(input_manifest_path, '.json', '_subseg.json') | |
| # todo: do we need to expose this? | |
| min_subsegment_duration = 0.05 | |
| step_count = int(step_count) | |
| segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, decimals) | |
| subsegments_manifest_file = subsegment_manifest_path | |
| logging.info("Creating subsegments.") | |
| segments_manifest_to_subsegments_manifest( | |
| segments_manifest_file=segments_manifest_file, | |
| subsegments_manifest_file=subsegments_manifest_file, | |
| window=window, | |
| shift=shift, | |
| min_subsegment_duration=min_subsegment_duration, | |
| include_uniq_id=True, | |
| ) | |
| subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, decimals) | |
| write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, decimals) | |
| os.remove(segment_manifest_path) | |
| os.remove(subsegment_manifest_path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input_manifest_path", help="input json file name", type=str, required=True) | |
| parser.add_argument( | |
| "--output_manifest_path", help="output manifest_file name", type=str, default=None, required=False | |
| ) | |
| parser.add_argument( | |
| "--pairwise_rttm_output_folder", | |
| help="Save two-speaker pair RTTM files", | |
| type=str, | |
| default=None, | |
| required=False, | |
| ) | |
| parser.add_argument("--window", help="Window length for segmentation", type=float, required=True) | |
| parser.add_argument("--shift", help="Shift length for segmentation", type=float, required=True) | |
| parser.add_argument("--decimals", help="Rounding decimals", type=int, default=3, required=False) | |
| parser.add_argument( | |
| "--step_count", help="Number of the unit segments you want to create per utterance", required=True, | |
| ) | |
| args = parser.parse_args() | |
| main( | |
| args.input_manifest_path, | |
| args.output_manifest_path, | |
| args.pairwise_rttm_output_folder, | |
| args.window, | |
| args.shift, | |
| args.step_count, | |
| args.decimals, | |
| ) | |