Skip to content

Commit 07e7957

Browse files
committed
Enhance error handling in ChatRepository
1 parent 4597abd commit 07e7957

File tree

12 files changed

+238
-170
lines changed

12 files changed

+238
-170
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package io.jja08111.gemini.feature.chat.data.exception
2+
3+
class EmptyContentException : Exception("Message and images are empty")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package io.jja08111.gemini.feature.chat.data.exception
2+
3+
class EmptyMessageGroupsOnRegenerationException : Exception(
4+
"Message group list is empty when regenerating response",
5+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package io.jja08111.gemini.feature.chat.data.exception
2+
3+
class MessageInsertionException(cause: Throwable? = null) : Exception(cause)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package io.jja08111.gemini.feature.chat.data.exception
2+
3+
class NotJoinedException : Exception("Must join room before sending message or doing other actions")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package io.jja08111.gemini.feature.chat.data.exception
2+
3+
class RoomInsertionException(cause: Throwable? = null) : Exception(cause)

feature/chat/data/src/main/java/io/jja08111/gemini/feature/chat/data/repository/GenerativeChatRepository.kt

+98-98
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import io.github.jja08111.core.common.di.IoDispatcher
99
import io.github.jja08111.core.common.image.BitmapCreator
1010
import io.jja08111.gemini.database.entity.ModelResponseStateEntity
1111
import io.jja08111.gemini.feature.chat.data.BuildConfig
12+
import io.jja08111.gemini.feature.chat.data.exception.EmptyContentException
13+
import io.jja08111.gemini.feature.chat.data.exception.EmptyMessageGroupsOnRegenerationException
14+
import io.jja08111.gemini.feature.chat.data.exception.NotJoinedException
1215
import io.jja08111.gemini.feature.chat.data.extension.toContents
1316
import io.jja08111.gemini.feature.chat.data.extension.toResponseContentPartials
1417
import io.jja08111.gemini.feature.chat.data.model.AttachedImage
@@ -27,7 +30,6 @@ import kotlinx.coroutines.flow.Flow
2730
import kotlinx.coroutines.flow.catch
2831
import kotlinx.coroutines.flow.collect
2932
import kotlinx.coroutines.flow.first
30-
import kotlinx.coroutines.flow.flowOf
3133
import kotlinx.coroutines.flow.onCompletion
3234
import kotlinx.coroutines.flow.onEach
3335
import kotlinx.coroutines.launch
@@ -61,116 +63,118 @@ class GenerativeChatRepository @Inject constructor(
6163
candidateCount = CANDIDATE_COUNT
6264
},
6365
)
64-
return try {
65-
chatLocalDataSource.getMessageGroupStream(roomId)
66-
} catch (e: Exception) {
67-
flowOf(emptyList())
68-
}
66+
return chatLocalDataSource.getMessageGroupStream(roomId)
6967
}
7068

7169
override suspend fun sendMessage(
7270
message: String,
7371
images: List<AttachedImage>,
7472
onRoomCreated: (Flow<List<MessageGroup>>) -> Unit,
7573
): Result<Unit> {
76-
val roomId = joinedRoomId ?: throwNotJoinedError()
77-
val model = generativeModel ?: throwNotJoinedError()
78-
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
79-
val messageGroups = messageGroupStream.first()
80-
val isNewChat = messageGroups.isEmpty()
81-
82-
if (isNewChat) {
83-
val title = when {
84-
message.isNotEmpty() -> message
85-
images.isNotEmpty() -> "Image question"
86-
else -> error("Message is empty with null image")
74+
return runCatching {
75+
val roomId = joinedRoomId ?: throw NotJoinedException()
76+
val model = generativeModel ?: throw NotJoinedException()
77+
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
78+
val messageGroups = messageGroupStream.first()
79+
val isNewChat = messageGroups.isEmpty()
80+
81+
if (isNewChat) {
82+
val title = when {
83+
message.isNotEmpty() -> message
84+
images.isNotEmpty() -> "Image question"
85+
else -> throw EmptyContentException()
86+
}
87+
chatLocalDataSource.insertRoom(roomId = roomId, title = title)
88+
onRoomCreated(messageGroupStream)
8789
}
88-
chatLocalDataSource.insertRoom(roomId = roomId, title = title)
89-
onRoomCreated(messageGroupStream)
90-
}
9190

92-
val promptId = createId()
93-
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
94-
val parentModelResponseId = messageGroups.lastOrNull()?.selectedResponse?.id
95-
val imageBitmaps = images.map {
96-
when (it) {
97-
is AttachedImage.Bitmap -> it.bitmap
98-
is AttachedImage.Uri -> bitmapCreator.create(it.uri)
91+
val promptId = createId()
92+
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
93+
val parentModelResponseId = messageGroups.lastOrNull()?.selectedResponse?.id
94+
val imageBitmaps = images.map {
95+
when (it) {
96+
is AttachedImage.Bitmap -> it.bitmap
97+
is AttachedImage.Uri -> bitmapCreator.create(it.uri)
98+
}
9999
}
100-
}
101100

102-
chatLocalDataSource.insertInitialMessageGroup(
103-
prompt = message,
104-
imageBitmaps = imageBitmaps,
105-
roomId = roomId,
106-
promptId = promptId,
107-
responseIds = responseTextBuilders.map { it.id },
108-
parentModelResponseId = parentModelResponseId,
109-
)
110-
111-
return model.generateTextMessageStream(
112-
message = message,
113-
images = imageBitmaps,
114-
history = messageGroups.flatMap(MessageGroup::toContents),
115-
promptId = promptId,
116-
responseTextBuilders = responseTextBuilders,
117-
)
101+
chatLocalDataSource.insertInitialMessageGroup(
102+
prompt = message,
103+
imageBitmaps = imageBitmaps,
104+
roomId = roomId,
105+
promptId = promptId,
106+
responseIds = responseTextBuilders.map { it.id },
107+
parentModelResponseId = parentModelResponseId,
108+
)
109+
110+
return model.generateTextMessageStream(
111+
message = message,
112+
images = imageBitmaps,
113+
history = messageGroups.flatMap(MessageGroup::toContents),
114+
promptId = promptId,
115+
responseTextBuilders = responseTextBuilders,
116+
)
117+
}
118118
}
119119

120120
override suspend fun regenerateOnError(): Result<Unit> {
121-
val model = generativeModel ?: throwNotJoinedError()
122-
val roomId = joinedRoomId ?: throwNotJoinedError()
123-
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
124-
val messageGroups = messageGroupStream.first()
125-
val lastMessageGroup =
126-
messageGroups.lastOrNull() ?: error("Message group list is empty when regenerating response")
127-
val lastPrompt = lastMessageGroup.prompt
128-
val lastPromptId = lastPrompt.id
129-
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
130-
131-
chatLocalDataSource.insertResponsesAndRemoveError(
132-
newResponseIds = responseTextBuilders.map { it.id },
133-
errorResponseId = lastMessageGroup.selectedResponse.id,
134-
roomId = roomId,
135-
promptId = lastPromptId,
136-
)
137-
138-
return model.generateTextMessageStream(
139-
message = lastPrompt.text,
140-
images = lastPrompt.images.map { promptImageLocalDataSource.loadImage(it.path) },
141-
history = messageGroups
142-
.dropLast(1)
143-
.flatMap(MessageGroup::toContents),
144-
promptId = lastPromptId,
145-
responseTextBuilders = responseTextBuilders,
146-
)
121+
return runCatching {
122+
val model = generativeModel ?: throw NotJoinedException()
123+
val roomId = joinedRoomId ?: throw NotJoinedException()
124+
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
125+
val messageGroups = messageGroupStream.first()
126+
val lastMessageGroup =
127+
messageGroups.lastOrNull() ?: throw EmptyMessageGroupsOnRegenerationException()
128+
val lastPrompt = lastMessageGroup.prompt
129+
val lastPromptId = lastPrompt.id
130+
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
131+
132+
chatLocalDataSource.insertResponsesAndRemoveError(
133+
newResponseIds = responseTextBuilders.map { it.id },
134+
errorResponseId = lastMessageGroup.selectedResponse.id,
135+
roomId = roomId,
136+
promptId = lastPromptId,
137+
)
138+
139+
return model.generateTextMessageStream(
140+
message = lastPrompt.text,
141+
images = lastPrompt.images.map { promptImageLocalDataSource.loadImage(it.path) },
142+
history = messageGroups
143+
.dropLast(1)
144+
.flatMap(MessageGroup::toContents),
145+
promptId = lastPromptId,
146+
responseTextBuilders = responseTextBuilders,
147+
)
148+
}
147149
}
148150

149151
override suspend fun regenerateResponse(responseId: String): Result<Unit> {
150-
val model = generativeModel ?: throwNotJoinedError()
151-
val roomId = joinedRoomId ?: throwNotJoinedError()
152-
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
153-
val messageGroups = messageGroupStream.first()
154-
val messageGroup = messageGroups.firstOrNull {
155-
it.selectedResponse.id == responseId
156-
} ?: error("Message group list is empty when regenerating response")
157-
val prompt = messageGroup.prompt
158-
val promptId = prompt.id
159-
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
160-
161-
chatLocalDataSource.insertAndUnselectOldResponses(
162-
newResponseIds = responseTextBuilders.map { it.id },
163-
roomId = roomId,
164-
promptId = promptId,
165-
)
166-
167-
return model.generateTextMessageStream(
168-
message = prompt.text,
169-
images = prompt.images.map { promptImageLocalDataSource.loadImage(it.path) },
170-
history = messageGroups.flatMap(MessageGroup::toContents),
171-
promptId = promptId,
172-
responseTextBuilders = responseTextBuilders,
173-
)
152+
return runCatching {
153+
val model = generativeModel ?: throw NotJoinedException()
154+
val roomId = joinedRoomId ?: throw NotJoinedException()
155+
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
156+
val messageGroups = messageGroupStream.first()
157+
val messageGroup = messageGroups.firstOrNull {
158+
it.selectedResponse.id == responseId
159+
} ?: throw EmptyMessageGroupsOnRegenerationException()
160+
val prompt = messageGroup.prompt
161+
val promptId = prompt.id
162+
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
163+
164+
chatLocalDataSource.insertAndUnselectOldResponses(
165+
newResponseIds = responseTextBuilders.map { it.id },
166+
roomId = roomId,
167+
promptId = promptId,
168+
)
169+
170+
return model.generateTextMessageStream(
171+
message = prompt.text,
172+
images = prompt.images.map { promptImageLocalDataSource.loadImage(it.path) },
173+
history = messageGroups.flatMap(MessageGroup::toContents),
174+
promptId = promptId,
175+
responseTextBuilders = responseTextBuilders,
176+
)
177+
}
174178
}
175179

176180
private suspend fun GenerativeModel.generateTextMessageStream(
@@ -211,10 +215,6 @@ class GenerativeChatRepository @Inject constructor(
211215
}
212216
}
213217

214-
private fun throwNotJoinedError(): Nothing {
215-
error("Must call join function before usage")
216-
}
217-
218218
override fun exit() {
219219
joinedRoomId = null
220220
generativeModel = null

feature/chat/data/src/main/java/io/jja08111/gemini/feature/chat/data/source/ChatLocalDataSource.kt

+60-42
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import io.jja08111.gemini.database.entity.RoomEntity
1111
import io.jja08111.gemini.database.entity.partial.ModelResponseContentPartial
1212
import io.jja08111.gemini.database.entity.relation.PromptWithImages
1313
import io.jja08111.gemini.database.extension.toDomain
14+
import io.jja08111.gemini.feature.chat.data.exception.MessageInsertionException
15+
import io.jja08111.gemini.feature.chat.data.exception.RoomInsertionException
1416
import io.jja08111.gemini.feature.chat.data.extension.convertToMessageGroups
1517
import io.jja08111.gemini.model.MessageGroup
1618
import io.jja08111.gemini.model.ModelResponse
@@ -79,31 +81,39 @@ class ChatLocalDataSource @Inject constructor(
7981
)
8082
}
8183

82-
messageDao.insert(
83-
prompt = PromptEntity(
84-
id = promptId,
85-
roomId = roomId,
86-
parentModelResponseId = parentModelResponseId,
87-
text = prompt,
88-
createdAt = now,
89-
),
90-
images = images,
91-
modelResponses = responseIds.mapIndexed { index, id ->
92-
ModelResponseEntity(
93-
id = id,
84+
try {
85+
messageDao.insert(
86+
prompt = PromptEntity(
87+
id = promptId,
9488
roomId = roomId,
95-
parentPromptId = promptId,
96-
text = "",
89+
parentModelResponseId = parentModelResponseId,
90+
text = prompt,
9791
createdAt = now,
98-
state = ModelResponseStateEntity.Generating,
99-
selected = index == 0,
100-
)
101-
},
102-
)
92+
),
93+
images = images,
94+
modelResponses = responseIds.mapIndexed { index, id ->
95+
ModelResponseEntity(
96+
id = id,
97+
roomId = roomId,
98+
parentPromptId = promptId,
99+
text = "",
100+
createdAt = now,
101+
state = ModelResponseStateEntity.Generating,
102+
selected = index == 0,
103+
)
104+
},
105+
)
106+
} catch (e: Exception) {
107+
throw MessageInsertionException(cause = e)
108+
}
103109
}
104110

105111
suspend fun insertRoom(roomId: String, title: String) {
106-
roomDao.insert(RoomEntity(id = roomId, createdAt = LocalDateTime.now(), title = title))
112+
try {
113+
roomDao.insert(RoomEntity(id = roomId, createdAt = LocalDateTime.now(), title = title))
114+
} catch (e: Exception) {
115+
throw RoomInsertionException(cause = e)
116+
}
107117
}
108118

109119
suspend fun insertResponsesAndRemoveError(
@@ -112,37 +122,45 @@ class ChatLocalDataSource @Inject constructor(
112122
roomId: String,
113123
promptId: String,
114124
) {
115-
val responses = newResponseIds.mapIndexed { index, id ->
116-
ModelResponseEntity(
117-
id = id,
118-
roomId = roomId,
119-
parentPromptId = promptId,
120-
text = "",
121-
createdAt = LocalDateTime.now(),
122-
state = ModelResponseStateEntity.Generating,
123-
selected = index == 0,
124-
)
125+
try {
126+
val responses = newResponseIds.mapIndexed { index, id ->
127+
ModelResponseEntity(
128+
id = id,
129+
roomId = roomId,
130+
parentPromptId = promptId,
131+
text = "",
132+
createdAt = LocalDateTime.now(),
133+
state = ModelResponseStateEntity.Generating,
134+
selected = index == 0,
135+
)
136+
}
137+
messageDao.insertAndRemove(modelResponses = responses, responseIdToRemove = errorResponseId)
138+
} catch (e: Exception) {
139+
throw MessageInsertionException(cause = e)
125140
}
126-
messageDao.insertAndRemove(modelResponses = responses, responseIdToRemove = errorResponseId)
127141
}
128142

129143
suspend fun insertAndUnselectOldResponses(
130144
newResponseIds: List<String>,
131145
roomId: String,
132146
promptId: String,
133147
) {
134-
val responses = newResponseIds.mapIndexed { index, id ->
135-
ModelResponseEntity(
136-
id = id,
137-
roomId = roomId,
138-
parentPromptId = promptId,
139-
text = "",
140-
createdAt = LocalDateTime.now(),
141-
state = ModelResponseStateEntity.Generating,
142-
selected = index == 0,
143-
)
148+
try {
149+
val responses = newResponseIds.mapIndexed { index, id ->
150+
ModelResponseEntity(
151+
id = id,
152+
roomId = roomId,
153+
parentPromptId = promptId,
154+
text = "",
155+
createdAt = LocalDateTime.now(),
156+
state = ModelResponseStateEntity.Generating,
157+
selected = index == 0,
158+
)
159+
}
160+
messageDao.insertAndUnselectOldResponses(modelResponses = responses)
161+
} catch (e: Exception) {
162+
throw MessageInsertionException(cause = e)
144163
}
145-
messageDao.insertAndUnselectOldResponses(modelResponses = responses)
146164
}
147165

148166
fun getPromptBy(promptId: String): Flow<Prompt> {

0 commit comments

Comments
 (0)