Skip to content

Commit bb06f73

Browse files
committed
Refactor to simplify code in GenerativeChatRepository
1 parent d7d3889 commit bb06f73

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

feature/chat/data/src/main/java/io/jja08111/gemini/feature/chat/data/repository/GenerativeChatRepository.kt

+6-10
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,16 @@ class GenerativeChatRepository @Inject constructor(
121121
return runCatching {
122122
val model = generativeModel ?: throw NotJoinedException()
123123
val roomId = joinedRoomId ?: throw NotJoinedException()
124-
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
125-
val messageGroups = messageGroupStream.first()
124+
val messageGroups = chatLocalDataSource.getMessageGroupsBy(roomId)
126125
val lastMessageGroup = messageGroups.lastOrNull() ?: throw EmptyMessageGroupsException()
127126
val lastPrompt = lastMessageGroup.prompt
128-
val lastPromptId = lastPrompt.id
129127
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
130128

131129
chatLocalDataSource.insertResponsesAndRemoveError(
132130
newResponseIds = responseTextBuilders.map { it.id },
133131
errorResponseId = lastMessageGroup.selectedResponse.id,
134132
roomId = roomId,
135-
promptId = lastPromptId,
133+
promptId = lastPrompt.id,
136134
)
137135

138136
return model.generateTextMessageStream(
@@ -141,7 +139,7 @@ class GenerativeChatRepository @Inject constructor(
141139
history = messageGroups
142140
.dropLast(1)
143141
.flatMap(MessageGroup::toContents),
144-
promptId = lastPromptId,
142+
promptId = lastPrompt.id,
145143
responseTextBuilders = responseTextBuilders,
146144
)
147145
}
@@ -151,26 +149,24 @@ class GenerativeChatRepository @Inject constructor(
151149
return runCatching {
152150
val model = generativeModel ?: throw NotJoinedException()
153151
val roomId = joinedRoomId ?: throw NotJoinedException()
154-
val messageGroupStream = chatLocalDataSource.getMessageGroupStream(roomId)
155-
val messageGroups = messageGroupStream.first()
152+
val messageGroups = chatLocalDataSource.getMessageGroupsBy(roomId)
156153
val messageGroup = messageGroups.firstOrNull {
157154
it.selectedResponse.id == responseId
158155
} ?: throw EmptyMessageGroupsException()
159156
val prompt = messageGroup.prompt
160-
val promptId = prompt.id
161157
val responseTextBuilders = List(CANDIDATE_COUNT) { ResponseTextBuilder() }
162158

163159
chatLocalDataSource.insertAndUnselectOldResponses(
164160
newResponseIds = responseTextBuilders.map { it.id },
165161
roomId = roomId,
166-
promptId = promptId,
162+
promptId = prompt.id,
167163
)
168164

169165
return model.generateTextMessageStream(
170166
message = prompt.text,
171167
images = prompt.images.map { promptImageLocalDataSource.loadImage(it.path) },
172168
history = messageGroups.flatMap(MessageGroup::toContents),
173-
promptId = promptId,
169+
promptId = prompt.id,
174170
responseTextBuilders = responseTextBuilders,
175171
)
176172
}

feature/chat/data/src/main/java/io/jja08111/gemini/feature/chat/data/source/ChatLocalDataSource.kt

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import kotlinx.coroutines.async
2222
import kotlinx.coroutines.awaitAll
2323
import kotlinx.coroutines.coroutineScope
2424
import kotlinx.coroutines.flow.Flow
25+
import kotlinx.coroutines.flow.first
2526
import kotlinx.coroutines.flow.map
2627
import kotlinx.coroutines.flow.mapLatest
2728
import java.time.LocalDateTime
@@ -45,6 +46,11 @@ class ChatLocalDataSource @Inject constructor(
4546
return messageDao.getPromptWithResponsesAndImages(roomId).mapLatest(::convertToMessageGroups)
4647
}
4748

49+
suspend fun getMessageGroupsBy(roomId: String): List<MessageGroup> {
50+
val messageGroupStream = getMessageGroupStream(roomId)
51+
return messageGroupStream.first()
52+
}
53+
4854
suspend fun updateResponseContentPartials(partials: List<ModelResponseContentPartial>) {
4955
messageDao.updateAll(partials)
5056
}

0 commit comments

Comments
 (0)