|
|
|
|
|
|
|
|
use crate::{Error, Result}; |
|
|
use ndarray::{Array, IxDyn}; |
|
|
use std::collections::HashMap; |
|
|
use std::path::{Path, PathBuf}; |
|
|
use std::sync::{Arc, RwLock}; |
|
|
|
|
|
|
|
|
pub struct OnnxSession { |
|
|
input_names: Vec<String>, |
|
|
output_names: Vec<String>, |
|
|
} |
|
|
|
|
|
impl OnnxSession { |
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let path = path.as_ref(); |
|
|
if !path.exists() { |
|
|
return Err(Error::FileNotFound(path.display().to_string())); |
|
|
} |
|
|
|
|
|
|
|
|
log::info!("Loading ONNX model from: {}", path.display()); |
|
|
|
|
|
Ok(Self { |
|
|
input_names: vec!["input".to_string()], |
|
|
output_names: vec!["output".to_string()], |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn run( |
|
|
&self, |
|
|
_inputs: HashMap<String, Array<f32, IxDyn>>, |
|
|
) -> Result<HashMap<String, Array<f32, IxDyn>>> { |
|
|
|
|
|
let mut result = HashMap::new(); |
|
|
for name in &self.output_names { |
|
|
let dummy = Array::zeros(IxDyn(&[1, 1])); |
|
|
result.insert(name.clone(), dummy); |
|
|
} |
|
|
Ok(result) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn run_i64( |
|
|
&self, |
|
|
_inputs: HashMap<String, Array<i64, IxDyn>>, |
|
|
) -> Result<HashMap<String, Array<f32, IxDyn>>> { |
|
|
let mut result = HashMap::new(); |
|
|
for name in &self.output_names { |
|
|
let dummy = Array::zeros(IxDyn(&[1, 1])); |
|
|
result.insert(name.clone(), dummy); |
|
|
} |
|
|
Ok(result) |
|
|
} |
|
|
|
|
|
pub fn input_names(&self) -> &[String] { |
|
|
&self.input_names |
|
|
} |
|
|
|
|
|
pub fn output_names(&self) -> &[String] { |
|
|
&self.output_names |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct ModelCache { |
|
|
sessions: RwLock<HashMap<String, Arc<OnnxSession>>>, |
|
|
model_dir: PathBuf, |
|
|
} |
|
|
|
|
|
impl ModelCache { |
|
|
pub fn new<P: AsRef<Path>>(model_dir: P) -> Self { |
|
|
Self { |
|
|
sessions: RwLock::new(HashMap::new()), |
|
|
model_dir: model_dir.as_ref().to_path_buf(), |
|
|
} |
|
|
} |
|
|
|
|
|
pub fn get_or_load(&self, name: &str) -> Result<Arc<OnnxSession>> { |
|
|
{ |
|
|
let cache = self.sessions.read().unwrap(); |
|
|
if let Some(session) = cache.get(name) { |
|
|
return Ok(Arc::clone(session)); |
|
|
} |
|
|
} |
|
|
|
|
|
let model_path = self.model_dir.join(format!("{}.onnx", name)); |
|
|
let session = OnnxSession::load(&model_path)?; |
|
|
let session = Arc::new(session); |
|
|
|
|
|
{ |
|
|
let mut cache = self.sessions.write().unwrap(); |
|
|
cache.insert(name.to_string(), Arc::clone(&session)); |
|
|
} |
|
|
|
|
|
Ok(session) |
|
|
} |
|
|
|
|
|
pub fn preload(&self, model_names: &[&str]) -> Result<()> { |
|
|
for name in model_names { |
|
|
self.get_or_load(name)?; |
|
|
} |
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
pub fn clear(&self) { |
|
|
let mut cache = self.sessions.write().unwrap(); |
|
|
cache.clear(); |
|
|
} |
|
|
|
|
|
pub fn is_cached(&self, name: &str) -> bool { |
|
|
let cache = self.sessions.read().unwrap(); |
|
|
cache.contains_key(name) |
|
|
} |
|
|
|
|
|
pub fn cached_models(&self) -> Vec<String> { |
|
|
let cache = self.sessions.read().unwrap(); |
|
|
cache.keys().cloned().collect() |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_model_cache_creation() { |
|
|
let cache = ModelCache::new("/tmp/models"); |
|
|
assert!(cache.cached_models().is_empty()); |
|
|
} |
|
|
} |
|
|
|