14
14
15
15
from config import SpeechMethod
16
16
from google .cloud import texttospeech
17
+ from typing import Dict , List , Tuple
17
18
18
19
19
20
# Set OpenAI API Key
@@ -59,26 +60,38 @@ def transcribe_audio(audio_file: str) -> str:
59
60
return text_transcription
60
61
61
62
62
- def chat_complete (text_input : str ) -> str :
63
+ def chat_complete (
64
+ text_input : str , messages : List [Dict [str , str ]]
65
+ ) -> Tuple [str , List [Dict [str , str ]]]:
63
66
"""
64
67
Chat complete using OpenAI API. This is what generates stories.
65
68
66
69
Args:
67
70
text_input: Text to use as prompt for story generation
71
+ messages: List of previous messages
68
72
69
73
Returns:
70
74
str: Generated story
75
+ messages: Updated list of messages
71
76
"""
72
- global messages
77
+ # Init with prompt on first call
78
+ if not messages :
79
+ messages = [
80
+ {
81
+ "role" : "system" ,
82
+ "content" : config .INITIAL_PROMPT ,
83
+ }
84
+ ]
73
85
74
86
# Append to messages for chat completion
75
87
messages .append ({"role" : "user" , "content" : text_input })
76
88
77
89
# Fetch response from OpenAI
90
+ print ("Messages sent to call: " , messages )
78
91
response = openai .ChatCompletion .create (model = "gpt-3.5-turbo" , messages = messages )
79
92
80
93
# Extract and store message
81
- system_message = response ["choices" ][0 ]["message" ]
94
+ system_message = dict ( response ["choices" ][0 ]["message" ])
82
95
messages .append (system_message )
83
96
84
97
# Return message to display
@@ -93,7 +106,7 @@ def chat_complete(text_input: str) -> str:
93
106
for message in messages :
94
107
f .write (f"{ message ['role' ]} : { message ['content' ]} \n \n " )
95
108
96
- return display_message
109
+ return display_message , messages
97
110
98
111
99
112
def generate_image (text_input : str ) -> str :
@@ -161,6 +174,9 @@ def text_to_speech(input_text: str) -> str:
161
174
Gradio UI Definition
162
175
"""
163
176
with gr .Blocks (analytics_enabled = False , title = "Audio Storyteller" ) as ui :
177
+ # Session state box containing all user/system messages, hidden
178
+ messages = gr .State (list ())
179
+
164
180
with gr .Row ():
165
181
with gr .Column (scale = 1 ):
166
182
# Audio Input Box
@@ -190,7 +206,9 @@ def text_to_speech(input_text: str) -> str:
190
206
audio_input .change (transcribe_audio , audio_input , transcribed_input )
191
207
192
208
# Connect user trainput to story output
193
- transcribed_input .change (chat_complete , transcribed_input , story_msg )
209
+ transcribed_input .change (
210
+ chat_complete , [transcribed_input , messages ], [story_msg , messages ]
211
+ )
194
212
195
213
# Connect story output to image generation
196
214
story_msg .change (generate_image , story_msg , gen_image )
0 commit comments