Update src/model/embedding.rs
Browse filesCo-authored-by: Copilot <[email protected]>
- src/model/embedding.rs +4 -1
src/model/embedding.rs
CHANGED
|
@@ -136,9 +136,12 @@ impl EmotionEncoder {
|
|
| 136 |
.map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?;
|
| 137 |
|
| 138 |
let shape = tensor.shape();
|
| 139 |
-
let data: Vec<f32> = tensor.data().
|
| 140 |
f32::from_le_bytes([b[0], b[1], b[2], b[3]])
|
| 141 |
}).collect();
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
let emotion_matrix = Array2::from_shape_vec((shape[0], shape[1]), data)
|
| 144 |
.map_err(|e| Error::ModelLoading(format!("Shape mismatch: {}", e)))?;
|
|
|
|
| 136 |
.map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?;
|
| 137 |
|
| 138 |
let shape = tensor.shape();
|
| 139 |
+
let mut data: Vec<f32> = tensor.data().chunks_exact(4).map(|b| {
|
| 140 |
f32::from_le_bytes([b[0], b[1], b[2], b[3]])
|
| 141 |
}).collect();
|
| 142 |
+
if !tensor.data().chunks_exact(4).remainder().is_empty() {
|
| 143 |
+
return Err(Error::ModelLoading("Tensor data length is not a multiple of 4".to_string()));
|
| 144 |
+
}
|
| 145 |
|
| 146 |
let emotion_matrix = Array2::from_shape_vec((shape[0], shape[1]), data)
|
| 147 |
.map_err(|e| Error::ModelLoading(format!("Shape mismatch: {}", e)))?;
|