Skip to content

Commit

Permalink
fix: dealing with empty tool_calls (#3514)
Browse files Browse the repository at this point in the history
When streaming the answers, we sometimes receive chunks with empty
`tool_calls`. That was causing messages in previous chunks to be yielded
again, so appearing to the user multiple times

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
  • Loading branch information
jacopo-chevallard authored Dec 10, 2024
1 parent 6450a49 commit e2f6389
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions core/quivr_core/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,18 @@ def parse_chunk_response(
"""
rolling_msg += raw_chunk

if not supports_func_calling or not rolling_msg.tool_calls:
tool_calls = rolling_msg.tool_calls

if not supports_func_calling or not tool_calls:
new_content = raw_chunk.content # Just the new chunk's content
full_content = rolling_msg.content # The full accumulated content
return rolling_msg, new_content, full_content

current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls)
current_answers = get_answers_from_tool_calls(tool_calls)
full_answer = "\n\n".join(current_answers)
if not full_answer:
full_answer = previous_content

new_content = full_answer[len(previous_content) :]

return rolling_msg, new_content, full_answer
Expand All @@ -111,8 +116,12 @@ def parse_chunk_response(
def get_answers_from_tool_calls(tool_calls):
answers = []
for tool_call in tool_calls:
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
answers.append(tool_call["args"].get("answer", ""))
if tool_call.get("name") == "cited_answer":
args = tool_call.get("args", {})
if isinstance(args, dict):
answers.append(args.get("answer", ""))
else:
logger.warning(f"Expected dict for tool_call args, got {type(args)}")
return answers


Expand Down

0 comments on commit e2f6389

Please sign in to comment.