Skip to content

Commit

Permalink
Voice can toggle voice provider
Browse files Browse the repository at this point in the history
  • Loading branch information
dmweis committed Nov 22, 2023
1 parent 7009c8f commit bfaa0c5
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "hopper_rust"
publish = false
readme = "README.md"
repository = "https://github.com/dmweis/hopper_rust"
version = "0.4.8"
version = "0.4.9"

[package.metadata.deb]
assets = [
Expand Down
98 changes: 87 additions & 11 deletions src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{
collections::HashMap,
sync::{atomic::AtomicU8, Arc},
sync::{atomic::AtomicU8, Arc, Mutex},
};
use tokio::select;
use tracing::info;
Expand Down Expand Up @@ -101,6 +101,16 @@ pub async fn start_openai_controller(
Arc::new(FaceDisplayFuncCallback),
)?;

let voice_provider_arc = Arc::new(Mutex::new(VoiceProvider::Default));

chat_gpt_conversation.add_function::<SwitchVoiceFuncArgs>(
"switch_voice_provider",
"Switch voice provider",
Arc::new(SwitchVoiceFuncCallback {
voice_provider: voice_provider_arc.clone(),
}),
)?;

let simple_text_command_subscriber = zenoh_session
.declare_subscriber(HOPPER_OPENAI_COMMAND_SUBSCRIBER)
.res()
Expand Down Expand Up @@ -144,12 +154,15 @@ pub async fn start_openai_controller(
info!("Received new zenoh text command");
let text_command_msg = text_command_msg?;
let text_command: String = text_command_msg.value.try_into()?;
process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone()).await?;
let voice_provider = *voice_provider_arc.lock().unwrap();

process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone(), voice_provider).await?;
}
text_command = receiver.recv() => {
if let Some(text_command) = text_command {
info!("Received new text command");
process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone()).await?;
let voice_provider = *voice_provider_arc.lock().unwrap();
process_simple_text_command(&text_command, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone(), voice_provider).await?;
}
}
wake_word_detection = wake_word_detection_subscriber.recv_async() => {
Expand Down Expand Up @@ -182,7 +195,8 @@ pub async fn start_openai_controller(
let wake_word_transcript: AudioTranscript = serde_json::from_str(&wake_word_transcript)?;
if wake_word_transcript.wake_word.to_lowercase().contains("hopper") {
info!("Received new text command");
process_simple_text_command(&wake_word_transcript.transcript, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone()).await?;
let voice_provider = *voice_provider_arc.lock().unwrap();
process_simple_text_command(&wake_word_transcript.transcript, chat_gpt_conversation.clone(), client.clone(), zenoh_session.clone(), voice_provider).await?;
}
}
}
Expand All @@ -205,6 +219,7 @@ async fn process_simple_text_command(
mut conversation: ChatGptConversation,
open_ai_client: Client<OpenAIConfig>,
zenoh_session: Arc<zenoh::Session>,
voice_provider: VoiceProvider,
) -> anyhow::Result<()> {
info!("Received hopper command {:?}", text_command);

Expand All @@ -229,7 +244,7 @@ async fn process_simple_text_command(
info!("Assistant response form ChatGPT: {:?}", response);

tokio::spawn(async move {
if let Err(err) = speak_with_face_animation(&response).await {
if let Err(err) = speak_with_face_animation(&response, voice_provider).await {
tracing::error!("Failed to speak with face animation: {}", err);
}
});
Expand All @@ -253,12 +268,24 @@ async fn process_simple_text_command(
Ok(())
}

async fn speak_with_face_animation(message: &str) -> anyhow::Result<()> {
IocContainer::global_instance()
.service::<SpeechService>()?
.say_azure_with_style(message, crate::speech::AzureVoiceStyle::Cheerful)
// .say_eleven_with_default_voice(message)
.await?;
async fn speak_with_face_animation(
message: &str,
voice_provider: VoiceProvider,
) -> anyhow::Result<()> {
match voice_provider {
VoiceProvider::Default => {
IocContainer::global_instance()
.service::<SpeechService>()?
.say_azure_with_style(message, crate::speech::AzureVoiceStyle::Cheerful)
.await?;
}
VoiceProvider::Expensive => {
IocContainer::global_instance()
.service::<SpeechService>()?
.say_eleven_with_default_voice(message)
.await?;
}
}

IocContainer::global_instance()
.service::<crate::face::FaceController>()?
Expand Down Expand Up @@ -566,6 +593,20 @@ impl AsyncCallback for HopperBodyPoseFuncCallback {
HopperBodyPose::Sitting => "ground",
};

// stop high fives in case we are sitting down or folding
match hopper_body_pose_func.body_pose {
HopperBodyPose::Folded | HopperBodyPose::Sitting => {
IocContainer::global_instance()
.service::<HighFiveServiceController>()?
.set_active(false);

IocContainer::global_instance()
.service::<LidarServiceController>()?
.set_active(false);
}
_ => (),
}

self.zenoh_session
.put(STANCE_SUBSCRIBER, message)
.res()
Expand Down Expand Up @@ -707,3 +748,38 @@ impl AsyncCallback for FaceDisplayFuncCallback {
Ok(result)
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SwitchVoiceFuncArgs {
/// TTS voice provider
pub voice_provider: VoiceProvider,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum VoiceProvider {
/// default voice provider Microsoft Azure
#[default]
Default,
/// expensive voice provider form Eleven Labs
/// Should be used carefully
Expensive,
}

struct SwitchVoiceFuncCallback {
voice_provider: Arc<Mutex<VoiceProvider>>,
}

#[async_trait]
impl AsyncCallback for SwitchVoiceFuncCallback {
async fn call(&self, args: &str) -> anyhow::Result<serde_json::Value> {
let switch_voice: SwitchVoiceFuncArgs = serde_json::from_str(args)?;

let mut voice_provider = self.voice_provider.lock().unwrap();

*voice_provider = switch_voice.voice_provider;

let result = json!({});
Ok(result)
}
}

0 comments on commit bfaa0c5

Please sign in to comment.