|
|
|
|
|
|
|
|
use crate::{Error, Result}; |
|
|
use ndarray::{Array1, Array2, Array, IxDyn}; |
|
|
use std::collections::HashMap; |
|
|
use std::path::Path; |
|
|
|
|
|
use super::OnnxSession; |
|
|
|
|
|
|
|
|
pub struct SpeakerEncoder { |
|
|
session: Option<OnnxSession>, |
|
|
embedding_dim: usize, |
|
|
} |
|
|
|
|
|
impl SpeakerEncoder { |
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let session = OnnxSession::load(path)?; |
|
|
Ok(Self { |
|
|
session: Some(session), |
|
|
embedding_dim: 192, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn new_placeholder(embedding_dim: usize) -> Self { |
|
|
Self { |
|
|
session: None, |
|
|
embedding_dim, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode(&self, mel_spectrogram: &Array2<f32>) -> Result<Array1<f32>> { |
|
|
if let Some(ref session) = self.session { |
|
|
|
|
|
let input = mel_spectrogram |
|
|
.clone() |
|
|
.into_shape(IxDyn(&[1, mel_spectrogram.nrows(), mel_spectrogram.ncols()]))?; |
|
|
|
|
|
let mut inputs = HashMap::new(); |
|
|
inputs.insert("mel".to_string(), input); |
|
|
|
|
|
let outputs = session.run(inputs)?; |
|
|
|
|
|
let embedding = outputs |
|
|
.get("embedding") |
|
|
.ok_or_else(|| Error::Model("Missing embedding output".into()))?; |
|
|
|
|
|
|
|
|
let flat: Vec<f32> = embedding.iter().cloned().collect(); |
|
|
Ok(Array1::from_vec(flat)) |
|
|
} else { |
|
|
|
|
|
Ok(Array1::from_vec(vec![0.0f32; self.embedding_dim])) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode_audio(&self, audio_path: &str) -> Result<Array1<f32>> { |
|
|
use crate::audio::{compute_mel_from_file, AudioConfig}; |
|
|
|
|
|
let config = AudioConfig::default(); |
|
|
let mel = compute_mel_from_file(audio_path, &config)?; |
|
|
self.encode(&mel) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn embedding_dim(&self) -> usize { |
|
|
self.embedding_dim |
|
|
} |
|
|
|
|
|
|
|
|
pub fn normalize_embedding(&self, embedding: &Array1<f32>) -> Array1<f32> { |
|
|
let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
if norm > 1e-8 { |
|
|
embedding / norm |
|
|
} else { |
|
|
embedding.clone() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn cosine_similarity(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 { |
|
|
let norm1 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
let norm2 = emb2.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
|
|
|
if norm1 < 1e-8 || norm2 < 1e-8 { |
|
|
return 0.0; |
|
|
} |
|
|
|
|
|
let dot: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum(); |
|
|
dot / (norm1 * norm2) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct EmotionEncoder { |
|
|
|
|
|
emotion_matrix: Array2<f32>, |
|
|
|
|
|
num_dims: usize, |
|
|
|
|
|
dim_sizes: Vec<usize>, |
|
|
} |
|
|
|
|
|
impl EmotionEncoder { |
|
|
|
|
|
pub fn new(num_dims: usize, dim_sizes: Vec<usize>, embedding_dim: usize) -> Self { |
|
|
let total_emotions: usize = dim_sizes.iter().sum(); |
|
|
let emotion_matrix = Array2::zeros((total_emotions, embedding_dim)); |
|
|
|
|
|
Self { |
|
|
emotion_matrix, |
|
|
num_dims, |
|
|
dim_sizes, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let path = path.as_ref(); |
|
|
if !path.exists() { |
|
|
return Err(Error::FileNotFound(path.display().to_string())); |
|
|
} |
|
|
|
|
|
|
|
|
let file_data = std::fs::read(path)?; |
|
|
let tensors = safetensors::SafeTensors::deserialize(&file_data) |
|
|
.map_err(|e| Error::ModelLoading(format!("Failed to load safetensors: {}", e)))?; |
|
|
|
|
|
|
|
|
let tensor = tensors |
|
|
.tensor("emotion_matrix") |
|
|
.map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?; |
|
|
|
|
|
let shape = tensor.shape(); |
|
|
let data: Vec<f32> = tensor.data().chunks_exact(4).map(|b| { |
|
|
f32::from_le_bytes([b[0], b[1], b[2], b[3]]) |
|
|
}).collect(); |
|
|
if !tensor.data().chunks_exact(4).remainder().is_empty() { |
|
|
return Err(Error::ModelLoading("Tensor data length is not a multiple of 4".to_string())); |
|
|
} |
|
|
|
|
|
let emotion_matrix = Array2::from_shape_vec((shape[0], shape[1]), data) |
|
|
.map_err(|e| Error::ModelLoading(format!("Shape mismatch: {}", e)))?; |
|
|
|
|
|
|
|
|
let num_dims = 8; |
|
|
let dim_sizes = vec![5, 6, 8, 6, 5, 4, 7, 6]; |
|
|
|
|
|
Ok(Self { |
|
|
emotion_matrix, |
|
|
num_dims, |
|
|
dim_sizes, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode(&self, emotion_vector: &[f32]) -> Result<Array1<f32>> { |
|
|
if emotion_vector.len() != self.num_dims { |
|
|
return Err(Error::ShapeMismatch { |
|
|
expected: format!("{} dimensions", self.num_dims), |
|
|
actual: format!("{} dimensions", emotion_vector.len()), |
|
|
}); |
|
|
} |
|
|
|
|
|
let embedding_dim = self.emotion_matrix.ncols(); |
|
|
let mut embedding = vec![0.0f32; embedding_dim]; |
|
|
|
|
|
let mut offset = 0; |
|
|
for (WIN_LENGTH, (&value, &dim_size)) in emotion_vector.iter().zip(self.dim_sizes.iter()).enumerate() { |
|
|
|
|
|
let continuous_idx = value * (dim_size - 1) as f32; |
|
|
let lower_idx = continuous_idx.floor() as usize; |
|
|
let upper_idx = (lower_idx + 1).min(dim_size - 1); |
|
|
let alpha = continuous_idx - lower_idx as f32; |
|
|
|
|
|
|
|
|
for i in 0..embedding_dim { |
|
|
let lower_val = self.emotion_matrix[[offset + lower_idx, i]]; |
|
|
let upper_val = self.emotion_matrix[[offset + upper_idx, i]]; |
|
|
embedding[i] += lower_val * (1.0 - alpha) + upper_val * alpha; |
|
|
} |
|
|
|
|
|
offset += dim_size; |
|
|
} |
|
|
|
|
|
|
|
|
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
if norm > 1e-8 { |
|
|
for e in embedding.iter_mut() { |
|
|
*e /= norm; |
|
|
} |
|
|
} |
|
|
|
|
|
Ok(Array1::from_vec(embedding)) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn neutral(&self) -> Vec<f32> { |
|
|
vec![0.5f32; self.num_dims] |
|
|
} |
|
|
|
|
|
|
|
|
pub fn preset(&self, name: &str) -> Vec<f32> { |
|
|
match name { |
|
|
"happy" => vec![0.9, 0.7, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5], |
|
|
"sad" => vec![0.2, 0.3, 0.4, 0.5, 0.6, 0.5, 0.5, 0.5], |
|
|
"angry" => vec![0.8, 0.9, 0.7, 0.5, 0.3, 0.5, 0.5, 0.5], |
|
|
"fearful" => vec![0.3, 0.4, 0.8, 0.5, 0.7, 0.5, 0.5, 0.5], |
|
|
"surprised" => vec![0.7, 0.8, 0.7, 0.5, 0.5, 0.5, 0.5, 0.5], |
|
|
"neutral" | _ => self.neutral(), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn interpolate(&self, emot1: &[f32], emot2: &[f32], alpha: f32) -> Vec<f32> { |
|
|
emot1 |
|
|
.iter() |
|
|
.zip(emot2.iter()) |
|
|
.map(|(&a, &b)| a * (1.0 - alpha) + b * alpha) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn apply_strength(&self, emotion: &[f32], strength: f32) -> Vec<f32> { |
|
|
let neutral = self.neutral(); |
|
|
self.interpolate(&neutral, emotion, strength) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct SemanticEncoder { |
|
|
session: Option<OnnxSession>, |
|
|
embedding_dim: usize, |
|
|
} |
|
|
|
|
|
impl SemanticEncoder { |
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let session = OnnxSession::load(path)?; |
|
|
Ok(Self { |
|
|
session: Some(session), |
|
|
embedding_dim: 1024, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn new_placeholder() -> Self { |
|
|
Self { |
|
|
session: None, |
|
|
embedding_dim: 1024, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode(&self, audio: &[f32], sample_rate: u32) -> Result<Vec<i64>> { |
|
|
if let Some(ref session) = self.session { |
|
|
let input = Array::from_shape_vec( |
|
|
IxDyn(&[1, audio.len()]), |
|
|
audio.to_vec(), |
|
|
)?; |
|
|
|
|
|
let mut inputs = HashMap::new(); |
|
|
inputs.insert("audio".to_string(), input); |
|
|
|
|
|
let outputs = session.run(inputs)?; |
|
|
|
|
|
let codes = outputs |
|
|
.get("codes") |
|
|
.ok_or_else(|| Error::Model("Missing codes output".into()))?; |
|
|
|
|
|
Ok(codes.iter().map(|&x| x as i64).collect()) |
|
|
} else { |
|
|
|
|
|
let num_codes = audio.len() / (sample_rate as usize / 50); |
|
|
Ok(vec![0i64; num_codes.max(1)]) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_speaker_encoder_placeholder() { |
|
|
let encoder = SpeakerEncoder::new_placeholder(192); |
|
|
assert_eq!(encoder.embedding_dim(), 192); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_emotion_encoder() { |
|
|
let encoder = EmotionEncoder::new(8, vec![5, 6, 8, 6, 5, 4, 7, 6], 256); |
|
|
let neutral = encoder.neutral(); |
|
|
assert_eq!(neutral.len(), 8); |
|
|
assert!(neutral.iter().all(|&x| (x - 0.5).abs() < 1e-6)); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_emotion_presets() { |
|
|
let encoder = EmotionEncoder::new(8, vec![5, 6, 8, 6, 5, 4, 7, 6], 256); |
|
|
let happy = encoder.preset("happy"); |
|
|
assert_eq!(happy.len(), 8); |
|
|
assert!(happy[0] > 0.5); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_emotion_interpolation() { |
|
|
let encoder = EmotionEncoder::new(8, vec![5, 6, 8, 6, 5, 4, 7, 6], 256); |
|
|
let happy = encoder.preset("happy"); |
|
|
let sad = encoder.preset("sad"); |
|
|
let mid = encoder.interpolate(&happy, &sad, 0.5); |
|
|
|
|
|
|
|
|
for i in 0..8 { |
|
|
assert!((mid[i] - (happy[i] + sad[i]) / 2.0).abs() < 1e-6); |
|
|
} |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_cosine_similarity() { |
|
|
let encoder = SpeakerEncoder::new_placeholder(3); |
|
|
let emb1 = Array1::from_vec(vec![1.0, 0.0, 0.0]); |
|
|
let emb2 = Array1::from_vec(vec![1.0, 0.0, 0.0]); |
|
|
let sim = encoder.cosine_similarity(&emb1, &emb2); |
|
|
assert!((sim - 1.0).abs() < 1e-6); |
|
|
|
|
|
let emb3 = Array1::from_vec(vec![0.0, 1.0, 0.0]); |
|
|
let sim2 = encoder.cosine_similarity(&emb1, &emb3); |
|
|
assert!(sim2.abs() < 1e-6); |
|
|
} |
|
|
} |
|
|
|