Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| MagpieTTS Inference and Evaluation Script. | |
| This script provides a clean CLI for running MagpieTTS inference with optional evaluation. | |
| It decouples inference and evaluation into separate modules for better maintainability. | |
| Example usage: | |
| # Inference only (from .nemo file) - default behavior | |
| python examples/tts/magpietts_inference.py \\ | |
| --nemo_files /path/to/model.nemo \\ | |
| --datasets libritts_test_clean \\ | |
| --out_dir /path/to/output \\ | |
| --codecmodel_path /path/to/codec.nemo | |
| # Inference with evaluation (from checkpoint) | |
| python examples/tts/magpietts_inference.py \\ | |
| --hparams_files /path/to/hparams.yaml \\ | |
| --checkpoint_files /path/to/model.ckpt \\ | |
| --datasets libritts_test_clean,vctk \\ | |
| --out_dir /path/to/output \\ | |
| --codecmodel_path /path/to/codec.nemo \\ | |
| --run_evaluation \\ | |
| --num_repeats 3 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import copy | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| # Import dataset configuration | |
| import nemo.collections.tts.modules.magpietts_inference.evalset_config as evalset_config | |
| from nemo.collections.asr.parts.utils.manifest_utils import read_manifest | |
| # Import the modular components | |
| from nemo.collections.tts.modules.magpietts_inference.evaluation import ( | |
| DEFAULT_VIOLIN_METRICS, | |
| STANDARD_METRIC_KEYS, | |
| EvaluationConfig, | |
| compute_mean_with_confidence_interval, | |
| evaluate_generated_audio_dir, | |
| ) | |
| from nemo.collections.tts.modules.magpietts_inference.inference import InferenceConfig, MagpieInferenceRunner | |
| from nemo.collections.tts.modules.magpietts_inference.utils import ( | |
| ModelLoadConfig, | |
| get_experiment_name_from_checkpoint_path, | |
| load_magpie_model, | |
| ) | |
| from nemo.collections.tts.modules.magpietts_inference.visualization import create_combined_box_plot, create_violin_plot | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Default evaluation datasets | |
| EVALUATION_DATASETS = ( | |
| "riva_hard_digits,riva_hard_letters,riva_hard_money,riva_hard_short,vctk,libritts_seen,libritts_test_clean" | |
| ) | |
| def parse_layer_list(layer_str: Optional[str]) -> Optional[List[int]]: | |
| """Parse a comma-separated list of layer indices.""" | |
| if layer_str is None: | |
| return None | |
| return [int(l.strip()) for l in layer_str.split(",")] | |
| def write_csv_header_if_needed(csv_path: str, header: str) -> None: | |
| """Write CSV header if file doesn't exist.""" | |
| if not os.path.exists(csv_path): | |
| with open(csv_path, "w") as f: | |
| f.write(header + "\n") | |
| def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, metrics: dict) -> None: | |
| """Append metrics to a CSV file.""" | |
| values = [ | |
| checkpoint_name, | |
| dataset, | |
| metrics.get('cer_filewise_avg', ''), | |
| metrics.get('wer_filewise_avg', ''), | |
| metrics.get('cer_cumulative', ''), | |
| metrics.get('wer_cumulative', ''), | |
| metrics.get('ssim_pred_gt_avg', ''), | |
| metrics.get('ssim_pred_context_avg', ''), | |
| metrics.get('ssim_gt_context_avg', ''), | |
| metrics.get('ssim_pred_gt_avg_alternate', ''), | |
| metrics.get('ssim_pred_context_avg_alternate', ''), | |
| metrics.get('ssim_gt_context_avg_alternate', ''), | |
| metrics.get('cer_gt_audio_cumulative', ''), | |
| metrics.get('wer_gt_audio_cumulative', ''), | |
| metrics.get('utmosv2_avg', ''), | |
| metrics.get('total_gen_audio_seconds', ''), | |
| ] | |
| with open(csv_path, "a") as f: | |
| f.write(",".join(str(v) for v in values) + "\n") | |
| logger.info(f"Metrics appended to: {csv_path}") | |
| def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: | |
| """Create formatted metrics mean CI.""" | |
| for k, v in metrics_mean_ci.items(): | |
| if isinstance(v, list): | |
| mean, ci = float(v[0]), float(v[1]) | |
| logging.info(f"Metric {k}: {mean:.4f} ± {ci:.4f}") | |
| metrics_mean_ci[k] = f"{mean:.4f} ± {ci:.4f}" | |
| return metrics_mean_ci | |
| def run_inference_and_evaluation( | |
| model_config: ModelLoadConfig, | |
| inference_config: InferenceConfig, | |
| eval_config: EvaluationConfig, | |
| datasets: List[str], | |
| out_dir: str, | |
| num_repeats: int = 1, | |
| confidence_level: float = 0.95, | |
| violin_plot_metrics: Optional[List[str]] = None, | |
| log_exp_name: bool = False, | |
| clean_up_disk: bool = False, | |
| skip_evaluation: bool = False, | |
| ) -> Tuple[Optional[float], Optional[float]]: | |
| """Run inference and optional evaluation on specified datasets. | |
| Args: | |
| model_config: Configuration for loading the model. | |
| inference_config: Configuration for inference. | |
| eval_config: Configuration for evaluation. | |
| datasets: List of dataset names to evaluate. | |
| out_dir: Output directory for results. | |
| num_repeats: Number of times to repeat inference (for CI estimation). | |
| confidence_level: Confidence level for CI calculation. | |
| violin_plot_metrics: Metrics to include in violin plots. | |
| log_exp_name: Whether to include experiment name in output paths. | |
| clean_up_disk: Whether to clean up output directory after completion. | |
| skip_evaluation: Whether to skip evaluation (inference only mode). | |
| Returns: | |
| Tuple of (mean CER across datasets, mean SSIM across datasets). | |
| """ | |
| if violin_plot_metrics is None: | |
| violin_plot_metrics = list(DEFAULT_VIOLIN_METRICS) | |
| # Remove UTMOSv2 from plots if disabled | |
| if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: | |
| violin_plot_metrics.remove('utmosv2') | |
| # Load model | |
| model, checkpoint_name = load_magpie_model(model_config) | |
| # Add experiment name prefix if requested | |
| if log_exp_name and model_config.checkpoint_file: | |
| exp_name = get_experiment_name_from_checkpoint_path(model_config.checkpoint_file) | |
| checkpoint_name = f"{exp_name}__{checkpoint_name}" | |
| # Build full checkpoint identifier | |
| full_checkpoint_name = f"{checkpoint_name}_{inference_config.build_identifier()}_SV_{eval_config.sv_model}" | |
| # Create inference runner | |
| runner = MagpieInferenceRunner(model, inference_config) | |
| # Tracking metrics across datasets | |
| dataset_meta_info = evalset_config.dataset_meta_info | |
| ssim_per_dataset = [] | |
| cer_per_dataset = [] | |
| all_datasets_filewise_metrics = {} | |
| # CSV headers | |
| csv_header = ( | |
| "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," | |
| "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," | |
| "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," | |
| "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," | |
| "utmosv2_avg,total_gen_audio_seconds" | |
| ) | |
| for dataset in datasets: | |
| logger.info(f"Processing dataset: {dataset}") | |
| if dataset not in dataset_meta_info: | |
| logger.warning(f"Dataset '{dataset}' not found in evalset_config, skipping.") | |
| continue | |
| meta = dataset_meta_info[dataset] | |
| manifest_records = read_manifest(meta['manifest_path']) | |
| language = meta.get('whisper_language', 'en') | |
| # Prepare dataset metadata (remove evaluation-specific keys) | |
| dataset_meta_for_dl = copy.deepcopy(meta) | |
| for key in ["whisper_language", "load_cached_codes_if_available"]: | |
| dataset_meta_for_dl.pop(key, None) | |
| # Setup output directories | |
| eval_dir = os.path.join(out_dir, f"{full_checkpoint_name}_{dataset}") | |
| audio_dir = os.path.join(eval_dir, "audio") | |
| os.makedirs(eval_dir, exist_ok=True) | |
| # Setup CSV files | |
| per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") | |
| write_csv_header_if_needed(per_run_csv, csv_header) | |
| metrics_all_repeats = [] | |
| filewise_metrics_all_repeats = [] | |
| for repeat_idx in range(num_repeats): | |
| logger.info(f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}") | |
| repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") | |
| os.makedirs(repeat_audio_dir, exist_ok=True) | |
| # Create dataset and run inference | |
| test_dataset = runner.create_dataset({dataset: dataset_meta_for_dl}) | |
| if len(test_dataset) != len(manifest_records): | |
| raise ValueError( | |
| f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" | |
| ) | |
| rtf_metrics_list, generated_paths = runner.run_inference_on_dataset( | |
| dataset=test_dataset, | |
| output_dir=repeat_audio_dir, | |
| manifest_records=manifest_records, | |
| audio_base_dir=meta['audio_dir'], | |
| save_cross_attention_maps=True, | |
| save_context_audio=(repeat_idx == 0), # Only save context audio once | |
| ) | |
| # Compute mean RTF metrics | |
| mean_rtf = runner.compute_mean_rtf_metrics(rtf_metrics_list) | |
| with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: | |
| json.dump(mean_rtf, f, indent=4) | |
| if skip_evaluation: | |
| logger.info("Skipping evaluation as requested.") | |
| continue | |
| # Run evaluation | |
| eval_config_for_dataset = EvaluationConfig( | |
| sv_model=eval_config.sv_model, | |
| asr_model_name=eval_config.asr_model_name, | |
| language=language, | |
| with_utmosv2=eval_config.with_utmosv2, | |
| ) | |
| metrics, filewise_metrics = evaluate_generated_audio_dir( | |
| manifest_path=meta['manifest_path'], | |
| audio_dir=meta['audio_dir'], | |
| generated_audio_dir=repeat_audio_dir, | |
| config=eval_config_for_dataset, | |
| ) | |
| metrics_all_repeats.append(metrics) | |
| filewise_metrics_all_repeats.extend(filewise_metrics) | |
| # Save metrics | |
| with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w") as f: | |
| json.dump(metrics, f, indent=4) | |
| with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: | |
| json.dump(filewise_metrics, f, indent=4) | |
| # Append to per-run CSV | |
| append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) | |
| # Create violin plot for this repeat | |
| violin_path = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png" | |
| create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path) | |
| if skip_evaluation or not metrics_all_repeats: | |
| continue | |
| # Store for combined plot | |
| all_datasets_filewise_metrics[dataset] = filewise_metrics_all_repeats | |
| # Compute mean with confidence interval across repeats | |
| metrics_mean_ci = compute_mean_with_confidence_interval( | |
| metrics_all_repeats, | |
| STANDARD_METRIC_KEYS, | |
| confidence=confidence_level, | |
| ) | |
| formatted_metrics_mean_ci = create_formatted_metrics_mean_ci(metrics_mean_ci) | |
| # Write to aggregated CSV | |
| ci_csv = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") | |
| write_csv_header_if_needed(ci_csv, csv_header) | |
| append_metrics_to_csv(ci_csv, full_checkpoint_name, dataset, formatted_metrics_mean_ci) | |
| # Track per-dataset means | |
| ssim_values = [m['ssim_pred_context_avg'] for m in metrics_all_repeats] | |
| cer_values = [m['cer_cumulative'] for m in metrics_all_repeats] | |
| ssim_per_dataset.append(np.mean(ssim_values)) | |
| cer_per_dataset.append(np.mean(cer_values)) | |
| # Create combined plot if we have multiple datasets | |
| if len(all_datasets_filewise_metrics) > 1: | |
| combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") | |
| create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) | |
| # Clean up if requested | |
| if clean_up_disk: | |
| logger.info(f"Cleaning up output directory: {out_dir}") | |
| shutil.rmtree(out_dir) | |
| # Return averaged metrics | |
| if ssim_per_dataset and cer_per_dataset: | |
| return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) | |
| return None, None | |
| def create_argument_parser() -> argparse.ArgumentParser: | |
| """Create the CLI argument parser.""" | |
| parser = argparse.ArgumentParser( | |
| description='MagpieTTS Inference and Evaluation', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| # Model loading arguments | |
| model_group = parser.add_argument_group('Model Loading') | |
| model_group.add_argument( | |
| '--hparams_files', | |
| type=str, | |
| default=None, | |
| help='Comma-separated paths to hparams.yaml files (use with --checkpoint_files)', | |
| ) | |
| model_group.add_argument( | |
| '--checkpoint_files', | |
| type=str, | |
| default=None, | |
| help='Comma-separated paths to .ckpt files (use with --hparams_files)', | |
| ) | |
| model_group.add_argument( | |
| '--nemo_files', | |
| type=str, | |
| default=None, | |
| help='Comma-separated paths to .nemo files (alternative to hparams + checkpoint)', | |
| ) | |
| model_group.add_argument( | |
| '--codecmodel_path', | |
| type=str, | |
| required=True, | |
| help='Path to the audio codec model', | |
| ) | |
| model_group.add_argument( | |
| '--hparams_file_from_wandb', | |
| action='store_true', | |
| help='Set if hparams file was exported from wandb', | |
| ) | |
| model_group.add_argument( | |
| '--legacy_codebooks', | |
| action='store_true', | |
| help='Use legacy codebook indices (for old checkpoints)', | |
| ) | |
| model_group.add_argument( | |
| '--legacy_text_conditioning', | |
| action='store_true', | |
| help='Use legacy text conditioning (for old checkpoints)', | |
| ) | |
| # Dataset and output arguments | |
| data_group = parser.add_argument_group('Dataset and Output') | |
| data_group.add_argument( | |
| '--datasets', | |
| type=str, | |
| default=None, | |
| help=f'Comma-separated dataset names (default: {EVALUATION_DATASETS})', | |
| ) | |
| data_group.add_argument( | |
| '--out_dir', | |
| type=str, | |
| required=True, | |
| help='Output directory for generated audio and metrics', | |
| ) | |
| data_group.add_argument( | |
| '--log_exp_name', | |
| action='store_true', | |
| help='Include experiment name in output folder name', | |
| ) | |
| data_group.add_argument( | |
| '--clean_up_disk', | |
| action='store_true', | |
| help='Delete output directory after completion', | |
| ) | |
| # Inference arguments | |
| infer_group = parser.add_argument_group('Inference Parameters') | |
| infer_group.add_argument('--temperature', type=float, default=0.6) | |
| infer_group.add_argument('--topk', type=int, default=80) | |
| infer_group.add_argument('--batch_size', type=int, default=32) | |
| infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance') | |
| infer_group.add_argument('--cfg_scale', type=float, default=2.5) | |
| # Attention prior arguments | |
| prior_group = parser.add_argument_group('Attention Prior') | |
| prior_group.add_argument('--apply_attention_prior', action='store_true') | |
| prior_group.add_argument('--attention_prior_epsilon', type=float, default=0.1) | |
| prior_group.add_argument('--attention_prior_lookahead_window', type=int, default=5) | |
| prior_group.add_argument( | |
| '--estimate_alignment_from_layers', | |
| type=str, | |
| default=None, | |
| help='Comma-separated layer indices for alignment estimation', | |
| ) | |
| prior_group.add_argument( | |
| '--apply_prior_to_layers', | |
| type=str, | |
| default=None, | |
| help='Comma-separated layer indices to apply prior', | |
| ) | |
| prior_group.add_argument('--start_prior_after_n_audio_steps', type=int, default=0) | |
| # Local transformer / MaskGit arguments | |
| lt_group = parser.add_argument_group('Local Transformer / MaskGit') | |
| lt_group.add_argument('--use_local_transformer', action='store_true') | |
| lt_group.add_argument('--maskgit_n_steps', type=int, default=3) | |
| lt_group.add_argument('--maskgit_noise_scale', type=float, default=0.0) | |
| lt_group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) | |
| lt_group.add_argument( | |
| '--maskgit_sampling_type', | |
| default=None, | |
| choices=["default", "causal", "purity_causal", "purity_default"], | |
| ) | |
| # EOS detection | |
| eos_group = parser.add_argument_group('EOS Detection') | |
| eos_group.add_argument( | |
| '--eos_detection_method', | |
| type=str, | |
| default="argmax_or_multinomial_any", | |
| choices=[ | |
| "argmax_any", | |
| "argmax_or_multinomial_any", | |
| "argmax_all", | |
| "argmax_or_multinomial_all", | |
| "argmax_zero_cb", | |
| "argmax_or_multinomial_zero_cb", | |
| ], | |
| ) | |
| eos_group.add_argument('--ignore_finished_sentence_tracking', action='store_true') | |
| # Evaluation arguments | |
| eval_group = parser.add_argument_group('Evaluation') | |
| eval_group.add_argument( | |
| '--run_evaluation', | |
| action='store_true', | |
| help='Run evaluation after inference (default: False, inference only)', | |
| ) | |
| eval_group.add_argument('--sv_model', type=str, default="titanet", choices=["titanet", "wavlm"]) | |
| eval_group.add_argument('--asr_model_name', type=str, default="nvidia/parakeet-tdt-1.1b") | |
| eval_group.add_argument('--num_repeats', type=int, default=1) | |
| eval_group.add_argument('--confidence_level', type=float, default=0.95) | |
| eval_group.add_argument('--disable_utmosv2', action='store_true') | |
| eval_group.add_argument( | |
| '--violin_plot_metrics', | |
| type=str, | |
| nargs='*', | |
| default=['cer', 'pred_context_ssim', 'utmosv2'], | |
| ) | |
| # Quality targets (for CI/CD) | |
| target_group = parser.add_argument_group('Quality Targets') | |
| target_group.add_argument('--cer_target', type=float, default=None) | |
| target_group.add_argument('--ssim_target', type=float, default=None) | |
| return parser | |
| def main(): | |
| """Main entry point.""" | |
| parser = create_argument_parser() | |
| args = parser.parse_args() | |
| # Set default datasets if not provided | |
| if args.datasets is None: | |
| args.datasets = EVALUATION_DATASETS | |
| datasets = args.datasets.split(",") | |
| # Determine mode and validate | |
| has_checkpoint_mode = ( | |
| args.hparams_files is not None | |
| and args.checkpoint_files is not None | |
| and args.hparams_files != "null" | |
| and args.checkpoint_files != "null" | |
| ) | |
| has_nemo_mode = args.nemo_files is not None and args.nemo_files != "null" | |
| if not has_checkpoint_mode and not has_nemo_mode: | |
| parser.error("You must provide either:\n" " 1. --hparams_files and --checkpoint_files\n" " 2. --nemo_files") | |
| # Build configurations | |
| inference_config = InferenceConfig( | |
| temperature=args.temperature, | |
| topk=args.topk, | |
| batch_size=args.batch_size, | |
| use_cfg=args.use_cfg, | |
| cfg_scale=args.cfg_scale, | |
| apply_attention_prior=args.apply_attention_prior, | |
| attention_prior_epsilon=args.attention_prior_epsilon, | |
| attention_prior_lookahead_window=args.attention_prior_lookahead_window, | |
| estimate_alignment_from_layers=parse_layer_list(args.estimate_alignment_from_layers), | |
| apply_prior_to_layers=parse_layer_list(args.apply_prior_to_layers), | |
| start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, | |
| use_local_transformer=args.use_local_transformer, | |
| maskgit_n_steps=args.maskgit_n_steps, | |
| maskgit_noise_scale=args.maskgit_noise_scale, | |
| maskgit_fixed_schedule=args.maskgit_fixed_schedule, | |
| maskgit_sampling_type=args.maskgit_sampling_type, | |
| eos_detection_method=args.eos_detection_method, | |
| ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, | |
| ) | |
| eval_config = EvaluationConfig( | |
| sv_model=args.sv_model, | |
| asr_model_name=args.asr_model_name, | |
| with_utmosv2=not args.disable_utmosv2, | |
| ) | |
| cer, ssim = None, None | |
| # Run for each model (checkpoint or nemo) | |
| if has_checkpoint_mode: | |
| hparam_files = args.hparams_files.split(",") | |
| checkpoint_files = args.checkpoint_files.split(",") | |
| if len(hparam_files) != len(checkpoint_files): | |
| parser.error("Number of hparams_files must match number of checkpoint_files") | |
| for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): | |
| logger.info(f"Processing checkpoint: {checkpoint_file}") | |
| model_config = ModelLoadConfig( | |
| hparams_file=hparams_file, | |
| checkpoint_file=checkpoint_file, | |
| codecmodel_path=args.codecmodel_path, | |
| legacy_codebooks=args.legacy_codebooks, | |
| legacy_text_conditioning=args.legacy_text_conditioning, | |
| hparams_from_wandb=args.hparams_file_from_wandb, | |
| ) | |
| cer, ssim = run_inference_and_evaluation( | |
| model_config=model_config, | |
| inference_config=inference_config, | |
| eval_config=eval_config, | |
| datasets=datasets, | |
| out_dir=args.out_dir, | |
| num_repeats=args.num_repeats, | |
| confidence_level=args.confidence_level, | |
| violin_plot_metrics=args.violin_plot_metrics, | |
| log_exp_name=args.log_exp_name, | |
| clean_up_disk=args.clean_up_disk, | |
| skip_evaluation=not args.run_evaluation, | |
| ) | |
| else: # nemo mode | |
| for nemo_file in args.nemo_files.split(","): | |
| logger.info(f"Processing NeMo file: {nemo_file}") | |
| model_config = ModelLoadConfig( | |
| nemo_file=nemo_file, | |
| codecmodel_path=args.codecmodel_path, | |
| legacy_codebooks=args.legacy_codebooks, | |
| legacy_text_conditioning=args.legacy_text_conditioning, | |
| ) | |
| cer, ssim = run_inference_and_evaluation( | |
| model_config=model_config, | |
| inference_config=inference_config, | |
| eval_config=eval_config, | |
| datasets=datasets, | |
| out_dir=args.out_dir, | |
| num_repeats=args.num_repeats, | |
| confidence_level=args.confidence_level, | |
| violin_plot_metrics=args.violin_plot_metrics, | |
| log_exp_name=args.log_exp_name, | |
| clean_up_disk=args.clean_up_disk, | |
| skip_evaluation=not args.run_evaluation, | |
| ) | |
| # Check quality targets | |
| if cer is not None and args.cer_target is not None: | |
| if cer > args.cer_target: | |
| raise ValueError(f"CER {cer:.4f} exceeds target {args.cer_target:.4f}") | |
| logger.info(f"CER {cer:.4f} meets target {args.cer_target:.4f}") | |
| if ssim is not None and args.ssim_target is not None: | |
| if ssim < args.ssim_target: | |
| raise ValueError(f"SSIM {ssim:.4f} below target {args.ssim_target:.4f}") | |
| logger.info(f"SSIM {ssim:.4f} meets target {args.ssim_target:.4f}") | |
| logger.info("Inference and evaluation completed successfully.") | |
| if __name__ == '__main__': | |
| main() | |