Skip to content

Commit

Permalink
Merge pull request #138 from tuna2134/aivmx
Browse files Browse the repository at this point in the history
support aivmx
  • Loading branch information
tuna2134 authored Nov 20, 2024
2 parents aa7fc2e + db09b73 commit a7fbfa2
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 4 deletions.
87 changes: 87 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion convert/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_style_vector(style_id, weight):
)


def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio, noise_scale, noise_scale_w):
return model.infer(
x,
x_len,
Expand All @@ -105,6 +105,8 @@ def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
style,
sdp_ratio=sdp_ratio,
length_scale=length_scale,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
)


Expand All @@ -122,6 +124,8 @@ def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
style_vec_tensor,
torch.tensor(1.0),
torch.tensor(0.0),
torch.tensor(0.6777),
torch.tensor(0.8),
),
f"../models/model_{out_name}.onnx",
verbose=True,
Expand All @@ -144,6 +148,8 @@ def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
"style_vec",
"length_scale",
"sdp_ratio",
"noise_scale",
"noise_scale_w"
],
output_names=["output"],
)
Expand Down
2 changes: 1 addition & 1 deletion sbv2_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ axum = "0.7.5"
dotenvy.workspace = true
env_logger.workspace = true
log = "0.4.22"
sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core" }
sbv2_core = { version = "0.2.0-alpha2", path = "../sbv2_core", features = ["aivmx"] }
serde = { version = "1.0.210", features = ["derive"] }
tokio = { version = "1.40.0", features = ["full"] }
utoipa = { version = "5.0.0", features = ["axum_extras"] }
Expand Down
14 changes: 14 additions & 0 deletions sbv2_api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ impl AppState {
log::warn!("Error loading {entry}: {e}");
};
log::info!("Loaded: {entry}");
} else if name.ends_with(".aivmx") {
let entry = &name[..name.len() - 6];
log::info!("Try loading: {entry}");
let aivmx_bytes = match fs::read(format!("{models}/{entry}.aivmx")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading aivmx bytes from file {entry}: {e}");
continue;
}
};
if let Err(e) = tts_model.load_aivmx(entry, aivmx_bytes) {
log::error!("Error loading {entry}: {e}");
}
log::info!("Loaded: {entry}");
}
}
for entry in entries {
Expand Down
6 changes: 5 additions & 1 deletion sbv2_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ documentation = "https://docs.rs/sbv2_core"

[dependencies]
anyhow.workspace = true
base64 = { version = "0.22.1", optional = true }
dotenvy.workspace = true
env_logger.workspace = true
hound = "3.5.1"
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
ndarray.workspace = true
npyz = { version = "0.8.3", optional = true }
num_cpus = "1.16.0"
once_cell.workspace = true
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.8", optional = true }
Expand All @@ -35,4 +37,6 @@ directml = ["ort/directml", "std"]
tensorrt = ["ort/tensorrt", "std"]
coreml = ["ort/coreml", "std"]
default = ["std"]
no_std = ["tokenizers/unstable_wasm"]
no_std = ["tokenizers/unstable_wasm"]
aivmx = ["npyz", "base64"]
base64 = ["dep:base64"]
3 changes: 3 additions & 0 deletions sbv2_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ pub enum Error {
HoundError(#[from] hound::Error),
#[error("model not found error")]
ModelNotFoundError(String),
#[cfg(feature = "base64")]
#[error("base64 error")]
Base64Error(#[from] base64::DecodeError),
#[error("other")]
OtherError(String),
}
Expand Down
10 changes: 9 additions & 1 deletion sbv2_core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ fn main_inner() -> anyhow::Result<()> {
.ok()
.and_then(|x| x.parse().ok()),
)?;
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
#[cfg(not(feature = "aivmx"))]
{
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
}
#[cfg(feature = "aivmx")]
{
tts_holder.load_aivmx(ident, fs::read(env::var("MODEL_PATH")?)?)?;
}

let audio =
tts_holder.easy_synthesize(ident, &text, 0, 0, tts::SynthesizeOptions::default())?;
fs::write("output.wav", audio)?;

Ok(())
}

#[cfg(not(feature = "std"))]
fn main_inner() -> anyhow::Result<()> {
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions sbv2_core/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub fn synthesize(
style_vector: Array1<f32>,
sdp_ratio: f32,
length_scale: f32,
noise_scale: f32,
noise_scale_w: f32,
) -> Result<Array3<f32>> {
let bert = bert_ori.insert_axis(Axis(0));
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
Expand All @@ -75,6 +77,8 @@ pub fn synthesize(
"style_vec" => style_vector,
"sdp_ratio" => array![sdp_ratio],
"length_scale" => array![length_scale],
"noise_scale" => array![noise_scale],
"noise_scale_w" => array![noise_scale_w]
}?)?;

let audio_array = outputs["output"]
Expand Down
57 changes: 57 additions & 0 deletions sbv2_core/src/tts.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use crate::error::{Error, Result};
use crate::{jtalk, model, style, tokenizer, tts_util};
#[cfg(feature = "aivmx")]
use base64::prelude::{Engine as _, BASE64_STANDARD};
#[cfg(feature = "aivmx")]
use ndarray::ShapeBuilder;
use ndarray::{concatenate, Array1, Array2, Array3, Axis};
use ort::Session;
#[cfg(feature = "aivmx")]
use std::io::Cursor;
use tokenizers::Tokenizer;

#[derive(PartialEq, Eq, Clone)]
Expand Down Expand Up @@ -69,6 +75,53 @@ impl TTSModelHolder {
self.models.iter().map(|m| m.ident.to_string()).collect()
}

#[cfg(feature = "aivmx")]
pub fn load_aivmx<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
aivmx_bytes: P,
) -> Result<()> {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
let mut load = true;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
load = false;
}
}
let model = model::load_model(&aivmx_bytes, false)?;
let metadata = model.metadata()?;
if let Some(aivm_style_vectors) = metadata.custom("aivm_style_vectors")? {
let aivm_style_vectors = BASE64_STANDARD.decode(aivm_style_vectors)?;
let style_vectors = Cursor::new(&aivm_style_vectors);
let reader = npyz::NpyFile::new(style_vectors)?;
let style_vectors = {
let shape = reader.shape().to_vec();
let order = reader.order();
let data = reader.into_vec::<f32>()?;
let shape = match shape[..] {
[i1, i2] => [i1 as usize, i2 as usize],
_ => panic!("expected 2D array"),
};
let true_shape = shape.set_f(order == npyz::Order::Fortran);
ndarray::Array2::from_shape_vec(true_shape, data)?
};
drop(metadata);
self.models.push(TTSModel {
vits2: if load { Some(model) } else { None },
bytes: if self.max_loaded_models.is_some() {
Some(aivmx_bytes.as_ref().to_vec())
} else {
None
},
ident,
style_vectors,
})
}
}
Ok(())
}

/// Load a .sbv2 file binary
///
/// # Examples
Expand Down Expand Up @@ -257,6 +310,8 @@ impl TTSModelHolder {
style_vector.clone(),
options.sdp_ratio,
options.length_scale,
0.677,
0.8,
)?;
audios.push(audio.clone());
if i != texts.len() - 1 {
Expand All @@ -279,6 +334,8 @@ impl TTSModelHolder {
style_vector,
options.sdp_ratio,
options.length_scale,
0.677,
0.8,
)?
};
tts_util::array_to_vec(audio_array)
Expand Down

0 comments on commit a7fbfa2

Please sign in to comment.