Skip to content

Commit

Permalink
feat: add support for client-tools (Python SDK)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisjoecodes committed Dec 17, 2024
1 parent d1f0878 commit e5bacd3
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 7 deletions.
61 changes: 54 additions & 7 deletions src/elevenlabs/conversational_ai/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,44 @@ def interrupt(self):
"""
pass


class ClientTools:
def __init__(self):
self.tools = {}
self.lock = threading.Lock()

def register(self, tool_name, handler):
with self.lock:
if not callable(handler):
raise ValueError("Handler must be callable")
self.tools[tool_name] = handler

def handle(self, tool_name, parameters):
with self.lock:
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' is not registered.")
return self.tools[tool_name](parameters)


class ConversationConfig:
"""Configuration options for the Conversation."""

def __init__(
self,
extra_body: Optional[dict] = None,
conversation_config_override: Optional[dict] = None,
):
self.extra_body = extra_body or {}
self.conversation_config_override = conversation_config_override or {}



class Conversation:
client: BaseElevenLabs
agent_id: str
requires_auth: bool
config: ConversationConfig
audio_interface: AudioInterface
client_tools: Optional[ClientTools]
callback_agent_response: Optional[Callable[[str], None]]
callback_agent_response_correction: Optional[Callable[[str, str], None]]
callback_user_transcript: Optional[Callable[[str], None]]
Expand All @@ -86,7 +108,7 @@ def __init__(
requires_auth: bool,
audio_interface: AudioInterface,
config: Optional[ConversationConfig] = None,

client_tools: Optional[ClientTools] = None,
callback_agent_response: Optional[Callable[[str], None]] = None,
callback_agent_response_correction: Optional[Callable[[str, str], None]] = None,
callback_user_transcript: Optional[Callable[[str], None]] = None,
Expand All @@ -101,6 +123,7 @@ def __init__(
agent_id: The ID of the agent to converse with.
requires_auth: Whether the agent requires authentication.
audio_interface: The audio interface to use for input and output.
client_tools: The client tools to use for the conversation.
callback_agent_response: Callback for agent responses.
callback_agent_response_correction: Callback for agent response corrections.
First argument is the original response (previously given to
Expand All @@ -112,10 +135,10 @@ def __init__(
self.client = client
self.agent_id = agent_id
self.requires_auth = requires_auth

self.audio_interface = audio_interface
self.callback_agent_response = callback_agent_response
self.config = config or ConversationConfig()
self.client_tools = client_tools or ClientTools()
self.callback_agent_response_correction = callback_agent_response_correction
self.callback_user_transcript = callback_user_transcript
self.callback_latency_measurement = callback_latency_measurement
Expand Down Expand Up @@ -151,10 +174,10 @@ def _run(self, ws_url: str):
with connect(ws_url) as ws:
ws.send(
json.dumps(
{
"type": "conversation_initiation_client_data",
"custom_llm_extra_body": self.config.extra_body,
"conversation_config_override": self.config.conversation_config_override,
{
"type": "conversation_initiation_client_data",
"custom_llm_extra_body": self.config.extra_body,
"conversation_config_override": self.config.conversation_config_override,
}
)
)
Expand Down Expand Up @@ -220,6 +243,30 @@ def _handle_message(self, message, ws):
)
if self.callback_latency_measurement and event["ping_ms"]:
self.callback_latency_measurement(int(event["ping_ms"]))
elif message["type"] == "client_tool_call":
"""Handle tool calls from the server."""
tool_call = message.get("client_tool_call", {})
tool_name = tool_call.get("tool_name")
parameters = tool_call.get("parameters", {})

try:
result = (
self.client_tools.handle(tool_name, parameters) or f"Client tool: {tool_name} called successfully."
)
response = {
"type": "client_tool_result",
"tool_call_id": tool_call["tool_call_id"],
"result": result,
"is_error": False,
}
except Exception as e:
response = {
"type": "client_tool_result",
"tool_call_id": tool_call["tool_call_id"],
"result": str(e),
"is_error": True,
}
ws.send(json.dumps(response))
else:
pass # Ignore all other message types.

Expand Down
71 changes: 71 additions & 0 deletions tests/e2e_test_convai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import time

import pytest
from elevenlabs import ElevenLabs
from elevenlabs.conversational_ai.conversation import Conversation, ClientTools
from elevenlabs.conversational_ai.default_audio_interface import DefaultAudioInterface


@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip live conversation test in CI environment")
def test_live_conversation():
"""Test a live conversation with actual audio I/O"""

# Get API key from environment
api_key = os.getenv("ELEVENLABS_API_KEY")
if not api_key:
raise ValueError("Please set ELEVENLABS_API_KEY environment variable")

# Initialize client
client = ElevenLabs(api_key=api_key)

AGENT_ID = "<insert-testing-agent-id>"

# Create conversation handlers
def on_agent_response(text: str):
print(f"Agent: {text}")

def on_user_transcript(text: str):
print(f"You: {text}")

def on_latency(ms: int):
print(f"Latency: {ms}ms")

# Initialize client tools
client_tools = ClientTools()

def test_tool_handler(parameters):
print("Tool called with parameters:", parameters)
return f"Handled parameters: {parameters}"

client_tools.register("test", test_tool_handler)

# Initialize conversation
conversation = Conversation(
client=client,
agent_id=AGENT_ID,
requires_auth=False,
audio_interface=DefaultAudioInterface(),
callback_agent_response=on_agent_response,
callback_user_transcript=on_user_transcript,
callback_latency_measurement=on_latency,
client_tools=client_tools,
)

# Start the conversation
conversation.start_session()

# Let it run for 30 seconds
time.sleep(30)

# End the conversation
conversation.end_session()
conversation.wait_for_session_end()

# Get the conversation ID for reference
conversation_id = conversation._conversation_id
print(f"Conversation ID: {conversation_id}")


if __name__ == "__main__":
test_live_conversation()

0 comments on commit e5bacd3

Please sign in to comment.