From 0ab8526c61f4f930e0fcb0b4613f818c93872c20 Mon Sep 17 00:00:00 2001 From: David Weis Date: Sun, 5 Nov 2023 18:29:20 +0000 Subject: [PATCH] Add eleven labs --- Cargo.lock | 4 +- Cargo.toml | 4 +- config/settings.yaml | 1 + examples/astromech.rs | 5 +- src/bin/hopper.rs | 7 ++ src/configuration.rs | 1 + src/speech/eleven_labs_client.rs | 172 +++++++++++++++++++++++++++++++ src/speech/mod.rs | 3 + src/speech/speech_service.rs | 70 ++++++++++++- 9 files changed, 262 insertions(+), 5 deletions(-) create mode 100644 src/speech/eleven_labs_client.rs diff --git a/Cargo.lock b/Cargo.lock index 864ccc3..e0b95a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1836,13 +1836,14 @@ dependencies = [ [[package]] name = "hopper_rust" -version = "0.3.11" +version = "0.3.12" dependencies = [ "anyhow", "approx 0.5.1", "async-trait", "azure_tts", "bitflags 2.3.3", + "bytes", "chrono", "clap 4.3.19", "config", @@ -1862,6 +1863,7 @@ dependencies = [ "prost-reflect-build", "prost-types", "rand", + "reqwest", "rodio", "rplidar_driver", "serde", diff --git a/Cargo.toml b/Cargo.toml index 11e98aa..941bd32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ name = "hopper_rust" publish = false readme = "README.md" repository = "https://github.com/dmweis/hopper_rust" -version = "0.3.11" +version = "0.3.12" [package.metadata.deb] assets = [ @@ -60,6 +60,8 @@ rand = "0.8" sha2 = "0.10" thiserror = "1.0" walkdir = "2.3.3" +bytes = "1.4" +reqwest = {version = "0.11", features = ["json"]} # logging tracing = {version = "0.1", features = ["log"]} diff --git a/config/settings.yaml b/config/settings.yaml index b183520..54e1ca6 100644 --- a/config/settings.yaml +++ b/config/settings.yaml @@ -3,6 +3,7 @@ base: face_port: "/dev/hopper_face" tts_service_config: azure_api_key: "" + eleven_labs_api_key: "" cache_dir_path: "/var/cache/hopper/audio_cache" audio_repository_path: "/etc/hopper/audio/" lidar: diff --git a/examples/astromech.rs b/examples/astromech.rs index 8d24bff..b7e1f4f 100644 --- a/examples/astromech.rs +++ b/examples/astromech.rs @@ -17,7 +17,10 @@ struct Args { async fn main() -> Result<()> { let args = Args::parse(); - let mut speech_service = SpeechService::new(String::from(""), None, Some(args.audio)).unwrap(); + let mut speech_service = + SpeechService::new(String::from(""), String::from(""), None, Some(args.audio)) + .await + .unwrap(); if let Some(text) = args.text { speech_service.say_astromech(&text).await.unwrap(); diff --git a/src/bin/hopper.rs b/src/bin/hopper.rs index 0b0d78a..03a4876 100644 --- a/src/bin/hopper.rs +++ b/src/bin/hopper.rs @@ -85,9 +85,11 @@ async fn main() -> Result<()> { let mut speech_service = SpeechService::new( app_config.tts_service_config.azure_api_key, + app_config.tts_service_config.eleven_labs_api_key, app_config.tts_service_config.cache_dir_path, app_config.tts_service_config.audio_repository_path, ) + .await .unwrap(); speech_service @@ -95,6 +97,11 @@ async fn main() -> Result<()> { .await .unwrap(); + speech_service + .say_eleven("Hopper ready", "Natasha") + .await + .unwrap(); + ioc_container.register(TokioMutex::new(speech_service)); start_speech_controller( diff --git a/src/configuration.rs b/src/configuration.rs index 9efde0f..e2461db 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -48,6 +48,7 @@ pub struct BaseConfig { #[derive(Deserialize, Debug, Clone)] pub struct TtsServiceConfig { pub azure_api_key: String, + pub eleven_labs_api_key: String, pub cache_dir_path: Option, pub audio_repository_path: Option, } diff --git a/src/speech/eleven_labs_client.rs b/src/speech/eleven_labs_client.rs new file mode 100644 index 0000000..34c1e07 --- /dev/null +++ b/src/speech/eleven_labs_client.rs @@ -0,0 +1,172 @@ +use anyhow::Context; +use anyhow::Result; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize)] +pub struct TtsRequest { + /// Identifier of the model that will be used. + /// Defaults to "eleven_monolingual_v1" + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, + /// The text that will get converted into speech. + pub text: String, + /// Voice settings overriding stored setttings for the given voice. + /// They are applied only on the given TTS request. + /// Defaults to None + #[serde(skip_serializing_if = "Option::is_none")] + pub voice_settings: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct VoiceSettings { + pub similarity_boost: f64, + pub stability: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub style: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_speaker_boost: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Voices { + pub voices: Vec, +} + +impl Voices { + pub fn name_to_id_table(&self) -> HashMap { + let mut table = HashMap::new(); + for voice in &self.voices { + table.insert(voice.name.clone(), voice.voice_id.clone()); + } + table + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct Voice { + pub voice_id: String, + pub name: String, + pub samples: Option>, + pub category: Option, + pub labels: Option>, + pub description: Option, + pub preview_url: Option, + pub settings: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct VoiceSample { + pub sample_id: String, + file_name: String, + mime_type: String, + size_bytes: Option, + hash: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Subscription { + tier: String, + pub character_count: i64, + pub character_limit: i64, + can_extend_character_limit: bool, + allowed_to_extend_character_limit: bool, + next_character_count_reset_unix: i64, + voice_limit: i64, + professional_voice_limit: i64, + can_extend_voice_limit: bool, + can_use_instant_voice_cloning: bool, + can_use_professional_voice_cloning: bool, + currency: Option, + status: String, + next_invoice: Option, +} + +impl Subscription { + #[allow(dead_code)] + pub fn character_left(&self) -> i64 { + self.character_limit - self.character_count + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Invoice { + amount_due_cents: i64, + next_payment_attempt_unix: i64, +} + +#[derive(Debug, Clone)] +pub struct ElevenLabsTtsClient { + client: reqwest::Client, + api_key: String, +} + +impl ElevenLabsTtsClient { + pub fn new(api_key: String) -> Self { + ElevenLabsTtsClient { + client: reqwest::Client::new(), + api_key, + } + } + + pub async fn tts(&self, text: &str, voice_id: &str) -> Result { + let url = format!("https://api.elevenlabs.io/v1/text-to-speech/{}", voice_id); + + let body = TtsRequest { + text: text.to_owned(), + model_id: Some(String::from("eleven_multilingual_v1")), + voice_settings: Some(VoiceSettings { + similarity_boost: 0.5, + stability: 0.5, + style: None, + use_speaker_boost: None, + }), + }; + + let resp = self + .client + .post(url) + .header("xi-api-key", self.api_key.clone()) + .header("accept", "audio/mpeg") + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + resp.error_for_status_ref() + .context("Request failed with status")?; + let data = resp.bytes().await?; + Ok(data) + } + + pub async fn voices(&self) -> Result { + let resp = self + .client + .get("https://api.elevenlabs.io/v1/voices") + .header("xi-api-key", self.api_key.clone()) + .header("accept", "application/json") + .send() + .await?; + resp.error_for_status_ref() + .context("Request failed with status")?; + let data = resp.json::().await?; + + Ok(data) + } + + #[allow(dead_code)] + pub async fn get_subscription_info(&self) -> Result { + let resp = self + .client + .get("https://api.elevenlabs.io/v1/user/subscription") + .header("xi-api-key", self.api_key.clone()) + .header("accept", "application/json") + .send() + .await?; + resp.error_for_status_ref() + .context("Request failed with status")?; + let data = resp.json::().await?; + + Ok(data) + } +} diff --git a/src/speech/mod.rs b/src/speech/mod.rs index 6afdd02..61dcf85 100644 --- a/src/speech/mod.rs +++ b/src/speech/mod.rs @@ -7,6 +7,9 @@ mod audio_repository; #[cfg(feature = "audio")] mod speech_service; +#[cfg(feature = "audio")] +mod eleven_labs_client; + #[cfg(feature = "audio")] pub use speech_service::SpeechService; diff --git a/src/speech/speech_service.rs b/src/speech/speech_service.rs index 6519909..8dae344 100644 --- a/src/speech/speech_service.rs +++ b/src/speech/speech_service.rs @@ -2,6 +2,7 @@ use super::audio_cache::AudioCache; use super::audio_repository::AudioRepository; use super::AzureVoiceStyle; use crate::error::{HopperError, HopperResult}; +use anyhow::Context; use sha2::{Digest, Sha256}; use std::{fs::File, io::Cursor, thread}; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -32,6 +33,15 @@ fn hash_azure_tts( format!("{}-{:x}", voice.name, hashed) } +fn hash_eleven_labs_tts(text: &str, voice_id: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(text); + hasher.update(voice_id); + hasher.update(AZURE_FORMAT_VERSION.to_be_bytes()); + let hashed = hasher.finalize(); + format!("eleven-{:x}", hashed) +} + enum AudioPlayerCommand { Play(Box), Pause, @@ -128,6 +138,8 @@ fn create_player() -> Sender { pub struct SpeechService { azure_speech_client: azure_tts::VoiceService, + eleven_labs_client: super::eleven_labs_client::ElevenLabsTtsClient, + voice_name_to_voice_id_table: std::collections::HashMap, audio_cache: Option, audio_repository: Option, azure_voice: azure_tts::VoiceSettings, @@ -140,15 +152,25 @@ pub trait Playable: std::io::Read + std::io::Seek + Send + Sync {} impl Playable for Cursor> {} impl Playable for File {} +/// voice Freya +const DEFAULT_ELEVEN_LABS_VOICE_ID: &str = "jsCqWAovK2LkecY7zXl4"; + impl SpeechService { - pub fn new( + pub async fn new( azure_subscription_key: String, + eleven_labs_api_key: String, cache_dir_path: Option, audio_repository_path: Option, - ) -> HopperResult { + ) -> anyhow::Result { let azure_speech_client = azure_tts::VoiceService::new(&azure_subscription_key, azure_tts::Region::uksouth); + let eleven_labs_client = + super::eleven_labs_client::ElevenLabsTtsClient::new(eleven_labs_api_key); + + let voices = eleven_labs_client.voices().await?; + let voice_name_to_voice_id_table = voices.name_to_id_table(); + let audio_cache = match cache_dir_path { Some(path) => Some(AudioCache::new(path)?), None => None, @@ -163,6 +185,8 @@ impl SpeechService { Ok(SpeechService { azure_speech_client, + eleven_labs_client, + voice_name_to_voice_id_table, audio_cache, audio_repository, azure_voice: azure_tts::EnUsVoices::SaraNeural.to_voice_settings(), @@ -321,6 +345,48 @@ impl SpeechService { .await } + pub async fn say_eleven_with_default_voice(&mut self, text: &str) -> anyhow::Result<()> { + self.say_eleven_with_voice_id(text, DEFAULT_ELEVEN_LABS_VOICE_ID) + .await?; + Ok(()) + } + + pub async fn say_eleven(&mut self, text: &str, voice_name: &str) -> anyhow::Result<()> { + let voice_id = self + .voice_name_to_voice_id_table + .get(voice_name) + .context("Unknown voice")? + .clone(); + info!("Using voice id {} for voice {}", voice_id, voice_name); + self.say_eleven_with_voice_id(text, &voice_id).await?; + Ok(()) + } + + pub async fn say_eleven_with_voice_id( + &mut self, + text: &str, + voice_id: &str, + ) -> anyhow::Result<()> { + let sound: Box = if let Some(ref audio_cache) = self.audio_cache { + let file_key = hash_eleven_labs_tts(text, voice_id); + if let Some(file) = audio_cache.get(&file_key) { + info!("Using cached value with key {}", file_key); + file + } else { + info!("Writing new file with key {}", file_key); + let data = self.eleven_labs_client.tts(text, voice_id).await?; + let sound: Box = Box::new(Cursor::new(data.to_vec())); + audio_cache.set(&file_key, data.to_vec())?; + sound + } + } else { + let data = self.eleven_labs_client.tts(text, voice_id).await?; + Box::new(Cursor::new(data.to_vec())) + }; + self.play(sound).await; + Ok(()) + } + pub async fn pause(&self) { self.audio_sender .send(AudioPlayerCommand::Pause)