|
|
|
|
|
|
|
|
use crate::{Error, Result}; |
|
|
use ndarray::{Array, Array1, Array2, IxDyn}; |
|
|
use std::collections::HashMap; |
|
|
use std::path::Path; |
|
|
|
|
|
use super::{OnnxSession, SamplingStrategy, sample_from_logits, apply_repetition_penalty}; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct GptConfig { |
|
|
|
|
|
pub num_layers: usize, |
|
|
|
|
|
pub hidden_size: usize, |
|
|
|
|
|
pub num_heads: usize, |
|
|
|
|
|
pub max_seq_len: usize, |
|
|
|
|
|
pub vocab_size: usize, |
|
|
|
|
|
pub stop_token: usize, |
|
|
|
|
|
pub start_token: usize, |
|
|
} |
|
|
|
|
|
impl Default for GptConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
num_layers: 8, |
|
|
hidden_size: 512, |
|
|
num_heads: 8, |
|
|
max_seq_len: 250, |
|
|
vocab_size: 8194, |
|
|
stop_token: 8193, |
|
|
start_token: 8192, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct GptModel { |
|
|
session: OnnxSession, |
|
|
config: GptConfig, |
|
|
} |
|
|
|
|
|
impl GptModel { |
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P, config: GptConfig) -> Result<Self> { |
|
|
let session = OnnxSession::load(path)?; |
|
|
Ok(Self { session, config }) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn generate( |
|
|
&self, |
|
|
semantic_tokens: &[i64], |
|
|
speaker_embedding: &Array1<f32>, |
|
|
max_length: usize, |
|
|
strategy: &SamplingStrategy, |
|
|
repetition_penalty: f32, |
|
|
) -> Result<Vec<i64>> { |
|
|
let mut generated_tokens = vec![self.config.start_token as i64]; |
|
|
let mut past_tokens = Vec::new(); |
|
|
|
|
|
for _ in 0..max_length { |
|
|
|
|
|
let input_tokens = Array::from_shape_vec( |
|
|
IxDyn(&[1, generated_tokens.len()]), |
|
|
generated_tokens.clone(), |
|
|
)?; |
|
|
|
|
|
let speaker_emb = speaker_embedding |
|
|
.clone() |
|
|
.into_shape(IxDyn(&[1, speaker_embedding.len()]))?; |
|
|
|
|
|
let semantic_input = Array::from_shape_vec( |
|
|
IxDyn(&[1, semantic_tokens.len()]), |
|
|
semantic_tokens.to_vec(), |
|
|
)?; |
|
|
|
|
|
|
|
|
let mut inputs = HashMap::new(); |
|
|
inputs.insert("input_ids".to_string(), input_tokens.mapv(|x| x as f32)); |
|
|
inputs.insert("speaker_embedding".to_string(), speaker_emb); |
|
|
inputs.insert("semantic_tokens".to_string(), semantic_input.mapv(|x| x as f32)); |
|
|
|
|
|
|
|
|
let outputs = self.session.run(inputs)?; |
|
|
|
|
|
|
|
|
let logits = outputs |
|
|
.get("logits") |
|
|
.ok_or_else(|| Error::Model("Missing logits output".into()))?; |
|
|
|
|
|
|
|
|
let seq_len = logits.shape()[1]; |
|
|
let vocab_size = logits.shape()[2]; |
|
|
let last_logits: Vec<f32> = (0..vocab_size) |
|
|
.map(|i| logits[[0, seq_len - 1, i]]) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
let mut logits_vec = last_logits; |
|
|
let past_usize: Vec<usize> = past_tokens.iter().map(|&x| x as usize).collect(); |
|
|
apply_repetition_penalty(&mut logits_vec, &past_usize, repetition_penalty); |
|
|
|
|
|
|
|
|
let next_token = sample_from_logits(&logits_vec, strategy) as i64; |
|
|
|
|
|
|
|
|
if next_token == self.config.stop_token as i64 { |
|
|
break; |
|
|
} |
|
|
|
|
|
generated_tokens.push(next_token); |
|
|
past_tokens.push(next_token); |
|
|
} |
|
|
|
|
|
Ok(generated_tokens) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn generate_with_cache( |
|
|
&self, |
|
|
semantic_tokens: &[i64], |
|
|
speaker_embedding: &Array1<f32>, |
|
|
max_length: usize, |
|
|
strategy: &SamplingStrategy, |
|
|
repetition_penalty: f32, |
|
|
) -> Result<Vec<i64>> { |
|
|
|
|
|
|
|
|
self.generate( |
|
|
semantic_tokens, |
|
|
speaker_embedding, |
|
|
max_length, |
|
|
strategy, |
|
|
repetition_penalty, |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn config(&self) -> &GptConfig { |
|
|
&self.config |
|
|
} |
|
|
|
|
|
|
|
|
pub fn estimate_memory_mb(&self) -> f32 { |
|
|
let params = self.config.num_layers |
|
|
* self.config.hidden_size |
|
|
* self.config.hidden_size |
|
|
* 4; |
|
|
(params * 4) as f32 / 1_000_000.0 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub struct SimpleGptModel { |
|
|
config: GptConfig, |
|
|
|
|
|
token_embeddings: Array2<f32>, |
|
|
|
|
|
position_embeddings: Array2<f32>, |
|
|
|
|
|
output_projection: Array2<f32>, |
|
|
} |
|
|
|
|
|
impl SimpleGptModel { |
|
|
|
|
|
pub fn new_random(config: GptConfig) -> Self { |
|
|
use rand::Rng; |
|
|
let mut rng = rand::thread_rng(); |
|
|
|
|
|
let token_embeddings = Array2::from_shape_fn( |
|
|
(config.vocab_size, config.hidden_size), |
|
|
|_| rng.gen_range(-0.1..0.1), |
|
|
); |
|
|
|
|
|
let position_embeddings = Array2::from_shape_fn( |
|
|
(config.max_seq_len, config.hidden_size), |
|
|
|_| rng.gen_range(-0.1..0.1), |
|
|
); |
|
|
|
|
|
let output_projection = Array2::from_shape_fn( |
|
|
(config.hidden_size, config.vocab_size), |
|
|
|_| rng.gen_range(-0.1..0.1), |
|
|
); |
|
|
|
|
|
Self { |
|
|
config, |
|
|
token_embeddings, |
|
|
position_embeddings, |
|
|
output_projection, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn forward(&self, tokens: &[i64]) -> Vec<f32> { |
|
|
|
|
|
let mut hidden = vec![0.0f32; self.config.hidden_size]; |
|
|
|
|
|
for (pos, &token) in tokens.iter().enumerate().take(self.config.max_seq_len) { |
|
|
let token_idx = (token as usize).min(self.config.vocab_size - 1); |
|
|
|
|
|
for i in 0..self.config.hidden_size { |
|
|
hidden[i] += self.token_embeddings[[token_idx, i]] |
|
|
+ self.position_embeddings[[pos, i]]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt(); |
|
|
if norm > 1e-8 { |
|
|
for h in hidden.iter_mut() { |
|
|
*h /= norm; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let mut logits = vec![0.0f32; self.config.vocab_size]; |
|
|
for (i, logit) in logits.iter_mut().enumerate() { |
|
|
for j in 0..self.config.hidden_size { |
|
|
*logit += hidden[j] * self.output_projection[[j, i]]; |
|
|
} |
|
|
} |
|
|
|
|
|
logits |
|
|
} |
|
|
|
|
|
|
|
|
pub fn generate( |
|
|
&self, |
|
|
prompt: &[i64], |
|
|
max_length: usize, |
|
|
strategy: &SamplingStrategy, |
|
|
) -> Vec<i64> { |
|
|
let mut tokens = prompt.to_vec(); |
|
|
|
|
|
for _ in 0..max_length { |
|
|
let logits = self.forward(&tokens); |
|
|
let next_token = sample_from_logits(&logits, strategy) as i64; |
|
|
|
|
|
if next_token == self.config.stop_token as i64 { |
|
|
break; |
|
|
} |
|
|
|
|
|
tokens.push(next_token); |
|
|
|
|
|
if tokens.len() >= self.config.max_seq_len { |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
tokens |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_gpt_config_default() { |
|
|
let config = GptConfig::default(); |
|
|
assert_eq!(config.num_layers, 8); |
|
|
assert_eq!(config.hidden_size, 512); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_simple_gpt_forward() { |
|
|
let config = GptConfig { |
|
|
vocab_size: 100, |
|
|
hidden_size: 32, |
|
|
max_seq_len: 10, |
|
|
..Default::default() |
|
|
}; |
|
|
|
|
|
let model = SimpleGptModel::new_random(config); |
|
|
let tokens = vec![1i64, 2, 3]; |
|
|
let logits = model.forward(&tokens); |
|
|
|
|
|
assert_eq!(logits.len(), 100); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_simple_gpt_generate() { |
|
|
let config = GptConfig { |
|
|
vocab_size: 100, |
|
|
hidden_size: 32, |
|
|
max_seq_len: 20, |
|
|
stop_token: 99, |
|
|
..Default::default() |
|
|
}; |
|
|
|
|
|
let model = SimpleGptModel::new_random(config); |
|
|
let prompt = vec![1i64, 2, 3]; |
|
|
let generated = model.generate(&prompt, 10, &SamplingStrategy::Greedy); |
|
|
|
|
|
assert!(generated.len() >= 3); |
|
|
assert!(generated.len() <= 20); |
|
|
} |
|
|
} |
|
|
|