ThreadAbort Copilot commited on
Commit
668baf3
·
unverified ·
1 Parent(s): a27c268

Update src/model/embedding.rs

Browse files

Co-authored-by: Copilot <[email protected]>

Files changed (1) hide show
  1. src/model/embedding.rs +4 -1
src/model/embedding.rs CHANGED
@@ -136,9 +136,12 @@ impl EmotionEncoder {
136
  .map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?;
137
 
138
  let shape = tensor.shape();
139
- let data: Vec<f32> = tensor.data().chunks(4).map(|b| {
140
  f32::from_le_bytes([b[0], b[1], b[2], b[3]])
141
  }).collect();
 
 
 
142
 
143
  let emotion_matrix = Array2::from_shape_vec((shape[0], shape[1]), data)
144
  .map_err(|e| Error::ModelLoading(format!("Shape mismatch: {}", e)))?;
 
136
  .map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?;
137
 
138
  let shape = tensor.shape();
139
+ let mut data: Vec<f32> = tensor.data().chunks_exact(4).map(|b| {
140
  f32::from_le_bytes([b[0], b[1], b[2], b[3]])
141
  }).collect();
142
+ if !tensor.data().chunks_exact(4).remainder().is_empty() {
143
+ return Err(Error::ModelLoading("Tensor data length is not a multiple of 4".to_string()));
144
+ }
145
 
146
  let emotion_matrix = Array2::from_shape_vec((shape[0], shape[1]), data)
147
  .map_err(|e| Error::ModelLoading(format!("Shape mismatch: {}", e)))?;