# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ MagpieTTS Inference and Evaluation Script. This script provides a clean CLI for running MagpieTTS inference with optional evaluation. It decouples inference and evaluation into separate modules for better maintainability. Example usage: # Inference only (from .nemo file) - default behavior python examples/tts/magpietts_inference.py \\ --nemo_files /path/to/model.nemo \\ --datasets libritts_test_clean \\ --out_dir /path/to/output \\ --codecmodel_path /path/to/codec.nemo # Inference with evaluation (from checkpoint) python examples/tts/magpietts_inference.py \\ --hparams_files /path/to/hparams.yaml \\ --checkpoint_files /path/to/model.ckpt \\ --datasets libritts_test_clean,vctk \\ --out_dir /path/to/output \\ --codecmodel_path /path/to/codec.nemo \\ --run_evaluation \\ --num_repeats 3 """ from __future__ import annotations import argparse import copy import json import logging import os import shutil from pathlib import Path from typing import List, Optional, Tuple import numpy as np # Import dataset configuration import nemo.collections.tts.modules.magpietts_inference.evalset_config as evalset_config from nemo.collections.asr.parts.utils.manifest_utils import read_manifest # Import the modular components from nemo.collections.tts.modules.magpietts_inference.evaluation import ( DEFAULT_VIOLIN_METRICS, STANDARD_METRIC_KEYS, EvaluationConfig, compute_mean_with_confidence_interval, evaluate_generated_audio_dir, ) from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner from nemo.collections.tts.modules.magpietts_inference.utils import ( ModelLoadConfig, get_experiment_name_from_checkpoint_path, load_magpie_model, ) from nemo.collections.tts.modules.magpietts_inference.visualization import create_combined_box_plot, create_violin_plot # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', ) logger = logging.getLogger(__name__) # Default evaluation datasets EVALUATION_DATASETS = ( "riva_hard_digits,riva_hard_letters,riva_hard_money,riva_hard_short,vctk,libritts_seen,libritts_test_clean" ) def parse_layer_list(layer_str: Optional[str]) -> Optional[List[int]]: """Parse a comma-separated list of layer indices.""" if layer_str is None: return None return [int(l.strip()) for l in layer_str.split(",")] def write_csv_header_if_needed(csv_path: str, header: str) -> None: """Write CSV header if file doesn't exist.""" if not os.path.exists(csv_path): with open(csv_path, "w") as f: f.write(header + "\n") def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, metrics: dict) -> None: """Append metrics to a CSV file.""" values = [ checkpoint_name, dataset, metrics.get('cer_filewise_avg', ''), metrics.get('wer_filewise_avg', ''), metrics.get('cer_cumulative', ''), metrics.get('wer_cumulative', ''), metrics.get('ssim_pred_gt_avg', ''), metrics.get('ssim_pred_context_avg', ''), metrics.get('ssim_gt_context_avg', ''), metrics.get('ssim_pred_gt_avg_alternate', ''), metrics.get('ssim_pred_context_avg_alternate', ''), metrics.get('ssim_gt_context_avg_alternate', ''), metrics.get('cer_gt_audio_cumulative', ''), metrics.get('wer_gt_audio_cumulative', ''), metrics.get('utmosv2_avg', ''), metrics.get('total_gen_audio_seconds', ''), ] with open(csv_path, "a") as f: f.write(",".join(str(v) for v in values) + "\n") logger.info(f"Metrics appended to: {csv_path}") def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: """Create formatted metrics mean CI.""" for k, v in metrics_mean_ci.items(): if isinstance(v, list): mean, ci = float(v[0]), float(v[1]) logging.info(f"Metric {k}: {mean:.4f} ± {ci:.4f}") metrics_mean_ci[k] = f"{mean:.4f} ± {ci:.4f}" return metrics_mean_ci def run_inference_and_evaluation( model_config: ModelLoadConfig, inference_config: InferenceConfig, eval_config: EvaluationConfig, datasets: List[str], out_dir: str, num_repeats: int = 1, confidence_level: float = 0.95, violin_plot_metrics: Optional[List[str]] = None, log_exp_name: bool = False, clean_up_disk: bool = False, skip_evaluation: bool = False, ) -> Tuple[Optional[float], Optional[float]]: """Run inference and optional evaluation on specified datasets. Args: model_config: Configuration for loading the model. inference_config: Configuration for inference. eval_config: Configuration for evaluation. datasets: List of dataset names to evaluate. out_dir: Output directory for results. num_repeats: Number of times to repeat inference (for CI estimation). confidence_level: Confidence level for CI calculation. violin_plot_metrics: Metrics to include in violin plots. log_exp_name: Whether to include experiment name in output paths. clean_up_disk: Whether to clean up output directory after completion. skip_evaluation: Whether to skip evaluation (inference only mode). Returns: Tuple of (mean CER across datasets, mean SSIM across datasets). """ if violin_plot_metrics is None: violin_plot_metrics = list(DEFAULT_VIOLIN_METRICS) # Remove UTMOSv2 from plots if disabled if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: violin_plot_metrics.remove('utmosv2') # Load model model, checkpoint_name = load_magpie_model(model_config) # Add experiment name prefix if requested if log_exp_name and model_config.checkpoint_file: exp_name = get_experiment_name_from_checkpoint_path(model_config.checkpoint_file) checkpoint_name = f"{exp_name}__{checkpoint_name}" # Build full checkpoint identifier full_checkpoint_name = f"{checkpoint_name}_{inference_config.build_identifier()}_SV_{eval_config.sv_model}" # Create inference runner runner = MagpieInferenceRunner(model, inference_config) # Tracking metrics across datasets dataset_meta_info = evalset_config.dataset_meta_info ssim_per_dataset = [] cer_per_dataset = [] all_datasets_filewise_metrics = {} # CSV headers csv_header = ( "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," "utmosv2_avg,total_gen_audio_seconds" ) for dataset in datasets: logger.info(f"Processing dataset: {dataset}") if dataset not in dataset_meta_info: logger.warning(f"Dataset '{dataset}' not found in evalset_config, skipping.") continue meta = dataset_meta_info[dataset] manifest_records = read_manifest(meta['manifest_path']) language = meta.get('whisper_language', 'en') # Prepare dataset metadata (remove evaluation-specific keys) dataset_meta_for_dl = copy.deepcopy(meta) for key in ["whisper_language", "load_cached_codes_if_available"]: dataset_meta_for_dl.pop(key, None) # Setup output directories eval_dir = os.path.join(out_dir, f"{full_checkpoint_name}_{dataset}") audio_dir = os.path.join(eval_dir, "audio") os.makedirs(eval_dir, exist_ok=True) # Setup CSV files per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") write_csv_header_if_needed(per_run_csv, csv_header) metrics_all_repeats = [] filewise_metrics_all_repeats = [] for repeat_idx in range(num_repeats): logger.info(f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}") repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(repeat_audio_dir, exist_ok=True) # Create dataset and run inference test_dataset = runner.create_dataset({dataset: dataset_meta_for_dl}) if len(test_dataset) != len(manifest_records): raise ValueError( f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" ) rtf_metrics_list, generated_paths = runner.run_inference_on_dataset( dataset=test_dataset, output_dir=repeat_audio_dir, manifest_records=manifest_records, audio_base_dir=meta['audio_dir'], save_cross_attention_maps=True, save_context_audio=(repeat_idx == 0), # Only save context audio once ) # Compute mean RTF metrics mean_rtf = runner.compute_mean_rtf_metrics(rtf_metrics_list) with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: json.dump(mean_rtf, f, indent=4) if skip_evaluation: logger.info("Skipping evaluation as requested.") continue # Run evaluation eval_config_for_dataset = EvaluationConfig( sv_model=eval_config.sv_model, asr_model_name=eval_config.asr_model_name, language=language, with_utmosv2=eval_config.with_utmosv2, ) metrics, filewise_metrics = evaluate_generated_audio_dir( manifest_path=meta['manifest_path'], audio_dir=meta['audio_dir'], generated_audio_dir=repeat_audio_dir, config=eval_config_for_dataset, ) metrics_all_repeats.append(metrics) filewise_metrics_all_repeats.extend(filewise_metrics) # Save metrics with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: json.dump(metrics, f, indent=4) with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: json.dump(filewise_metrics, f, indent=4) # Append to per-run CSV append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) # Create violin plot for this repeat violin_path = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png" create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path) if skip_evaluation or not metrics_all_repeats: continue # Store for combined plot all_datasets_filewise_metrics[dataset] = filewise_metrics_all_repeats # Compute mean with confidence interval across repeats metrics_mean_ci = compute_mean_with_confidence_interval( metrics_all_repeats, STANDARD_METRIC_KEYS, confidence=confidence_level, ) formatted_metrics_mean_ci = create_formatted_metrics_mean_ci(metrics_mean_ci) # Write to aggregated CSV ci_csv = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") write_csv_header_if_needed(ci_csv, csv_header) append_metrics_to_csv(ci_csv, full_checkpoint_name, dataset, formatted_metrics_mean_ci) # Track per-dataset means ssim_values = [m['ssim_pred_context_avg'] for m in metrics_all_repeats] cer_values = [m['cer_cumulative'] for m in metrics_all_repeats] ssim_per_dataset.append(np.mean(ssim_values)) cer_per_dataset.append(np.mean(cer_values)) # Create combined plot if we have multiple datasets if len(all_datasets_filewise_metrics) > 1: combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) # Clean up if requested if clean_up_disk: logger.info(f"Cleaning up output directory: {out_dir}") shutil.rmtree(out_dir) # Return averaged metrics if ssim_per_dataset and cer_per_dataset: return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) return None, None def create_argument_parser() -> argparse.ArgumentParser: """Create the CLI argument parser.""" parser = argparse.ArgumentParser( description='MagpieTTS Inference and Evaluation', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) # Model loading arguments model_group = parser.add_argument_group('Model Loading') model_group.add_argument( '--hparams_files', type=str, default=None, help='Comma-separated paths to hparams.yaml files (use with --checkpoint_files)', ) model_group.add_argument( '--checkpoint_files', type=str, default=None, help='Comma-separated paths to .ckpt files (use with --hparams_files)', ) model_group.add_argument( '--nemo_files', type=str, default=None, help='Comma-separated paths to .nemo files (alternative to hparams + checkpoint)', ) model_group.add_argument( '--codecmodel_path', type=str, required=True, help='Path to the audio codec model', ) model_group.add_argument( '--hparams_file_from_wandb', action='store_true', help='Set if hparams file was exported from wandb', ) model_group.add_argument( '--legacy_codebooks', action='store_true', help='Use legacy codebook indices (for old checkpoints)', ) model_group.add_argument( '--legacy_text_conditioning', action='store_true', help='Use legacy text conditioning (for old checkpoints)', ) # Dataset and output arguments data_group = parser.add_argument_group('Dataset and Output') data_group.add_argument( '--datasets', type=str, default=None, help=f'Comma-separated dataset names (default: {EVALUATION_DATASETS})', ) data_group.add_argument( '--out_dir', type=str, required=True, help='Output directory for generated audio and metrics', ) data_group.add_argument( '--log_exp_name', action='store_true', help='Include experiment name in output folder name', ) data_group.add_argument( '--clean_up_disk', action='store_true', help='Delete output directory after completion', ) # Inference arguments infer_group = parser.add_argument_group('Inference Parameters') infer_group.add_argument('--temperature', type=float, default=0.6) infer_group.add_argument('--topk', type=int, default=80) infer_group.add_argument('--batch_size', type=int, default=32) infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance') infer_group.add_argument('--cfg_scale', type=float, default=2.5) # Attention prior arguments prior_group = parser.add_argument_group('Attention Prior') prior_group.add_argument('--apply_attention_prior', action='store_true') prior_group.add_argument('--attention_prior_epsilon', type=float, default=0.1) prior_group.add_argument('--attention_prior_lookahead_window', type=int, default=5) prior_group.add_argument( '--estimate_alignment_from_layers', type=str, default=None, help='Comma-separated layer indices for alignment estimation', ) prior_group.add_argument( '--apply_prior_to_layers', type=str, default=None, help='Comma-separated layer indices to apply prior', ) prior_group.add_argument('--start_prior_after_n_audio_steps', type=int, default=0) # Local transformer / MaskGit arguments lt_group = parser.add_argument_group('Local Transformer / MaskGit') lt_group.add_argument('--use_local_transformer', action='store_true') lt_group.add_argument('--maskgit_n_steps', type=int, default=3) lt_group.add_argument('--maskgit_noise_scale', type=float, default=0.0) lt_group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) lt_group.add_argument( '--maskgit_sampling_type', default=None, choices=["default", "causal", "purity_causal", "purity_default"], ) # EOS detection eos_group = parser.add_argument_group('EOS Detection') eos_group.add_argument( '--eos_detection_method', type=str, default="argmax_or_multinomial_any", choices=[ "argmax_any", "argmax_or_multinomial_any", "argmax_all", "argmax_or_multinomial_all", "argmax_zero_cb", "argmax_or_multinomial_zero_cb", ], ) eos_group.add_argument('--ignore_finished_sentence_tracking', action='store_true') # Evaluation arguments eval_group = parser.add_argument_group('Evaluation') eval_group.add_argument( '--run_evaluation', action='store_true', help='Run evaluation after inference (default: False, inference only)', ) eval_group.add_argument('--sv_model', type=str, default="titanet", choices=["titanet", "wavlm"]) eval_group.add_argument('--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b") eval_group.add_argument('--num_repeats', type=int, default=1) eval_group.add_argument('--confidence_level', type=float, default=0.95) eval_group.add_argument('--disable_utmosv2', action='store_true') eval_group.add_argument( '--violin_plot_metrics', type=str, nargs='*', default=['cer', 'pred_context_ssim', 'utmosv2'], ) # Quality targets (for CI/CD) target_group = parser.add_argument_group('Quality Targets') target_group.add_argument('--cer_target', type=float, default=None) target_group.add_argument('--ssim_target', type=float, default=None) return parser def main(): """Main entry point.""" parser = create_argument_parser() args = parser.parse_args() # Set default datasets if not provided if args.datasets is None: args.datasets = EVALUATION_DATASETS datasets = args.datasets.split(",") # Determine mode and validate has_checkpoint_mode = ( args.hparams_files is not None and args.checkpoint_files is not None and args.hparams_files != "null" and args.checkpoint_files != "null" ) has_nemo_mode = args.nemo_files is not None and args.nemo_files != "null" if not has_checkpoint_mode and not has_nemo_mode: parser.error("You must provide either:\n" " 1. --hparams_files and --checkpoint_files\n" " 2. --nemo_files") # Build configurations inference_config = InferenceConfig( temperature=args.temperature, topk=args.topk, batch_size=args.batch_size, use_cfg=args.use_cfg, cfg_scale=args.cfg_scale, apply_attention_prior=args.apply_attention_prior, attention_prior_epsilon=args.attention_prior_epsilon, attention_prior_lookahead_window=args.attention_prior_lookahead_window, estimate_alignment_from_layers=parse_layer_list(args.estimate_alignment_from_layers), apply_prior_to_layers=parse_layer_list(args.apply_prior_to_layers), start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, maskgit_noise_scale=args.maskgit_noise_scale, maskgit_fixed_schedule=args.maskgit_fixed_schedule, maskgit_sampling_type=args.maskgit_sampling_type, eos_detection_method=args.eos_detection_method, ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, ) eval_config = EvaluationConfig( sv_model=args.sv_model, asr_model_name=args.asr_model_name, with_utmosv2=not args.disable_utmosv2, ) cer, ssim = None, None # Run for each model (checkpoint or nemo) if has_checkpoint_mode: hparam_files = args.hparams_files.split(",") checkpoint_files = args.checkpoint_files.split(",") if len(hparam_files) != len(checkpoint_files): parser.error("Number of hparams_files must match number of checkpoint_files") for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): logger.info(f"Processing checkpoint: {checkpoint_file}") model_config = ModelLoadConfig( hparams_file=hparams_file, checkpoint_file=checkpoint_file, codecmodel_path=args.codecmodel_path, legacy_codebooks=args.legacy_codebooks, legacy_text_conditioning=args.legacy_text_conditioning, hparams_from_wandb=args.hparams_file_from_wandb, ) cer, ssim = run_inference_and_evaluation( model_config=model_config, inference_config=inference_config, eval_config=eval_config, datasets=datasets, out_dir=args.out_dir, num_repeats=args.num_repeats, confidence_level=args.confidence_level, violin_plot_metrics=args.violin_plot_metrics, log_exp_name=args.log_exp_name, clean_up_disk=args.clean_up_disk, skip_evaluation=not args.run_evaluation, ) else: # nemo mode for nemo_file in args.nemo_files.split(","): logger.info(f"Processing NeMo file: {nemo_file}") model_config = ModelLoadConfig( nemo_file=nemo_file, codecmodel_path=args.codecmodel_path, legacy_codebooks=args.legacy_codebooks, legacy_text_conditioning=args.legacy_text_conditioning, ) cer, ssim = run_inference_and_evaluation( model_config=model_config, inference_config=inference_config, eval_config=eval_config, datasets=datasets, out_dir=args.out_dir, num_repeats=args.num_repeats, confidence_level=args.confidence_level, violin_plot_metrics=args.violin_plot_metrics, log_exp_name=args.log_exp_name, clean_up_disk=args.clean_up_disk, skip_evaluation=not args.run_evaluation, ) # Check quality targets if cer is not None and args.cer_target is not None: if cer > args.cer_target: raise ValueError(f"CER {cer:.4f} exceeds target {args.cer_target:.4f}") logger.info(f"CER {cer:.4f} meets target {args.cer_target:.4f}") if ssim is not None and args.ssim_target is not None: if ssim < args.ssim_target: raise ValueError(f"SSIM {ssim:.4f} below target {args.ssim_target:.4f}") logger.info(f"SSIM {ssim:.4f} meets target {args.ssim_target:.4f}") logger.info("Inference and evaluation completed successfully.") if __name__ == '__main__': main()