diff --git a/Cargo.lock b/Cargo.lock index 09c78a8..605c26f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1950,7 +1950,7 @@ dependencies = [ [[package]] name = "hopper_rust" -version = "0.4.8" +version = "0.4.9" dependencies = [ "anyhow", "approx 0.5.1", diff --git a/Cargo.toml b/Cargo.toml index 6f3eeb3..0f940cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ name = "hopper_rust" publish = false readme = "README.md" repository = "https://github.com/dmweis/hopper_rust" -version = "0.4.8" +version = "0.4.9" [package.metadata.deb] assets = [ diff --git a/src/openai.rs b/src/openai.rs index 02ed6bc..1e25fd1 100644 --- a/src/openai.rs +++ b/src/openai.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use std::{ collections::HashMap, - sync::{atomic::AtomicU8, Arc}, + sync::{atomic::AtomicU8, Arc, Mutex}, }; use tokio::select; use tracing::info; @@ -101,6 +101,16 @@ pub async fn start_openai_controller( Arc::new(FaceDisplayFuncCallback), )?; + let voice_provider_arc = Arc::new(Mutex::new(VoiceProvider::Default)); + + chat_gpt_conversation.add_function::( + "switch_voice_provider", + "Switch voice provider", + Arc::new(SwitchVoiceFuncCallback { + voice_provider: voice_provider_arc.clone(), + }), + )?; + let simple_text_command_subscriber = zenoh_session .declare_subscriber(HOPPER_OPENAI_COMMAND_SUBSCRIBER) .res() @@ -144,12 +154,15 @@ pub async fn start_openai_controller( info!("Received new zenoh text command"); let text_command_msg = text_command_msg?; let text_command: String = text_command_msg.value.try_into()?; - process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone()).await?; + let voice_provider = *voice_provider_arc.lock().unwrap(); + + process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone(), voice_provider).await?; } text_command = receiver.recv() => { if let Some(text_command) = text_command { info!("Received new text command"); - process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone()).await?; + let voice_provider = *voice_provider_arc.lock().unwrap(); + process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone(), voice_provider).await?; } } wake_word_detection = wake_word_detection_subscriber.recv_async() => { @@ -182,7 +195,8 @@ pub async fn start_openai_controller( let wake_word_transcript: AudioTranscript = serde_json::from_str(&wake_word_transcript)?; if wake_word_transcript.wake_word.to_lowercase().contains("hopper") { info!("Received new text command"); - process_simple_text_command(&wake_word_transcript.transcript, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone()).await?; + let voice_provider = *voice_provider_arc.lock().unwrap(); + process_simple_text_command(&wake_word_transcript.transcript, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone(), voice_provider).await?; } } } @@ -205,6 +219,7 @@ async fn process_simple_text_command( mut conversation: ChatGptConversation, open_ai_client: Client, zenoh_session: Arc, + voice_provider: VoiceProvider, ) -> anyhow::Result<()> { info!("Received hopper command {:?}", text_command); @@ -229,7 +244,7 @@ async fn process_simple_text_command( info!("Assistant response form ChatGPT: {:?}", response); tokio::spawn(async move { - if let Err(err) = speak_with_face_animation(&response).await { + if let Err(err) = speak_with_face_animation(&response, voice_provider).await { tracing::error!("Failed to speak with face animation: {}", err); } }); @@ -253,12 +268,24 @@ async fn process_simple_text_command( Ok(()) } -async fn speak_with_face_animation(message: &str) -> anyhow::Result<()> { - IocContainer::global_instance() - .service::()? - .say_azure_with_style(message, crate::speech::AzureVoiceStyle::Cheerful) - // .say_eleven_with_default_voice(message) - .await?; +async fn speak_with_face_animation( + message: &str, + voice_provider: VoiceProvider, +) -> anyhow::Result<()> { + match voice_provider { + VoiceProvider::Default => { + IocContainer::global_instance() + .service::()? + .say_azure_with_style(message, crate::speech::AzureVoiceStyle::Cheerful) + .await?; + } + VoiceProvider::Expensive => { + IocContainer::global_instance() + .service::()? + .say_eleven_with_default_voice(message) + .await?; + } + } IocContainer::global_instance() .service::()? @@ -566,6 +593,20 @@ impl AsyncCallback for HopperBodyPoseFuncCallback { HopperBodyPose::Sitting => "ground", }; + // stop high fives in case we are sitting down or folding + match hopper_body_pose_func.body_pose { + HopperBodyPose::Folded | HopperBodyPose::Sitting => { + IocContainer::global_instance() + .service::()? + .set_active(false); + + IocContainer::global_instance() + .service::()? + .set_active(false); + } + _ => (), + } + self.zenoh_session .put(STANCE_SUBSCRIBER, message) .res() @@ -707,3 +748,38 @@ impl AsyncCallback for FaceDisplayFuncCallback { Ok(result) } } + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SwitchVoiceFuncArgs { + /// TTS voice provider + pub voice_provider: VoiceProvider, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)] +#[serde(rename_all = "snake_case")] +pub enum VoiceProvider { + /// default voice provider Microsoft Azure + #[default] + Default, + /// expensive voice provider form Eleven Labs + /// Should be used carefully + Expensive, +} + +struct SwitchVoiceFuncCallback { + voice_provider: Arc>, +} + +#[async_trait] +impl AsyncCallback for SwitchVoiceFuncCallback { + async fn call(&self, args: &str) -> anyhow::Result { + let switch_voice: SwitchVoiceFuncArgs = serde_json::from_str(args)?; + + let mut voice_provider = self.voice_provider.lock().unwrap(); + + *voice_provider = switch_voice.voice_provider; + + let result = json!({}); + Ok(result) + } +}