Skip to content

Commit

Permalink
Merge pull request #123 from Googlefan256/main
Browse files Browse the repository at this point in the history
  • Loading branch information
tuna2134 authored Nov 6, 2024
2 parents 380daf4 + 14d631e commit 2d557fb
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 16 deletions.
5 changes: 4 additions & 1 deletion sbv2_api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async fn synthesize(
) -> AppResult<impl IntoResponse> {
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
let buffer = {
let tts_model = state.tts_model.lock().await;
let mut tts_model = state.tts_model.lock().await;
tts_model.easy_synthesize(
&ident,
&text,
Expand All @@ -94,6 +94,9 @@ impl AppState {
let mut tts_model = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?,
env::var("HOLDER_MAX_LOADED_MODElS")
.ok()
.and_then(|x| x.parse().ok()),
)?;
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
let mut f = fs::read_dir(&models).await?;
Expand Down
26 changes: 21 additions & 5 deletions sbv2_bindings/src/sbv2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ pub struct TTSModel {

#[pymethods]
impl TTSModel {
#[pyo3(signature = (bert_model_bytes, tokenizer_bytes, max_loaded_models=None))]
#[new]
fn new(bert_model_bytes: Vec<u8>, tokenizer_bytes: Vec<u8>) -> anyhow::Result<Self> {
fn new(
bert_model_bytes: Vec<u8>,
tokenizer_bytes: Vec<u8>,
max_loaded_models: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self {
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes)?,
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes, max_loaded_models)?,
})
}

Expand All @@ -38,10 +43,21 @@ impl TTSModel {
/// BERTモデルのパス
/// tokenizer_path : str
/// トークナイザーのパス
/// max_loaded_models: int | None
/// 同時にVRAMに存在するモデルの数
#[pyo3(signature = (bert_model_path, tokenizer_path, max_loaded_models=None))]
#[staticmethod]
fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result<Self> {
fn from_path(
bert_model_path: String,
tokenizer_path: String,
max_loaded_models: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self {
model: TTSModelHolder::new(fs::read(bert_model_path)?, fs::read(tokenizer_path)?)?,
model: TTSModelHolder::new(
fs::read(bert_model_path)?,
fs::read(tokenizer_path)?,
max_loaded_models,
)?,
})
}

Expand Down Expand Up @@ -121,7 +137,7 @@ impl TTSModel {
/// voice_data : bytes
/// 音声データ
fn synthesize<'p>(
&'p self,
&'p mut self,
py: Python<'p>,
text: String,
ident: String,
Expand Down
3 changes: 3 additions & 0 deletions sbv2_core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ fn main_inner() -> anyhow::Result<()> {
let mut tts_holder = tts::TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?)?,
&fs::read(env::var("TOKENIZER_PATH")?)?,
env::var("HOLDER_MAX_LOADED_MODElS")
.ok()
.and_then(|x| x.parse().ok()),
)?;
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;

Expand Down
90 changes: 80 additions & 10 deletions sbv2_core/src/tts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ where
}

pub struct TTSModel {
vits2: Session,
vits2: Option<Session>,
style_vectors: Array2<f32>,
ident: TTSIdent,
bytes: Option<Vec<u8>>,
}

/// High-level Style-Bert-VITS2's API
Expand All @@ -35,6 +36,7 @@ pub struct TTSModelHolder {
bert: Session,
models: Vec<TTSModel>,
jtalk: jtalk::JTalk,
max_loaded_models: Option<usize>,
}

impl TTSModelHolder {
Expand All @@ -43,9 +45,13 @@ impl TTSModelHolder {
/// # Examples
///
/// ```rs
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?)?;
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?, None)?;
/// ```
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
pub fn new<P: AsRef<[u8]>>(
bert_model_bytes: P,
tokenizer_bytes: P,
max_loaded_models: Option<usize>,
) -> Result<Self> {
let bert = model::load_model(bert_model_bytes, true)?;
let jtalk = jtalk::JTalk::new()?;
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
Expand All @@ -54,6 +60,7 @@ impl TTSModelHolder {
models: vec![],
jtalk,
tokenizer,
max_loaded_models,
})
}

Expand Down Expand Up @@ -94,10 +101,25 @@ impl TTSModelHolder {
) -> Result<()> {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
let mut load = true;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
load = false;
}
}
self.models.push(TTSModel {
vits2: model::load_model(vits2_bytes, false)?,
vits2: if load {
Some(model::load_model(&vits2_bytes, false)?)
} else {
None
},
style_vectors: style::load_style(style_vectors_bytes)?,
ident,
bytes: if self.max_loaded_models.is_some() {
Some(vits2_bytes.as_ref().to_vec())
} else {
None
},
})
}
Ok(())
Expand Down Expand Up @@ -145,6 +167,42 @@ impl TTSModelHolder {
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))
}
fn find_and_load_model<I: Into<TTSIdent>>(&mut self, ident: I) -> Result<bool> {
let ident = ident.into();
let (bytes, style_vectors) = {
let model = self
.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
if model.vits2.is_some() {
return Ok(true);
}
(model.bytes.clone().unwrap(), model.style_vectors.clone())
};
self.unload(ident.clone());
let s = model::load_model(&bytes, false)?;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
self.unload(self.models.first().unwrap().ident.clone());
}
}
self.models.push(TTSModel {
bytes: Some(bytes.to_vec()),
vits2: Some(s),
style_vectors,
ident: ident.clone(),
});
let model = self
.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
if model.vits2.is_some() {
return Ok(true);
}
Err(Error::ModelNotFoundError(ident.to_string()))
}

/// Get style vector by style id and weight
///
Expand All @@ -167,12 +225,18 @@ impl TTSModelHolder {
/// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?;
/// ```
pub fn easy_synthesize<I: Into<TTSIdent> + Copy>(
&self,
&mut self,
ident: I,
text: &str,
style_id: i32,
options: SynthesizeOptions,
) -> Result<Vec<u8>> {
self.find_and_load_model(ident)?;
let vits2 = &self
.find_model(ident)?
.vits2
.as_ref()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
let audio_array = if options.split_sentences {
let texts: Vec<&str> = text.split('\n').collect();
Expand All @@ -183,7 +247,7 @@ impl TTSModelHolder {
}
let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?;
let audio = model::synthesize(
&self.find_model(ident)?.vits2,
&vits2,
bert_ori.to_owned(),
phones,
tones,
Expand All @@ -204,7 +268,7 @@ impl TTSModelHolder {
} else {
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
model::synthesize(
&self.find_model(ident)?.vits2,
&vits2,
bert_ori.to_owned(),
phones,
tones,
Expand All @@ -222,8 +286,8 @@ impl TTSModelHolder {
/// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
#[allow(clippy::too_many_arguments)]
pub fn synthesize<I: Into<TTSIdent>>(
&self,
pub fn synthesize<I: Into<TTSIdent> + Copy>(
&mut self,
ident: I,
bert_ori: Array2<f32>,
phones: Array1<i64>,
Expand All @@ -233,8 +297,14 @@ impl TTSModelHolder {
sdp_ratio: f32,
length_scale: f32,
) -> Result<Vec<u8>> {
self.find_and_load_model(ident)?;
let vits2 = &self
.find_model(ident)?
.vits2
.as_ref()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
let audio_array = model::synthesize(
&self.find_model(ident)?.vits2,
&vits2,
bert_ori.to_owned(),
phones,
tones,
Expand Down

0 comments on commit 2d557fb

Please sign in to comment.