Skip to content

Commit

Permalink
Add eleven labs
Browse files Browse the repository at this point in the history
  • Loading branch information
dmweis committed Nov 5, 2023
1 parent 78e7c04 commit 0ab8526
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 5 deletions.
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "hopper_rust"
publish = false
readme = "README.md"
repository = "https://github.com/dmweis/hopper_rust"
version = "0.3.11"
version = "0.3.12"

[package.metadata.deb]
assets = [
Expand Down Expand Up @@ -60,6 +60,8 @@ rand = "0.8"
sha2 = "0.10"
thiserror = "1.0"
walkdir = "2.3.3"
bytes = "1.4"
reqwest = {version = "0.11", features = ["json"]}

# logging
tracing = {version = "0.1", features = ["log"]}
Expand Down
1 change: 1 addition & 0 deletions config/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ base:
face_port: "/dev/hopper_face"
tts_service_config:
azure_api_key: ""
eleven_labs_api_key: ""
cache_dir_path: "/var/cache/hopper/audio_cache"
audio_repository_path: "/etc/hopper/audio/"
lidar:
Expand Down
5 changes: 4 additions & 1 deletion examples/astromech.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ struct Args {
async fn main() -> Result<()> {
let args = Args::parse();

let mut speech_service = SpeechService::new(String::from(""), None, Some(args.audio)).unwrap();
let mut speech_service =
SpeechService::new(String::from(""), String::from(""), None, Some(args.audio))
.await
.unwrap();

if let Some(text) = args.text {
speech_service.say_astromech(&text).await.unwrap();
Expand Down
7 changes: 7 additions & 0 deletions src/bin/hopper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,23 @@ async fn main() -> Result<()> {

let mut speech_service = SpeechService::new(
app_config.tts_service_config.azure_api_key,
app_config.tts_service_config.eleven_labs_api_key,
app_config.tts_service_config.cache_dir_path,
app_config.tts_service_config.audio_repository_path,
)
.await
.unwrap();

speech_service
.play_sound("hopper_sounds/windows_startup.wav")
.await
.unwrap();

speech_service
.say_eleven("Hopper ready", "Natasha")
.await
.unwrap();

ioc_container.register(TokioMutex::new(speech_service));

start_speech_controller(
Expand Down
1 change: 1 addition & 0 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct BaseConfig {
#[derive(Deserialize, Debug, Clone)]
pub struct TtsServiceConfig {
pub azure_api_key: String,
pub eleven_labs_api_key: String,
pub cache_dir_path: Option<String>,
pub audio_repository_path: Option<String>,
}
Expand Down
172 changes: 172 additions & 0 deletions src/speech/eleven_labs_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use anyhow::Context;
use anyhow::Result;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Serialize, Deserialize)]
pub struct TtsRequest {
/// Identifier of the model that will be used.
/// Defaults to "eleven_monolingual_v1"
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
/// The text that will get converted into speech.
pub text: String,
/// Voice settings overriding stored setttings for the given voice.
/// They are applied only on the given TTS request.
/// Defaults to None
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_settings: Option<VoiceSettings>,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct VoiceSettings {
pub similarity_boost: f64,
pub stability: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub style: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_speaker_boost: Option<bool>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Voices {
pub voices: Vec<Voice>,
}

impl Voices {
pub fn name_to_id_table(&self) -> HashMap<String, String> {
let mut table = HashMap::new();
for voice in &self.voices {
table.insert(voice.name.clone(), voice.voice_id.clone());
}
table
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct Voice {
pub voice_id: String,
pub name: String,
pub samples: Option<Vec<VoiceSample>>,
pub category: Option<String>,
pub labels: Option<HashMap<String, String>>,
pub description: Option<String>,
pub preview_url: Option<String>,
pub settings: Option<VoiceSettings>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct VoiceSample {
pub sample_id: String,
file_name: String,
mime_type: String,
size_bytes: Option<i64>,
hash: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Subscription {
tier: String,
pub character_count: i64,
pub character_limit: i64,
can_extend_character_limit: bool,
allowed_to_extend_character_limit: bool,
next_character_count_reset_unix: i64,
voice_limit: i64,
professional_voice_limit: i64,
can_extend_voice_limit: bool,
can_use_instant_voice_cloning: bool,
can_use_professional_voice_cloning: bool,
currency: Option<String>,
status: String,
next_invoice: Option<Invoice>,
}

impl Subscription {
#[allow(dead_code)]
pub fn character_left(&self) -> i64 {
self.character_limit - self.character_count
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Invoice {
amount_due_cents: i64,
next_payment_attempt_unix: i64,
}

#[derive(Debug, Clone)]
pub struct ElevenLabsTtsClient {
client: reqwest::Client,
api_key: String,
}

impl ElevenLabsTtsClient {
pub fn new(api_key: String) -> Self {
ElevenLabsTtsClient {
client: reqwest::Client::new(),
api_key,
}
}

pub async fn tts(&self, text: &str, voice_id: &str) -> Result<Bytes> {
let url = format!("https://api.elevenlabs.io/v1/text-to-speech/{}", voice_id);

let body = TtsRequest {
text: text.to_owned(),
model_id: Some(String::from("eleven_multilingual_v1")),
voice_settings: Some(VoiceSettings {
similarity_boost: 0.5,
stability: 0.5,
style: None,
use_speaker_boost: None,
}),
};

let resp = self
.client
.post(url)
.header("xi-api-key", self.api_key.clone())
.header("accept", "audio/mpeg")
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
resp.error_for_status_ref()
.context("Request failed with status")?;
let data = resp.bytes().await?;
Ok(data)
}

pub async fn voices(&self) -> Result<Voices> {
let resp = self
.client
.get("https://api.elevenlabs.io/v1/voices")
.header("xi-api-key", self.api_key.clone())
.header("accept", "application/json")
.send()
.await?;
resp.error_for_status_ref()
.context("Request failed with status")?;
let data = resp.json::<Voices>().await?;

Ok(data)
}

#[allow(dead_code)]
pub async fn get_subscription_info(&self) -> Result<Subscription> {
let resp = self
.client
.get("https://api.elevenlabs.io/v1/user/subscription")
.header("xi-api-key", self.api_key.clone())
.header("accept", "application/json")
.send()
.await?;
resp.error_for_status_ref()
.context("Request failed with status")?;
let data = resp.json::<Subscription>().await?;

Ok(data)
}
}
3 changes: 3 additions & 0 deletions src/speech/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ mod audio_repository;
#[cfg(feature = "audio")]
mod speech_service;

#[cfg(feature = "audio")]
mod eleven_labs_client;

#[cfg(feature = "audio")]
pub use speech_service::SpeechService;

Expand Down
70 changes: 68 additions & 2 deletions src/speech/speech_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::audio_cache::AudioCache;
use super::audio_repository::AudioRepository;
use super::AzureVoiceStyle;
use crate::error::{HopperError, HopperResult};
use anyhow::Context;
use sha2::{Digest, Sha256};
use std::{fs::File, io::Cursor, thread};
use tokio::sync::mpsc::{channel, Receiver, Sender};
Expand Down Expand Up @@ -32,6 +33,15 @@ fn hash_azure_tts(
format!("{}-{:x}", voice.name, hashed)
}

fn hash_eleven_labs_tts(text: &str, voice_id: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(text);
hasher.update(voice_id);
hasher.update(AZURE_FORMAT_VERSION.to_be_bytes());
let hashed = hasher.finalize();
format!("eleven-{:x}", hashed)
}

enum AudioPlayerCommand {
Play(Box<dyn Playable>),
Pause,
Expand Down Expand Up @@ -128,6 +138,8 @@ fn create_player() -> Sender<AudioPlayerCommand> {

pub struct SpeechService {
azure_speech_client: azure_tts::VoiceService,
eleven_labs_client: super::eleven_labs_client::ElevenLabsTtsClient,
voice_name_to_voice_id_table: std::collections::HashMap<String, String>,
audio_cache: Option<AudioCache>,
audio_repository: Option<AudioRepository>,
azure_voice: azure_tts::VoiceSettings,
Expand All @@ -140,15 +152,25 @@ pub trait Playable: std::io::Read + std::io::Seek + Send + Sync {}
impl Playable for Cursor<Vec<u8>> {}
impl Playable for File {}

/// voice Freya
const DEFAULT_ELEVEN_LABS_VOICE_ID: &str = "jsCqWAovK2LkecY7zXl4";

impl SpeechService {
pub fn new(
pub async fn new(
azure_subscription_key: String,
eleven_labs_api_key: String,
cache_dir_path: Option<String>,
audio_repository_path: Option<String>,
) -> HopperResult<SpeechService> {
) -> anyhow::Result<SpeechService> {
let azure_speech_client =
azure_tts::VoiceService::new(&azure_subscription_key, azure_tts::Region::uksouth);

let eleven_labs_client =
super::eleven_labs_client::ElevenLabsTtsClient::new(eleven_labs_api_key);

let voices = eleven_labs_client.voices().await?;
let voice_name_to_voice_id_table = voices.name_to_id_table();

let audio_cache = match cache_dir_path {
Some(path) => Some(AudioCache::new(path)?),
None => None,
Expand All @@ -163,6 +185,8 @@ impl SpeechService {

Ok(SpeechService {
azure_speech_client,
eleven_labs_client,
voice_name_to_voice_id_table,
audio_cache,
audio_repository,
azure_voice: azure_tts::EnUsVoices::SaraNeural.to_voice_settings(),
Expand Down Expand Up @@ -321,6 +345,48 @@ impl SpeechService {
.await
}

pub async fn say_eleven_with_default_voice(&mut self, text: &str) -> anyhow::Result<()> {
self.say_eleven_with_voice_id(text, DEFAULT_ELEVEN_LABS_VOICE_ID)
.await?;
Ok(())
}

pub async fn say_eleven(&mut self, text: &str, voice_name: &str) -> anyhow::Result<()> {
let voice_id = self
.voice_name_to_voice_id_table
.get(voice_name)
.context("Unknown voice")?
.clone();
info!("Using voice id {} for voice {}", voice_id, voice_name);
self.say_eleven_with_voice_id(text, &voice_id).await?;
Ok(())
}

pub async fn say_eleven_with_voice_id(
&mut self,
text: &str,
voice_id: &str,
) -> anyhow::Result<()> {
let sound: Box<dyn Playable> = if let Some(ref audio_cache) = self.audio_cache {
let file_key = hash_eleven_labs_tts(text, voice_id);
if let Some(file) = audio_cache.get(&file_key) {
info!("Using cached value with key {}", file_key);
file
} else {
info!("Writing new file with key {}", file_key);
let data = self.eleven_labs_client.tts(text, voice_id).await?;
let sound: Box<dyn Playable> = Box::new(Cursor::new(data.to_vec()));
audio_cache.set(&file_key, data.to_vec())?;
sound
}
} else {
let data = self.eleven_labs_client.tts(text, voice_id).await?;
Box::new(Cursor::new(data.to_vec()))
};
self.play(sound).await;
Ok(())
}

pub async fn pause(&self) {
self.audio_sender
.send(AudioPlayerCommand::Pause)
Expand Down

0 comments on commit 0ab8526

Please sign in to comment.