From 14d631eeaa81e9ed87819049542e0694fa8b1eaf Mon Sep 17 00:00:00 2001 From: Googlefan Date: Wed, 6 Nov 2024 10:43:41 +0000 Subject: [PATCH] wip: max loaded models --- sbv2_api/src/main.rs | 5 ++- sbv2_bindings/src/sbv2.rs | 26 ++++++++--- sbv2_core/src/main.rs | 3 ++ sbv2_core/src/tts.rs | 90 ++++++++++++++++++++++++++++++++++----- 4 files changed, 108 insertions(+), 16 deletions(-) diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index 3b95281..10e936a 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -69,7 +69,7 @@ async fn synthesize( ) -> AppResult { log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}"); let buffer = { - let tts_model = state.tts_model.lock().await; + let mut tts_model = state.tts_model.lock().await; tts_model.easy_synthesize( &ident, &text, @@ -94,6 +94,9 @@ impl AppState { let mut tts_model = TTSModelHolder::new( &fs::read(env::var("BERT_MODEL_PATH")?).await?, &fs::read(env::var("TOKENIZER_PATH")?).await?, + env::var("HOLDER_MAX_LOADED_MODElS") + .ok() + .and_then(|x| x.parse().ok()), )?; let models = env::var("MODELS_PATH").unwrap_or("models".to_string()); let mut f = fs::read_dir(&models).await?; diff --git a/sbv2_bindings/src/sbv2.rs b/sbv2_bindings/src/sbv2.rs index 378e5e3..0ecc414 100644 --- a/sbv2_bindings/src/sbv2.rs +++ b/sbv2_bindings/src/sbv2.rs @@ -23,10 +23,15 @@ pub struct TTSModel { #[pymethods] impl TTSModel { + #[pyo3(signature = (bert_model_bytes, tokenizer_bytes, max_loaded_models=None))] #[new] - fn new(bert_model_bytes: Vec, tokenizer_bytes: Vec) -> anyhow::Result { + fn new( + bert_model_bytes: Vec, + tokenizer_bytes: Vec, + max_loaded_models: Option, + ) -> anyhow::Result { Ok(Self { - model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes)?, + model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes, max_loaded_models)?, }) } @@ -38,10 +43,21 @@ impl TTSModel { /// BERTモデルのパス /// tokenizer_path : str /// トークナイザーのパス + /// max_loaded_models: int | None + /// 同時にVRAMに存在するモデルの数 + #[pyo3(signature = (bert_model_path, tokenizer_path, max_loaded_models=None))] #[staticmethod] - fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result { + fn from_path( + bert_model_path: String, + tokenizer_path: String, + max_loaded_models: Option, + ) -> anyhow::Result { Ok(Self { - model: TTSModelHolder::new(fs::read(bert_model_path)?, fs::read(tokenizer_path)?)?, + model: TTSModelHolder::new( + fs::read(bert_model_path)?, + fs::read(tokenizer_path)?, + max_loaded_models, + )?, }) } @@ -121,7 +137,7 @@ impl TTSModel { /// voice_data : bytes /// 音声データ fn synthesize<'p>( - &'p self, + &'p mut self, py: Python<'p>, text: String, ident: String, diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 3700cd4..3a4cbdd 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -11,6 +11,9 @@ fn main_inner() -> anyhow::Result<()> { let mut tts_holder = tts::TTSModelHolder::new( &fs::read(env::var("BERT_MODEL_PATH")?)?, &fs::read(env::var("TOKENIZER_PATH")?)?, + env::var("HOLDER_MAX_LOADED_MODElS") + .ok() + .and_then(|x| x.parse().ok()), )?; tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?; diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 29c4da6..d248e31 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -24,9 +24,10 @@ where } pub struct TTSModel { - vits2: Session, + vits2: Option, style_vectors: Array2, ident: TTSIdent, + bytes: Option>, } /// High-level Style-Bert-VITS2's API @@ -35,6 +36,7 @@ pub struct TTSModelHolder { bert: Session, models: Vec, jtalk: jtalk::JTalk, + max_loaded_models: Option, } impl TTSModelHolder { @@ -43,9 +45,13 @@ impl TTSModelHolder { /// # Examples /// /// ```rs - /// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?)?; + /// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?, None)?; /// ``` - pub fn new>(bert_model_bytes: P, tokenizer_bytes: P) -> Result { + pub fn new>( + bert_model_bytes: P, + tokenizer_bytes: P, + max_loaded_models: Option, + ) -> Result { let bert = model::load_model(bert_model_bytes, true)?; let jtalk = jtalk::JTalk::new()?; let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?; @@ -54,6 +60,7 @@ impl TTSModelHolder { models: vec![], jtalk, tokenizer, + max_loaded_models, }) } @@ -94,10 +101,25 @@ impl TTSModelHolder { ) -> Result<()> { let ident = ident.into(); if self.find_model(ident.clone()).is_err() { + let mut load = true; + if let Some(max) = self.max_loaded_models { + if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max { + load = false; + } + } self.models.push(TTSModel { - vits2: model::load_model(vits2_bytes, false)?, + vits2: if load { + Some(model::load_model(&vits2_bytes, false)?) + } else { + None + }, style_vectors: style::load_style(style_vectors_bytes)?, ident, + bytes: if self.max_loaded_models.is_some() { + Some(vits2_bytes.as_ref().to_vec()) + } else { + None + }, }) } Ok(()) @@ -145,6 +167,42 @@ impl TTSModelHolder { .find(|m| m.ident == ident) .ok_or(Error::ModelNotFoundError(ident.to_string())) } + fn find_and_load_model>(&mut self, ident: I) -> Result { + let ident = ident.into(); + let (bytes, style_vectors) = { + let model = self + .models + .iter() + .find(|m| m.ident == ident) + .ok_or(Error::ModelNotFoundError(ident.to_string()))?; + if model.vits2.is_some() { + return Ok(true); + } + (model.bytes.clone().unwrap(), model.style_vectors.clone()) + }; + self.unload(ident.clone()); + let s = model::load_model(&bytes, false)?; + if let Some(max) = self.max_loaded_models { + if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max { + self.unload(self.models.first().unwrap().ident.clone()); + } + } + self.models.push(TTSModel { + bytes: Some(bytes.to_vec()), + vits2: Some(s), + style_vectors, + ident: ident.clone(), + }); + let model = self + .models + .iter() + .find(|m| m.ident == ident) + .ok_or(Error::ModelNotFoundError(ident.to_string()))?; + if model.vits2.is_some() { + return Ok(true); + } + Err(Error::ModelNotFoundError(ident.to_string())) + } /// Get style vector by style id and weight /// @@ -167,12 +225,18 @@ impl TTSModelHolder { /// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?; /// ``` pub fn easy_synthesize + Copy>( - &self, + &mut self, ident: I, text: &str, style_id: i32, options: SynthesizeOptions, ) -> Result> { + self.find_and_load_model(ident)?; + let vits2 = &self + .find_model(ident)? + .vits2 + .as_ref() + .ok_or(Error::ModelNotFoundError(ident.into().to_string()))?; let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?; let audio_array = if options.split_sentences { let texts: Vec<&str> = text.split('\n').collect(); @@ -183,7 +247,7 @@ impl TTSModelHolder { } let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?; let audio = model::synthesize( - &self.find_model(ident)?.vits2, + &vits2, bert_ori.to_owned(), phones, tones, @@ -204,7 +268,7 @@ impl TTSModelHolder { } else { let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?; model::synthesize( - &self.find_model(ident)?.vits2, + &vits2, bert_ori.to_owned(), phones, tones, @@ -222,8 +286,8 @@ impl TTSModelHolder { /// # Note /// This function is for low-level usage, use `easy_synthesize` for high-level usage. #[allow(clippy::too_many_arguments)] - pub fn synthesize>( - &self, + pub fn synthesize + Copy>( + &mut self, ident: I, bert_ori: Array2, phones: Array1, @@ -233,8 +297,14 @@ impl TTSModelHolder { sdp_ratio: f32, length_scale: f32, ) -> Result> { + self.find_and_load_model(ident)?; + let vits2 = &self + .find_model(ident)? + .vits2 + .as_ref() + .ok_or(Error::ModelNotFoundError(ident.into().to_string()))?; let audio_array = model::synthesize( - &self.find_model(ident)?.vits2, + &vits2, bert_ori.to_owned(), phones, tones,