From 19dd0124cc32cdbd146137f76c170c87a8476da8 Mon Sep 17 00:00:00 2001 From: Shreyash Gupta <48386323+shreyashkgupta@users.noreply.github.com> Date: Tue, 23 Jul 2024 00:22:03 +0530 Subject: [PATCH] feat: Ask a question functionality --- .flake8 | 2 +- examples/ask_question/main.py | 54 +++++ kaizen/llms/prompts/ask_question_prompts.py | 62 ++++++ kaizen/reviewer/ask_question.py | 208 ++++++++++++++++++++ 4 files changed, 325 insertions(+), 1 deletion(-) create mode 100644 examples/ask_question/main.py create mode 100644 kaizen/llms/prompts/ask_question_prompts.py create mode 100644 kaizen/reviewer/ask_question.py diff --git a/.flake8 b/.flake8 index 4786b644..2821ef82 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] -exclude = docs/* +exclude = docs/*, venv/* ignore = E501, W503, E203, F541, W293, W291, E266, F601 \ No newline at end of file diff --git a/examples/ask_question/main.py b/examples/ask_question/main.py new file mode 100644 index 00000000..94f89a80 --- /dev/null +++ b/examples/ask_question/main.py @@ -0,0 +1,54 @@ +from kaizen.reviewer.ask_question import QuestionAnswer +from kaizen.llms.provider import LLMProvider +from github_app.github_helper.utils import get_diff_text, get_pr_files +import json +import logging + +logging.basicConfig(level="DEBUG") + +# PR details +pr_diff = "https://github.com/Cloud-Code-AI/kaizen/pull/335.patch" +pr_files = "https://api.github.com/repos/Cloud-Code-AI/kaizen/pulls/335/files" +pr_title = "feat: updated the prompt to provide solution" + +# Fetch PR data +diff_text = get_diff_text(pr_diff, "") +pr_files = get_pr_files(pr_files, "") + +# Initialize QuestionAnswer +qa = QuestionAnswer(llm_provider=LLMProvider()) + +# Example questions +questions = [ + "What are the main changes in this pull request?", + "Are there any potential performance implications in these changes?", + "Does this PR introduce any new dependencies?", +] + +# Ask questions about the PR +for question in questions: + print(f"\n----- Question: {question} -----") + + answer_output = qa.ask_pull_request( + diff_text=diff_text, + pull_request_title=pr_title, + pull_request_desc="", + question=question, + pull_request_files=pr_files, + user="kaizen/example", + ) + + print(f"Answer: {answer_output.answer}") + print(f"Model: {answer_output.model_name}") + print(f"Usage: {json.dumps(answer_output.usage, indent=2)}") + print(f"Cost: {json.dumps(answer_output.cost, indent=2)}") + +# Check if a specific question's prompt is within token limit +sample_question = "What are the coding style changes in this PR?" +is_within_limit = qa.is_ask_question_prompt_within_limit( + diff_text=diff_text, + pull_request_title=pr_title, + pull_request_desc="", + question=sample_question, +) +print(f"\nIs the prompt for '{sample_question}' within token limit? {is_within_limit}") diff --git a/kaizen/llms/prompts/ask_question_prompts.py b/kaizen/llms/prompts/ask_question_prompts.py new file mode 100644 index 00000000..a229137e --- /dev/null +++ b/kaizen/llms/prompts/ask_question_prompts.py @@ -0,0 +1,62 @@ +ANSWER_QUESTION_SYSTEM_PROMPT = """ +ou are an AI assistant specializing in software development and code review. Your role is to answer questions about pull requests accurately and comprehensively. When responding to questions: + +1. Analyze the provided code changes, pull request title, and description thoroughly. +2. Provide clear, concise, and relevant answers based on the information given. +3. If applicable, refer to specific code snippets or changes to support your answers. +4. Consider various aspects such as code quality, performance implications, potential bugs, and adherence to best practices. +5. Offer insights into the overall impact of the changes on the codebase. +6. If a question cannot be fully answered with the given information, state this clearly and provide the best possible answer based on available data. +7. Maintain a neutral, professional tone in your responses. +8. Do not ask for additional information or clarification; work with what is provided. +9. If relevant, suggest improvements or alternatives, but always in the context of answering the specific question asked. + +Your goal is to provide valuable insights that help developers and reviewers better understand the pull request and its implications. +""" + +ANSWER_QUESTION_PROMPT = """ +As an experienced software engineer, answer the following question about the given pull request. Use the provided information to give an accurate and helpful response. + +INFORMATION: + +Pull Request Title: {PULL_REQUEST_TITLE} +Pull Request Description: {PULL_REQUEST_DESC} + +PATCH DATA: +```{CODE_DIFF}``` + +QUESTION: +{QUESTION} + +Please provide a concise and informative answer to the question, based on the pull request information and code changes. +""" + +FILE_ANSWER_QUESTION_PROMPT = """ +As an experienced software engineer, answer the following question about the given pull request. Use the provided information to give an accurate and helpful response. + +INFORMATION: + +Pull Request Title: {PULL_REQUEST_TITLE} +Pull Request Description: {PULL_REQUEST_DESC} + +FILE PATCH: +```{FILE_PATCH}``` + +QUESTION: +{QUESTION} + +Please provide a concise and informative answer to the question, based on the pull request information and code changes. +""" + +SUMMARIZE_ANSWER_PROMPT = """ +As an experienced software engineer, analyze and summarize the following responses related to a question about a pull request. +Each response corresponds to a different file or chunk of the pull request. + +QUESTION: +{QUESTION} + +Responses: +{RESPONSES} + +Please provide a concise and informative summary that addresses the original question based on all the given responses. +""" diff --git a/kaizen/reviewer/ask_question.py b/kaizen/reviewer/ask_question.py new file mode 100644 index 00000000..cd9dd32c --- /dev/null +++ b/kaizen/reviewer/ask_question.py @@ -0,0 +1,208 @@ +from typing import Optional, List, Dict, Generator +from dataclasses import dataclass +import logging +from kaizen.helpers import parser +from kaizen.llms.provider import LLMProvider +from kaizen.llms.prompts.ask_question_prompts import ( + ANSWER_QUESTION_SYSTEM_PROMPT, + ANSWER_QUESTION_PROMPT, + FILE_ANSWER_QUESTION_PROMPT, + SUMMARIZE_ANSWER_PROMPT, +) + + +@dataclass +class AnswerOutput: + answer: str + usage: Dict[str, int] + model_name: str + cost: Dict[str, float] + + +class QuestionAnswer: + def __init__(self, llm_provider: LLMProvider): + self.logger = logging.getLogger(__name__) + self.provider = llm_provider + self.provider.system_prompt = ANSWER_QUESTION_SYSTEM_PROMPT + self.total_usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + + def is_ask_question_prompt_within_limit( + self, + diff_text: str, + pull_request_title: str, + pull_request_desc: str, + question: str, + ) -> bool: + prompt = ANSWER_QUESTION_PROMPT.format( + PULL_REQUEST_TITLE=pull_request_title, + PULL_REQUEST_DESC=pull_request_desc, + CODE_DIFF=parser.patch_to_combined_chunks(diff_text), + QUESTION=question, + ) + return self.provider.is_inside_token_limit(PROMPT=prompt) + + def ask_pull_request( + self, + diff_text: str, + pull_request_title: str, + pull_request_desc: str, + question: str, + pull_request_files: List[Dict], + user: Optional[str] = None, + ) -> AnswerOutput: + prompt = ANSWER_QUESTION_PROMPT.format( + PULL_REQUEST_TITLE=pull_request_title, + PULL_REQUEST_DESC=pull_request_desc, + CODE_DIFF=parser.patch_to_combined_chunks(diff_text), + QUESTION=question, + ) + self.total_usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + if not diff_text and not pull_request_files: + raise Exception("Both diff_text and pull_request_files are empty!") + + if diff_text and self.provider.is_inside_token_limit(PROMPT=prompt): + resp = self._process_full_diff_qa(prompt, user) + + else: + resp = self._process_files_qa( + pull_request_files, + pull_request_title, + pull_request_desc, + question, + user, + ) + + prompt_cost, completion_cost = self.provider.get_usage_cost( + total_usage=self.total_usage + ) + + return AnswerOutput( + answer=resp, + usage=self.total_usage, + model_name=self.provider.model, + cost={"prompt_cost": prompt_cost, "completion_cost": completion_cost}, + ) + + def _process_full_diff_qa( + self, + prompt: str, + user: Optional[str], + ) -> str: + self.logger.debug("Processing directly from diff") + resp, usage = self.provider.chat_completion(prompt, user=user) + self.total_usage = self.provider.update_usage(self.total_usage, usage) + return resp + + def _process_files_qa( + self, + pull_request_files: List[Dict], + pull_request_title: str, + pull_request_desc: str, + question: str, + user: Optional[str], + ) -> str: + self.logger.debug("Processing based on files") + responses = [] + for answer in self._process_files_generator_qa( + pull_request_files, + pull_request_title, + pull_request_desc, + question, + user, + ): + responses.append(answer) + ## summarize responses + return self._summarize_responses(question, responses) + + def _process_files_generator_qa( + self, + pull_request_files: List[Dict], + pull_request_title: str, + pull_request_desc: str, + question: str, + user: Optional[str], + ) -> Generator[str, None, None]: + combined_diff_data = "" + available_tokens = self.provider.available_tokens(FILE_ANSWER_QUESTION_PROMPT) + + for file in pull_request_files: + patch_details = file.get("patch") + filename = file.get("filename", "") + + if ( + filename.split(".")[-1] not in parser.EXCLUDED_FILETYPES + and patch_details is not None + ): + temp_prompt = ( + combined_diff_data + + f"\n---->\nFile Name: {filename}\nPatch Details: {parser.patch_to_combined_chunks(patch_details)}" + ) + + if available_tokens - self.provider.get_token_count(temp_prompt) > 0: + combined_diff_data = temp_prompt + continue + + yield self._process_file_chunk_qa( + combined_diff_data, + pull_request_title, + pull_request_desc, + question, + user, + ) + combined_diff_data = ( + f"\n---->\nFile Name: {filename}\nPatch Details: {patch_details}" + ) + + if combined_diff_data: + yield self._process_file_chunk_qa( + combined_diff_data, + pull_request_title, + pull_request_desc, + question, + user, + ) + + def _process_file_chunk_qa( + self, + diff_data: str, + pull_request_title: str, + pull_request_desc: str, + question: str, + user: Optional[str], + ) -> str: + if not diff_data: + return "" + prompt = FILE_ANSWER_QUESTION_PROMPT.format( + PULL_REQUEST_TITLE=pull_request_title, + PULL_REQUEST_DESC=pull_request_desc, + FILE_PATCH=diff_data, + QUESTION=question, + ) + resp, usage = self.provider.chat_completion(prompt, user=user) + self.total_usage = self.provider.update_usage(self.total_usage, usage) + return resp + + def _summarize_responses(self, question: str, responses: List[str]) -> str: + if len(responses) == 1: + return responses[0] + + formatted_responses = "\n\n".join( + f"Response for file/chunk {i + 1}:\n{response}" + for i, response in enumerate(responses) + ) + summary_prompt = SUMMARIZE_ANSWER_PROMPT.format( + QUESTION=question, RESPONSES=formatted_responses + ) + + summarized_answer, usage = self.provider.chat_completion(summary_prompt) + self.total_usage = self.provider.update_usage(self.total_usage, usage) + + return summarized_answer