Skip to content

Commit 7cea8bd

Browse files
authored
add session state (#2)
* add session state * pin requirements * readme update
1 parent 4124f27 commit 7cea8bd

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
---
44

5+
**NOTE: This does NOT work on mobile iOS devices. See [Gradio Issue #2987](https://github.com/gradio-app/gradio/issues/2987) for details.**
6+
57
This is a Gradio UI application that takes in a request for a story from the microphone
68
and speaks an interactive Choose-Your-Own-Adventure style children's story. It leverages:
79

requirements.txt

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
black
2-
flake8
3-
google-cloud-texttospeech
4-
gradio
5-
openai
6-
pre-commit
1+
google-cloud-texttospeech==2.14.1
2+
gradio==3.20.1
3+
openai==0.27.1

storyteller.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from config import SpeechMethod
1616
from google.cloud import texttospeech
17+
from typing import Dict, List, Tuple
1718

1819

1920
# Set OpenAI API Key
@@ -59,26 +60,38 @@ def transcribe_audio(audio_file: str) -> str:
5960
return text_transcription
6061

6162

62-
def chat_complete(text_input: str) -> str:
63+
def chat_complete(
64+
text_input: str, messages: List[Dict[str, str]]
65+
) -> Tuple[str, List[Dict[str, str]]]:
6366
"""
6467
Chat complete using OpenAI API. This is what generates stories.
6568
6669
Args:
6770
text_input: Text to use as prompt for story generation
71+
messages: List of previous messages
6872
6973
Returns:
7074
str: Generated story
75+
messages: Updated list of messages
7176
"""
72-
global messages
77+
# Init with prompt on first call
78+
if not messages:
79+
messages = [
80+
{
81+
"role": "system",
82+
"content": config.INITIAL_PROMPT,
83+
}
84+
]
7385

7486
# Append to messages for chat completion
7587
messages.append({"role": "user", "content": text_input})
7688

7789
# Fetch response from OpenAI
90+
print("Messages sent to call: ", messages)
7891
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
7992

8093
# Extract and store message
81-
system_message = response["choices"][0]["message"]
94+
system_message = dict(response["choices"][0]["message"])
8295
messages.append(system_message)
8396

8497
# Return message to display
@@ -93,7 +106,7 @@ def chat_complete(text_input: str) -> str:
93106
for message in messages:
94107
f.write(f"{message['role']}: {message['content']}\n\n")
95108

96-
return display_message
109+
return display_message, messages
97110

98111

99112
def generate_image(text_input: str) -> str:
@@ -161,6 +174,9 @@ def text_to_speech(input_text: str) -> str:
161174
Gradio UI Definition
162175
"""
163176
with gr.Blocks(analytics_enabled=False, title="Audio Storyteller") as ui:
177+
# Session state box containing all user/system messages, hidden
178+
messages = gr.State(list())
179+
164180
with gr.Row():
165181
with gr.Column(scale=1):
166182
# Audio Input Box
@@ -190,7 +206,9 @@ def text_to_speech(input_text: str) -> str:
190206
audio_input.change(transcribe_audio, audio_input, transcribed_input)
191207

192208
# Connect user trainput to story output
193-
transcribed_input.change(chat_complete, transcribed_input, story_msg)
209+
transcribed_input.change(
210+
chat_complete, [transcribed_input, messages], [story_msg, messages]
211+
)
194212

195213
# Connect story output to image generation
196214
story_msg.change(generate_image, story_msg, gen_image)

0 commit comments

Comments
 (0)