diff --git a/.gitignore b/.gitignore index 055644b..f94d284 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,6 @@ data_map.txt .gptcache_data_map.txt dump.rdb faiss.index + +# test data +demo_data/test_api_config.yaml diff --git a/README.md b/README.md index c0933dd..bf41f61 100644 --- a/README.md +++ b/README.md @@ -43,16 +43,23 @@ pip install -r requirements.txt ### Configure API keys -``` -cp factcheck/config/secret_dict.template factcheck/config/secret_dict.py -``` -You can choose to export essential api key to the environment, or configure it in `factcheck/config/secret_dict.py`. +You can choose to export essential api key to the environment - Example: Export essential api key to the environment ```bash export SERPER_API_KEY=... # this is required in evidence retrieval if serper being used export OPENAI_API_KEY=... # this is required in all tasks export ANTHROPIC_API_KEY=... # this is required only if you want to replace openai with anthropic +export LOCAL_API_KEY=... # this is required only if you want to use local LLM +export LOCAL_API_URL=... # this is required only if you want to use local LLM +``` + +Alternatively, you can save the api information in a yaml file with the same key names as the environment variables and pass the path to the yaml file as an argument to the `check_response` method. + +See `demo_data\api_config.yaml` as an example of the api configuration file. +- Example: Pass the path to the api configuration file +```bash +python -m factcheck --modal string --input "MBZUAI is the first AI university in the world" --api_config demo_data/api_config.yaml ``` ### Test @@ -62,15 +69,15 @@ export ANTHROPIC_API_KEY=... # this is required only if you want to replace open To test the project, you can run the `factcheck.py` script: ```bash # String -python factcheck.py --modal string --input "MBZUAI is the first AI university in the world" +python -m factcheck --modal string --input "MBZUAI is the first AI university in the world" # Text -python factcheck.py --modal text --input demo_data/text.txt +python -m factcheck --modal text --input demo_data/text.txt # Speech -python factcheck.py --modal speech --input demo_data/speech.mp3 +python -m factcheck --modal speech --input demo_data/speech.mp3 # Image -python factcheck.py --modal image --input demo_data/image.webp +python -m factcheck --modal image --input demo_data/image.webp # Video -python factcheck.py --modal video --input demo_data/video.m4v +python -m factcheck --modal video --input demo_data/video.m4v ``` ## Usage @@ -79,19 +86,21 @@ The main interface of the Fact-check Pipeline is located in `factcheck/core/Fact Example usage: ```python -from factcheck.core.FactCheck import check_response +from factcheck import FactCheck + +factcheck_instance = FactCheck() # Example text text = "Your text here" # Run the fact-check pipeline -results = check_response(text) +results = factcheck_instance.check_response(text) print(results) ``` Web app usage: ```bash -python webapp.py +python webapp.py --api_config demo_data/api_config.yaml ```
@@ -106,6 +115,23 @@ We welcome contributions from the community! If you'd like to contribute, please 5. Open a pull request. +## Customize Your Experience + +### Custom Models +```bash +python -m factcheck --modal string --input "MBZUAI is the first AI university in the world" --api_config demo_data/api_config.yaml --model claude-3-opus-20240229 --prompt claude_prompt +``` + +### Custom Evidence Retrieval +```bash +python -m factcheck --modal string --input "MBZUAI is the first AI university in the world" --api_config demo_data/test_api_config.yaml --retriever google +``` + +### Custom Prompts +```bash +python -m factcheck --modal string --input "MBZUAI is the first AI university in the world" --api_config demo_data/test_api_config.yaml --prompt demo_data/sample_prompt.yaml +``` + ## Ready for More? 💪 **Join Our Journey to Innovation with the Supporter Edition** diff --git a/factcheck.py b/factcheck.py deleted file mode 100644 index c3d6716..0000000 --- a/factcheck.py +++ /dev/null @@ -1,28 +0,0 @@ -from factcheck.core.FactCheck import FactCheck -from factcheck.utils.multimodal import modal_normalization -import argparse -import json - - -def main(model: str, modal: str, input: str): - """factcheck - - Args: - model (str): gpt model used for factchecking - modal (str): input type, supported types are str, text file, speech, image, and video - input (str): input content or path to the file - """ - factcheck = FactCheck(default_model=model) - content = modal_normalization(modal, input) - res = factcheck.check_response(content) - print(json.dumps(res["step_info"], indent=4)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="gpt-4-0125-preview") - parser.add_argument("--modal", type=str, default="text") - parser.add_argument("--input", type=str, default="demo_data/text.txt") - args = parser.parse_args() - - main(args.model, args.modal, args.input) diff --git a/factcheck/__init__.py b/factcheck/__init__.py index e69de29..981d16c 100644 --- a/factcheck/__init__.py +++ b/factcheck/__init__.py @@ -0,0 +1,162 @@ +import time +import tiktoken + +from factcheck.utils.llmclient import CLIENTS, model2client +from factcheck.utils.prompt import prompt_mapper +from factcheck.utils.logger import CustomLogger +from factcheck.utils.api_config import load_api_config +from factcheck.core import ( + Decompose, + Checkworthy, + QueryGenerator, + retriever_mapper, + ClaimVerify, +) + +logger = CustomLogger(__name__).getlog() + + +class FactCheck: + def __init__( + self, + default_model: str = "gpt-4-0125-preview", + client: str = None, + prompt: str = "chatgpt_prompt", + retriever: str = "serper", + decompose_model: str = None, + checkworthy_model: str = None, + query_generator_model: str = None, + evidence_retrieval_model: str = None, + claim_verify_model: str = None, + api_config: dict = None, + ): + self.encoding = tiktoken.get_encoding("cl100k_base") + + self.prompt = prompt_mapper(prompt_name=prompt) + + # load configures for API + self.load_config(api_config=api_config) + + # llms for each step (sub-module) + step_models = { + "decompose_model": decompose_model, + "checkworthy_model": checkworthy_model, + "query_generator_model": query_generator_model, + "evidence_retrieval_model": evidence_retrieval_model, + "claim_verify_model": claim_verify_model, + } + + for key, _model_name in step_models.items(): + _model_name = default_model if _model_name is None else _model_name + print(f"== Init {key} with model: {_model_name}") + if client is not None: + logger.info(f"== Use specified client: {client}") + LLMClient = CLIENTS[client] + else: + logger.info("== Client is not specified, use model2client() to get the default llm client.") + LLMClient = model2client(_model_name) + setattr(self, key, LLMClient(model=_model_name, api_config=self.api_config)) + + # sub-modules + self.decomposer = Decompose(llm_client=self.decompose_model, prompt=self.prompt) + self.checkworthy = Checkworthy(llm_client=self.checkworthy_model, prompt=self.prompt) + self.query_generator = QueryGenerator(llm_client=self.query_generator_model, prompt=self.prompt) + self.evidence_crawler = retriever_mapper(retriever_name=retriever)(api_config=self.api_config) + self.claimverify = ClaimVerify(llm_client=self.claim_verify_model, prompt=self.prompt) + + logger.info("===Sub-modules Init Finished===") + + def load_config(self, api_config: dict) -> None: + # Load API config + self.api_config = load_api_config(api_config) + + def check_response(self, response: str): + st_time = time.time() + # step 1 + claims = self.decomposer.getclaims(doc=response) + for i, claim in enumerate(claims): + logger.info(f"== response claims {i}: {claim}") + + # step 2 + ( + checkworthy_claims, + pairwise_checkworthy, + ) = self.checkworthy.identify_checkworthiness(claims) + for i, claim in enumerate(checkworthy_claims): + logger.info(f"== Check-worthy claims {i}: {claim}") + + # Token count + num_raw_tokens = len(self.encoding.encode(response)) + num_checkworthy_tokens = len(self.encoding.encode(" ".join(checkworthy_claims))) + + api_data_dict = { + "response": response, + "token_count": { + "num_raw_tokens": num_raw_tokens, + "num_checkworthy_tokens": num_checkworthy_tokens, + }, + "step_info": { + "0_response": response, + "1_decompose": claims, + "2_checkworthy": checkworthy_claims, + "2_checkworthy_pairwise": pairwise_checkworthy, + "3_query_generator": {}, + "4_evidence_retrieve": {}, + "5_claim_verify": {}, + }, + } + # Special case, return + if num_checkworthy_tokens == 0: + api_data_dict["factuality"] = "Nothing to check." + logger.info("== State: Done! (Nothing to check.)") + return api_data_dict + + # step 3 + claim_query_dict = self.query_generator.generate_query(claims=checkworthy_claims) + for k, v in claim_query_dict.items(): + logger.info(f"== Claim: {k} --- Queries: {v}") + + step123_time = time.time() + + # step 4 + claim_evidence_dict = self.evidence_crawler.retrieve_evidence(claim_query_dict=claim_query_dict) + for claim, evidences in claim_evidence_dict.items(): + logger.info(f"== Claim: {claim}") + logger.info(f"== Evidence: {evidences}\n") + step4_time = time.time() + + # step 5 + claim_verify_dict = self.claimverify.verify_claims(claims_evidences_dict=claim_evidence_dict) + step5_time = time.time() + logger.info( + f"== State: Done! \n Total time: {step5_time-st_time:.2f}s. (create claims:{step123_time-st_time:.2f}s ||| retrieve:{step4_time-step123_time:.2f}s ||| verify:{step5_time-step4_time:.2f}s)" + ) + + api_data_dict["step_info"].update( + { + "3_query_generator": claim_query_dict, + "4_evidence_retrieve": claim_evidence_dict, + "5_claim_verify": claim_verify_dict, + } + ) + api_data_dict = self._post_process(api_data_dict, claim_verify_dict) + api_data_dict["step_info"] = api_data_dict["step_info"] + + return api_data_dict + + def _post_process(self, api_data_dict, claim_verify_dict: dict): + label_list = list() + api_claim_data_list = list() + for claim in api_data_dict["step_info"]["2_checkworthy"]: + api_claim_data = {} + claim_detail = claim_verify_dict.get(claim, {}) + curr_claim_label = claim_detail.get("factuality", False) + label_list.append(curr_claim_label) + api_claim_data["claim"] = claim + api_claim_data["factuality"] = curr_claim_label + api_claim_data["correction"] = claim_detail.get("correction", "") + api_claim_data["reference_url"] = claim_detail.get("url", "") + api_claim_data_list.append(api_claim_data) + api_data_dict["factuality"] = all(label_list) if label_list else True + api_data_dict["claims_details"] = api_claim_data_list + return api_data_dict diff --git a/factcheck/__main__.py b/factcheck/__main__.py new file mode 100644 index 0000000..4239281 --- /dev/null +++ b/factcheck/__main__.py @@ -0,0 +1,45 @@ +import json +import argparse + +from factcheck.utils.llmclient import CLIENTS +from factcheck.utils.multimodal import modal_normalization +from factcheck.utils.utils import load_yaml +from factcheck import FactCheck + + +def check(args): + """factcheck + + Args: + model (str): gpt model used for factchecking + modal (str): input type, supported types are str, text file, speech, image, and video + input (str): input content or path to the file + """ + # Load API config from yaml file + try: + api_config = load_yaml(args.api_config) + except Exception as e: + print(f"Error loading api config: {e}") + api_config = {} + + factcheck = FactCheck( + default_model=args.model, client=args.client, api_config=api_config, prompt=args.prompt, retriever=args.retriever + ) + + content = modal_normalization(args.modal, args.input) + res = factcheck.check_response(content) + print(json.dumps(res["step_info"], indent=4)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="gpt-4-0125-preview") + parser.add_argument("--client", type=str, default=None, choices=CLIENTS.keys()) + parser.add_argument("--prompt", type=str, default="chatgpt_prompt") + parser.add_argument("--retriever", type=str, default="serper") + parser.add_argument("--modal", type=str, default="text") + parser.add_argument("--input", type=str, default="demo_data/text.txt") + parser.add_argument("--api_config", type=str, default="factcheck/config/api_config.yaml") + args = parser.parse_args() + + check(args) diff --git a/factcheck/config/api_config.yaml b/factcheck/config/api_config.yaml new file mode 100644 index 0000000..148b7ac --- /dev/null +++ b/factcheck/config/api_config.yaml @@ -0,0 +1,8 @@ +SERPER_API_KEY: null + +OPENAI_API_KEY: null + +ANTHROPIC_API_KEY: null + +LOCAL_API_KEY: null +LOCAL_API_URL: null diff --git a/factcheck/config/sample_prompt.yaml b/factcheck/config/sample_prompt.yaml new file mode 100644 index 0000000..3db790e --- /dev/null +++ b/factcheck/config/sample_prompt.yaml @@ -0,0 +1,106 @@ +decompose_prompt: | + Your task is to decompose the text into atomic claims. + The answer should be a JSON with a single key "claims", with the value of a list of strings, where each string should be a context-independent claim, representing one fact. + Note that: + 1. Each claim should be concise (less than 15 words) and self-contained. + 2. Avoid vague references like 'he', 'she', 'it', 'this', 'the company', 'the man' and using complete names. + 3. Generate at least one claim for each single sentence in the texts. + + For example, + Text: Mary is a five-year old girl, she likes playing piano and she doesn't like cookies. + Output: + {{"claims": ["Mary is a five-year old girl.", "Mary likes playing piano.", "Mary doesn't like cookies."]}} + + Text: {doc} + Output: + +checkworthy_prompt: | + Your task is to evaluate each provided statement to determine if it presents information whose factuality can be objectively verified by humans, irrespective of the statement's current accuracy. Consider the following guidelines: + 1. Opinions versus Facts: Distinguish between opinions, which are subjective and not verifiable, and statements that assert factual information, even if broad or general. Focus on whether there's a factual claim that can be investigated. + 2. Clarity and Specificity: Statements must have clear and specific references to be verifiable (e.g., "he is a professor" is not verifiable without knowing who "he" is). + 3. Presence of Factual Information: Consider a statement verifiable if it includes factual elements that can be checked against evidence or reliable sources, even if the overall statement might be broad or incorrect. + Your response should be in JSON format, with each statement as a key and either "Yes" or "No" as the value, along with a brief rationale for your decision. + + For example, given these statements: + 1. Gary Smith is a distinguished professor of economics. + 2. He is a professor at MBZUAI. + 3. Obama is the president of the UK. + + The expected output is: + {{ + "Gary Smith is a distinguished professor of economics.": "Yes (The statement contains verifiable factual information about Gary Smith's professional title and field.)", + "He is a professor at MBZUAI.": "No (The statement cannot be verified due to the lack of clear reference to who 'he' is.)", + "Obama is the president of the UK.": "Yes (This statement contain verifiable information regarding the political leadership of a country.)" + }} + + For these statements: + {texts} + + The output should be: + + +qgen_prompt: | + Given a claim, your task is to create minimum number of questions need to be check to verify the correctness of the claim. Output in JSON format with a single key "Questions", the value is a list of questions. For example: + + Claim: Your nose switches back and forth between nostrils. When you sleep, you switch about every 45 minutes. This is to prevent a buildup of mucus. It’s called the nasal cycle. + Output: {{"Questions": ["Does your nose switch between nostrils?", "How often does your nostrils switch?", "Why does your nostril switch?", "What is nasal cycle?"]}} + + Claim: The Stanford Prison Experiment was conducted in the basement of Encina Hall, Stanford’s psychology building. + Output: + {{"Question":["Where was Stanford Prison Experiment was conducted?"]}} + + Claim: The Havel-Hakimi algorithm is an algorithm for converting the adjacency matrix of a graph into its adjacency list. It is named after Vaclav Havel and Samih Hakimi. + Output: + {{"Questions":["What does Havel-Hakimi algorithm do?", "Who are Havel-Hakimi algorithm named after?"]}} + + Claim: Social work is a profession that is based in the philosophical tradition of humanism. It is an intellectual discipline that has its roots in the 1800s. + Output: + {{"Questions":["What philosophical tradition is social work based on?", "What year does social work have its root in?"]}} + + Claim: {claim} + Output: + + +verify_prompt: | + Your task is to evaluate the accuracy of a provided statement using the accompanying evidence. Carefully review the evidence, noting that it may vary in detail and sometimes present conflicting information. Your judgment should be informed by this evidence, taking into account its relevance and reliability. + + Keep in mind that a lack of detail in the evidence does not necessarily indicate that the statement is inaccurate. When assessing the statement's factuality, distinguish between errors and areas where the evidence supports the statement. + + Please structure your response in JSON format, including the following four keys: + - "reasoning": explain the thought process behind your judgment. + - "error": none if the text is factual; otherwise, identify any specific inaccuracies in the statement. + - "correction": none if the text is factual; otherwise, provide corrections to any identified inaccuracies, using the evidence to support your corrections. + - "factuality": true if the given text is factual, false otherwise, indicating whether the statement is factual, or non-factual based on the evidence. + + For example: + Input: + [text]: MBZUAI is located in Abu Dhabi, United Arab Emirates. + [evidence]: Where is MBZUAI located?\nAnswer: Masdar City - Abu Dhabi - United Arab Emirates + + Output: + {{ + "reasoning": "The evidence confirms that MBZUAI is located in Masdar City, Abu Dhabi, United Arab Emirates, so the statement is factually correct", + "error": none, + "correction": none, + "factuality": true + }} + + + Input: + [text]: Copper reacts with ferrous sulfate (FeSO4). + [evidence]: Copper is less reactive metal. It has positive value of standard reduction potential. Metal with high standard reduction potential can not displace other metal with low standard reduction potential values. Hence copper can not displace iron from ferrous sulphate solution. So no change will take place. + + Output: + {{ + "reasoning": "The evidence provided confirms that copper cannot displace iron from ferrous sulphate solution, and no change will take place.", + "error": "Copper does not react with ferrous sulfate as stated in the text.", + "correction": "Copper does not react with ferrous sulfate as it cannot displace iron from ferrous sulfate solution.", + "factuality": false + }} + + + Input + [text]: {claim} + [evidences]: {evidence} + + Output: diff --git a/factcheck/config/secret_dict.template b/factcheck/config/secret_dict.template deleted file mode 100644 index 3d5aeb8..0000000 --- a/factcheck/config/secret_dict.template +++ /dev/null @@ -1,9 +0,0 @@ -import os - -serper_dict = {"default_key": None} -openai_dict = {"default_key": None} -anthropic_dict = {"default_key": None} - -serper_dict["key"] = os.environ.get("SERPER_API_KEY", serper_dict["default_key"]) -openai_dict["key"] = os.environ.get("OPENAI_API_KEY", openai_dict["default_key"]) -anthropic_dict["key"] = os.environ.get("ANTHROPIC_API_KEY", anthropic_dict["default_key"]) diff --git a/factcheck/core/CheckWorthy.py b/factcheck/core/CheckWorthy.py index ed67450..7e408b4 100644 --- a/factcheck/core/CheckWorthy.py +++ b/factcheck/core/CheckWorthy.py @@ -1,37 +1,41 @@ -from typing import List -from factcheck.utils.prompt import CHECKWORTHY_PROMPT -from factcheck.utils.GPTClient import GPTClient -from factcheck.config.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() class Checkworthy: - def __init__(self, model: str = "gpt-3.5-turbo"): + def __init__(self, llm_client, prompt): """Initialize the Checkworthy class Args: - model (str, optional): The version of the GPT model used for checkworthy classification. Defaults to "gpt-3.5-turbo". + llm_client (BaseClient): The LLM client used for identifying checkworthiness of claims. + prompt (BasePrompt): The prompt used for identifying checkworthiness of claims. """ - self.chatgpt_client = GPTClient(model=model) + self.llm_client = llm_client + self.prompt = prompt - def identify_checkworthiness(self, texts: List[str], num_retries: int = 3) -> List[str]: + def identify_checkworthiness(self, texts: list[str], num_retries: int = 3, prompt: str = None) -> list[str]: """Use GPT to identify whether candidate claims are worth fact checking. if gpt is unable to return correct checkworthy_claims, we assume all texts are checkworthy. Args: - texts (List[str]): a list of texts to identify whether they are worth fact checking + texts (list[str]): a list of texts to identify whether they are worth fact checking num_retries (int, optional): maximum attempts for GPT to identify checkworthy claims. Defaults to 3. Returns: - List[str]: a list of checkworthy claims, pairwise outputs + list[str]: a list of checkworthy claims, pairwise outputs """ checkworthy_claims = texts # TODO: better handle checkworthiness joint_texts = "\n".join([str(i + 1) + ". " + j for i, j in enumerate(texts)]) - user_input = CHECKWORTHY_PROMPT.format(texts=joint_texts) - messages = self.chatgpt_client.construct_message_list([user_input]) + + if prompt is None: + user_input = self.prompt.checkworthy_prompt.format(texts=joint_texts) + else: + user_input = prompt.format(texts=joint_texts) + + messages = self.llm_client.construct_message_list([user_input]) for i in range(num_retries): - response = self.chatgpt_client.multi_call(messages, num_retries=1, seed=42 + i) + response = self.llm_client.call(messages, num_retries=1, seed=42 + i) try: results = eval(response) valid_answer = list( @@ -45,6 +49,6 @@ def identify_checkworthiness(self, texts: List[str], num_retries: int = 3) -> Li assert len(valid_answer) == len(results) break except Exception as e: - logger.error(f"====== Error: {e}, the response is: {response}") + logger.error(f"====== Error: {e}, the LLM response is: {response}") logger.error(f"====== Our input is: {messages}") return checkworthy_claims, results diff --git a/factcheck/core/ClaimVerify.py b/factcheck/core/ClaimVerify.py index 17d13f1..b92e138 100644 --- a/factcheck/core/ClaimVerify.py +++ b/factcheck/core/ClaimVerify.py @@ -1,23 +1,23 @@ from __future__ import annotations import json -from factcheck.utils.prompt import VERIFY_PROMPT -from factcheck.utils.GPTClient import GPTClient -from factcheck.config.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() class ClaimVerify: - def __init__(self, model: str = "gpt-3.5-turbo"): + def __init__(self, llm_client, prompt): """Initialize the ClaimVerify class Args: - model (str, optional): The version of the GPT model used for claim verification. Defaults to "gpt-3.5-turbo". + llm_client (BaseClient): The LLM client used for verifying the factuality of claims. + prompt (BasePrompt): The prompt used for verifying the factuality of claims. """ - self.chatgpt_client = GPTClient(model=model) + self.llm_client = llm_client + self.prompt = prompt - def verify_claims(self, claims_evidences_dict): + def verify_claims(self, claims_evidences_dict, prompt: str = None) -> dict[str, any]: """Verify the factuality of the claims with respect to the given evidences Args: @@ -30,7 +30,7 @@ def verify_claims(self, claims_evidences_dict): claims = list(claims_evidences_dict.keys()) evidence_lists = list(claims_evidences_dict.values()) - results = self._verify_all_claims(claims, evidence_lists) + results = self._verify_all_claims(claims, evidence_lists, prompt=prompt) for claim, evidence_list, result in zip(claims, evidence_lists, results): result["claim"] = claim @@ -38,7 +38,13 @@ def verify_claims(self, claims_evidences_dict): claim_detail_dict[claim] = result return claim_detail_dict - def _verify_all_claims(self, claims: list[str], evidence_lists: list[list], num_retries=3) -> list[dict[str, any]]: + def _verify_all_claims( + self, + claims: list[str], + evidence_lists: list[list], + num_retries=3, + prompt: str = None, + ) -> list[dict[str, any]]: """Verify the factuality of the claims with respect to the given evidences Args: @@ -54,22 +60,26 @@ def _verify_all_claims(self, claims: list[str], evidence_lists: list[list], num_ # construct user inputs with respect to each claim and its evidences messages_list = [] for claim, evidences in zip(claims, evidence_lists): - user_input = VERIFY_PROMPT.format(claim=claim, evidence=evidences) + if prompt is None: + user_input = self.prompt.verify_prompt.format(claim=claim, evidence=evidences) + else: + user_input = prompt.format(claim=claim, evidence=evidences) + messages_list.append(user_input) while (attempts < num_retries) and (None in factual_results): _messages = [_message for _i, _message in enumerate(messages_list) if factual_results[_i] is None] _indices = [_i for _i, _message in enumerate(messages_list) if factual_results[_i] is None] - _message_list = self.chatgpt_client.construct_message_list(_messages) - _response_list = self.chatgpt_client.call_chatgpt_multiple_async(_message_list) + _message_list = self.llm_client.construct_message_list(_messages) + _response_list = self.llm_client.multi_call(_message_list) for _response, _index in zip(_response_list, _indices): try: _response_json = json.loads(_response) assert all(k in _response_json for k in ["reasoning", "error", "correction", "factuality"]) factual_results[_index] = _response_json except: # noqa: E722 - logger.info(f"Warning: ChatGPT response parse fail, retry {attempts}.") + logger.info(f"Warning: LLM response parse fail, retry {attempts}.") attempts += 1 _template_results = { diff --git a/factcheck/core/Decompose.py b/factcheck/core/Decompose.py index 72a270e..b0dca2b 100644 --- a/factcheck/core/Decompose.py +++ b/factcheck/core/Decompose.py @@ -1,20 +1,19 @@ -from factcheck.utils.prompt import SENTENCES_TO_CLAIMS_PROMPT -from factcheck.utils.GPTClient import GPTClient -from factcheck.config.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger import nltk logger = CustomLogger(__name__).getlog() class Decompose: - def __init__(self, model="gpt-3.5-turbo"): + def __init__(self, llm_client, prompt): """Initialize the Decompose class Args: - model (str, optional): The version of the GPT model used for claim decomposition. Defaults to "gpt-3.5-turbo". + llm_client (BaseClient): The LLM client used for decomposing documents into claims. + prompt (BasePrompt): The prompt used for fact checking. """ - self.getclaims = self.getclaimsfromgpt - self.chatgpt_client = GPTClient(model=model) + self.llm_client = llm_client + self.prompt = prompt self.doc2sent = self._nltk_doc2sent def _nltk_doc2sent(self, text: str): @@ -31,7 +30,7 @@ def _nltk_doc2sent(self, text: str): sentence_list = [s.strip() for s in sentences if len(s.strip()) >= 3] return sentence_list - def getclaimsfromgpt(self, doc: str, num_retries: int = 3): + def getclaims(self, doc: str, num_retries: int = 3, prompt: str = None): """Use GPT to decompose a document into claims Args: @@ -41,12 +40,14 @@ def getclaimsfromgpt(self, doc: str, num_retries: int = 3): Returns: list: a list of claims """ - prompt_text = SENTENCES_TO_CLAIMS_PROMPT - user_input = prompt_text.format(doc=doc).strip() + if prompt is None: + user_input = self.prompt.decompose_prompt.format(doc=doc).strip() + else: + user_input = prompt.format(doc=doc).strip() - messages = self.chatgpt_client.construct_message_list([user_input]) + messages = self.llm_client.construct_message_list([user_input]) for i in range(num_retries): - response = self.chatgpt_client.multi_call( + response = self.llm_client.call( messages=messages, num_retries=1, seed=42 + i, @@ -56,8 +57,8 @@ def getclaimsfromgpt(self, doc: str, num_retries: int = 3): if isinstance(claims, list) and len(claims) > 0: return claims except Exception as e: - logger.error(f"Parse chatgpt result error {e}, response is: {response}") - logger.error(f"Parse chatgpt result error, prompt is: {messages}") + logger.error(f"Parse LLM response error {e}, response is: {response}") + logger.error(f"Parse LLM response error, prompt is: {messages}") logger.info("It does not output a list of sentences correctly, return self.doc2sent_tool split results.") claims = self.doc2sent(doc) diff --git a/factcheck/core/FactCheck.py b/factcheck/core/FactCheck.py deleted file mode 100644 index ee17a65..0000000 --- a/factcheck/core/FactCheck.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -import time -import tiktoken - -from factcheck.core.Decompose import Decompose -from factcheck.core.CheckWorthy import Checkworthy -from factcheck.core.QueryGenerator import QueryGenerator -from factcheck.core.Retriever import SerperEvidenceRetrieve -from factcheck.core.ClaimVerify import ClaimVerify - -from factcheck.config.CustomLogger import CustomLogger - -logger = CustomLogger(__name__).getlog() - - -class FactCheck: - def __init__( - self, - default_model: str = "gpt-4-0125-preview", - decompose_model: str = None, - checkworthy_model: str = None, - query_generator_model: str = None, - evidence_retrieval_model: str = None, - claim_verify_model: str = None, - ): - # for gpt token count - self.encoding = tiktoken.get_encoding("cl100k_base") - - # claim getter - self.decomposer = Decompose(model=default_model if decompose_model is None else decompose_model) - # checkworthy - self.checkworthy = Checkworthy(model=default_model if checkworthy_model is None else checkworthy_model) - self.query_generator = QueryGenerator( - model=(default_model if query_generator_model is None else query_generator_model) - ) - # evidences crawler - self.evidence_crawler = SerperEvidenceRetrieve( - model=(default_model if evidence_retrieval_model is None else evidence_retrieval_model) - ) - # verity claim with evidences - self.claimverify = ClaimVerify(model=default_model if claim_verify_model is None else claim_verify_model) - logger.info("===Sub-modules Init Finished===") - - def check_response(self, response: str): - st_time = time.time() - # step 1 - claims = self.decomposer.getclaimsfromgpt(doc=response) - for i, claim in enumerate(claims): - logger.info(f"== response claims {i}: {claim}") - - # step 2 - ( - checkworthy_claims, - pairwise_checkworthy, - ) = self.checkworthy.identify_checkworthiness(claims) - for i, claim in enumerate(checkworthy_claims): - logger.info(f"== Check-worthy claims {i}: {claim}") - - # Token count - num_raw_tokens = len(self.encoding.encode(response)) - num_checkworthy_tokens = len(self.encoding.encode(" ".join(checkworthy_claims))) - - api_data_dict = { - "response": response, - "token_count": { - "num_raw_tokens": num_raw_tokens, - "num_checkworthy_tokens": num_checkworthy_tokens, - }, - "step_info": { - "0_response": response, - "1_decompose": claims, - "2_checkworthy": checkworthy_claims, - "2_checkworthy_pairwise": pairwise_checkworthy, - "3_query_generator": {}, - "4_evidence_retrieve": {}, - "5_claim_verify": {}, - }, - } - # Special case, return - if num_checkworthy_tokens == 0: - api_data_dict["factuality"] = "Nothing to check." - logger.info("== State: Done! (Nothing to check.)") - return api_data_dict - - # step 3 - claim_query_dict = self.query_generator.generate_query(claims=checkworthy_claims) - for k, v in claim_query_dict.items(): - logger.info(f"== Claim: {k} --- Queries: {v}") - - step123_time = time.time() - - # step 4 - claim_evidence_dict = self.evidence_crawler.retrieve_evidence(claim_query_dict=claim_query_dict) - for claim, evidences in claim_evidence_dict.items(): - logger.info(f"== Claim: {claim}") - logger.info(f"== Evidence: {evidences}\n") - step4_time = time.time() - - # step 5 - claim_verify_dict = self.claimverify.verify_claims(claims_evidences_dict=claim_evidence_dict) - step5_time = time.time() - logger.info( - f"== State: Done! \n Total time: {step5_time-st_time:.2f}s. (create claims:{step123_time-st_time:.2f}s ||| retrieve:{step4_time-step123_time:.2f}s ||| verify:{step5_time-step4_time:.2f}s)" - ) - - api_data_dict["step_info"].update( - { - "3_query_generator": claim_query_dict, - "4_evidence_retrieve": claim_evidence_dict, - "5_claim_verify": claim_verify_dict, - } - ) - api_data_dict = self._post_process(api_data_dict, claim_verify_dict) - api_data_dict["step_info"] = api_data_dict["step_info"] - - return api_data_dict - - def _post_process(self, api_data_dict, claim_verify_dict: dict): - label_list = list() - api_claim_data_list = list() - for claim in api_data_dict["step_info"]["2_checkworthy"]: - api_claim_data = {} - claim_detail = claim_verify_dict.get(claim, {}) - curr_claim_label = claim_detail.get("factuality", False) - label_list.append(curr_claim_label) - api_claim_data["claim"] = claim - api_claim_data["factuality"] = curr_claim_label - api_claim_data["correction"] = claim_detail.get("correction", "") - api_claim_data["reference_url"] = claim_detail.get("url", "") - api_claim_data_list.append(api_claim_data) - api_data_dict["factuality"] = all(label_list) if label_list else True - api_data_dict["claims_details"] = api_claim_data_list - return api_data_dict diff --git a/factcheck/core/QueryGenerator.py b/factcheck/core/QueryGenerator.py index 3e3261c..30507e8 100644 --- a/factcheck/core/QueryGenerator.py +++ b/factcheck/core/QueryGenerator.py @@ -1,22 +1,21 @@ -from __future__ import annotations -from factcheck.utils.prompt import QGEN_PROMPT -from factcheck.utils.GPTClient import GPTClient -from factcheck.config.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() class QueryGenerator: - def __init__(self, model: str = "gpt-3.5-turbo") -> None: + def __init__(self, llm_client, prompt, max_query_per_claim: int = 5): """Initialize the QueryGenerator class Args: - model (str, optional): The version of the GPT model used for query generation. Defaults to "gpt-3.5-turbo". + llm_client (BaseClient): The LLM client used for generating questions. + prompt (BasePrompt): The prompt used for generating questions. """ - self.chatgpt_client = GPTClient(model=model) - self.max_query_per_claim = 5 + self.llm_client = llm_client + self.prompt = prompt + self.max_query_per_claim = max_query_per_claim - def generate_query(self, claims: list[str], generating_time: int = 3): + def generate_query(self, claims: list[str], generating_time: int = 3, prompt: str = None) -> dict[str, list[str]]: """Generate questions for the given claims Args: @@ -32,22 +31,25 @@ def generate_query(self, claims: list[str], generating_time: int = 3): # construct messages messages_list = [] for claim in claims: - user_input = QGEN_PROMPT.format(claim=claim) + if prompt is None: + user_input = self.prompt.qgen_prompt.format(claim=claim) + else: + user_input = prompt.format(claim=claim) messages_list.append(user_input) while (attempts < generating_time) and ([] in generated_questions): _messages = [_message for _i, _message in enumerate(messages_list) if generated_questions[_i] == []] _indices = [_i for _i, _message in enumerate(messages_list) if generated_questions[_i] == []] - _message_list = self.chatgpt_client.construct_message_list(_messages) - _response_list = self.chatgpt_client.call_chatgpt_multiple_async(_message_list) + _message_list = self.llm_client.construct_message_list(_messages) + _response_list = self.llm_client.multi_call(_message_list) for _response, _index in zip(_response_list, _indices): try: _questions = eval(_response)["Questions"] generated_questions[_index] = _questions except: # noqa: E722 - logger.info(f"Warning: ChatGPT response parse fail, retry {attempts}.") + logger.info(f"Warning: LLM response parse fail, retry {attempts}.") attempts += 1 # ensure that each claim has at least one question which is the claim itself diff --git a/factcheck/core/Retriever/EvidenceRetrieve.py b/factcheck/core/Retriever/EvidenceRetrieve.py index 4611e62..b5dd548 100644 --- a/factcheck/core/Retriever/EvidenceRetrieve.py +++ b/factcheck/core/Retriever/EvidenceRetrieve.py @@ -1,27 +1,15 @@ -from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor -from concurrent.futures import as_completed import os from copy import deepcopy from factcheck.utils.web_util import parse_response, crawl_web -from factcheck.config.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() class EvidenceRetrieve: - def __init__(self, model: str = "gpt-3.5-turbo") -> None: - """Initialize the EvidenceRetrieve class. - sentences_per_passage: Number of sentences for each passage. - sliding_distance: Sliding distance over the text. Allows the passages to have overlap. The sliding distance cannot be greater than the window size. - - Args: - model (str, optional): The version of the GPT model used for evidence retrieval. Defaults to "gpt-3.5-turbo". - - Returns: - None - """ + def __init__(self, api_config: dict = None): + """Initialize the EvidenceRetrieve class.""" import spacy self.tokenizer = spacy.load("en_core_web_sm", disable=["ner", "tagger", "lemmatizer"]) diff --git a/factcheck/core/Retriever/GoogleEvidenceRetrieve.py b/factcheck/core/Retriever/GoogleEvidenceRetrieve.py index 3e7de48..58c15f3 100644 --- a/factcheck/core/Retriever/GoogleEvidenceRetrieve.py +++ b/factcheck/core/Retriever/GoogleEvidenceRetrieve.py @@ -1,16 +1,14 @@ -from __future__ import annotations from concurrent.futures import ThreadPoolExecutor from factcheck.utils.web_util import common_web_request, crawl_google_web -from factcheck.core.Retriever.EvidenceRetrieve import EvidenceRetrieve - -from factcheck.config.CustomLogger import CustomLogger +from .EvidenceRetrieve import EvidenceRetrieve +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() class GoogleEvidenceRetrieve(EvidenceRetrieve): - def __init__(self) -> None: - super(GoogleEvidenceRetrieve, self).__init__() + def __init__(self, api_config: dict = None) -> None: + super(GoogleEvidenceRetrieve, self).__init__(api_config) self.num_web_pages = 10 def _get_query_urls(self, questions: list[str]): diff --git a/factcheck/core/Retriever/SerperEvidenceRetrieve.py b/factcheck/core/Retriever/SerperEvidenceRetrieve.py index 558dea1..97c354d 100644 --- a/factcheck/core/Retriever/SerperEvidenceRetrieve.py +++ b/factcheck/core/Retriever/SerperEvidenceRetrieve.py @@ -1,25 +1,20 @@ -from __future__ import annotations from concurrent.futures import ThreadPoolExecutor import json import requests import os import re import bs4 -from factcheck.config.secret_dict import serper_dict -from factcheck.config.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger from factcheck.utils.web_util import crawl_web logger = CustomLogger(__name__).getlog() class SerperEvidenceRetrieve: - def __init__(self, model: str = "gpt-3.5-turbo") -> None: - """Initialize the SerperEvidenceRetrieve class - - Args: - model (str, optional): The version of the GPT model used for evidence retrieval. Defaults to "gpt-3.5-turbo". - """ + def __init__(self, api_config: dict = None): + """Initialize the SerperEvidenceRetrieve class""" self.lang = "en" + self.serper_key = api_config["SERPER_API_KEY"] def retrieve_evidence(self, claim_query_dict, top_k: int = 5, snippet_extend_flag: bool = True): """Retrieve evidences for the given claims @@ -183,7 +178,7 @@ def _request_serper_api(self, questions): url = "https://google.serper.dev/search" headers = { - "X-API-KEY": serper_dict.get("key"), + "X-API-KEY": self.serper_key, "Content-Type": "application/json", } diff --git a/factcheck/core/Retriever/__init__.py b/factcheck/core/Retriever/__init__.py index 08e96ab..fd443af 100644 --- a/factcheck/core/Retriever/__init__.py +++ b/factcheck/core/Retriever/__init__.py @@ -1,2 +1,13 @@ -from factcheck.core.Retriever.GoogleEvidenceRetrieve import GoogleEvidenceRetrieve -from factcheck.core.Retriever.SerperEvidenceRetrieve import SerperEvidenceRetrieve +from .GoogleEvidenceRetrieve import GoogleEvidenceRetrieve +from .SerperEvidenceRetrieve import SerperEvidenceRetrieve + +retriever_map = { + "google": GoogleEvidenceRetrieve, + "serper": SerperEvidenceRetrieve, +} + + +def retriever_mapper(retriever_name: str): + if retriever_name not in retriever_map: + raise NotImplementedError(f"Retriever {retriever_name} not found!") + return retriever_map[retriever_name] diff --git a/factcheck/core/__init__.py b/factcheck/core/__init__.py new file mode 100644 index 0000000..69adff3 --- /dev/null +++ b/factcheck/core/__init__.py @@ -0,0 +1,5 @@ +from .Decompose import Decompose +from .CheckWorthy import Checkworthy +from .QueryGenerator import QueryGenerator +from .Retriever import retriever_mapper +from .ClaimVerify import ClaimVerify diff --git a/factcheck/utils/GPTClient.py b/factcheck/utils/GPTClient.py deleted file mode 100644 index 4041ad8..0000000 --- a/factcheck/utils/GPTClient.py +++ /dev/null @@ -1,215 +0,0 @@ -from __future__ import annotations - -import asyncio -from openai import OpenAI -from anthropic import Anthropic -from collections import deque -import time -from factcheck.config.secret_dict import openai_dict, anthropic_dict -from functools import partial - - -class APIClient: - def __init__(self, model: str) -> None: - self.model = model - if self.model.startswith("gpt"): - self.client = OpenAI(api_key=openai_dict["key"]) - elif self.model.startswith("claude"): - self.client = Anthropic(api_key=anthropic_dict["key"]) - else: - raise ValueError("Model not supported") - - def complete(self, messages: str, seed: int): - response = "" - if self.model.startswith("gpt"): - response = self._oai_call(messages, seed) - elif self.model.startswith("claude"): - response = self._anthropic_call(messages, seed) - return response - - def construct_message_list( - self, - prompt_list: list[str], - system_role: str = "You are a helpful assistant designed to output JSON.", - ): - if self.model.startswith("gpt"): - return self._oai_construct_message_list(prompt_list, system_role) - elif self.model.startswith("claude"): - return self._anthropic_construct_message_list(prompt_list, system_role) - - def _oai_call(self, messages: str, seed: int): - response = "" - try: - response = self.client.chat.completions.create( - response_format={"type": "json_object"}, - seed=seed, - model=self.model, - messages=messages, - ) - except Exception as e: - print(f"Error ChatGPTClient: {e}") - pass - return response - - def _oai_construct_message_list( - self, - prompt_list: list[str], - system_role: str = "You are a helpful assistant designed to output JSON.", - ): - messages_list = list() - for prompt in prompt_list: - messages = [ - {"role": "system", "content": system_role}, - {"role": "user", "content": prompt}, - ] - messages_list.append(messages) - return messages_list - - def _anthropic_call(self, messages: str, seed: int): - response = "" - try: - response = self.client.messages.create( - messages=messages, - model=self.model, - max_tokens=2048, - ) - except Exception as e: - print(f"Error ChatGPTClient: {e}") - pass - return response - - def _anthropic_construct_message_list( - self, - prompt_list: list[str], - system_role: str = "You are a helpful assistant designed to output JSON.", - ): - # system role is not used in this case - messages_list = list() - for prompt in prompt_list: - messages = [ - {"role": "user", "content": prompt}, - ] - messages_list.append(messages) - return messages_list - - -class GPTClient: - def __init__( - self, - model: str = None, - max_traffic_bytes=1000000, - max_requests_per_minute=200, - request_window=60, - ): - self.max_traffic_bytes = max_traffic_bytes - self.max_requests_per_minute = max_requests_per_minute - self.request_window = request_window - self.traffic_queue = deque() - self.total_traffic = 0 - self.model = model - self.client = APIClient(model=self.model) - - def set_model(self, model: str): - self.model = model - - def _call(self, messages: str, seed: int): - return self.client.complete(messages, seed) - - def multi_call(self, messages: str, num_retries=3, waiting_time=1, seed=42): - r = "" - for _ in range(num_retries): - response = self._call(messages[0], seed=seed) - try: - r = response.choices[0].message.content - break - except Exception as e: - print(f"{e}. Retrying...") - time.sleep(waiting_time) - return r - - def get_request_length(self, messages): - # TODO: check if we should return the len(menages) instead - return 1 - - async def call_chatgpt_async(self, messages: list, key: str = None, seed=42): - """Calls ChatGPT asynchronously, tracks traffic, and enforces rate limits.""" - while len(self.traffic_queue) >= self.max_requests_per_minute: - await asyncio.sleep(1) - self.expire_old_traffic() - - loop = asyncio.get_running_loop() - # TODO: support seed - response = await loop.run_in_executor(None, partial(self._call, messages, seed=seed)) - - self.total_traffic += self.get_request_length(messages) - self.traffic_queue.append((time.time(), self.get_request_length(messages))) - - result = response.choices[0].message.content - if key: - return key, result - else: - return result - - def call_chatgpt_multiple_async(self, messages_list, seed=42): - """Calls ChatGPT asynchronously for multiple prompts and returns a list of responses.""" - tasks = [self.call_chatgpt_async(messages=messages, seed=seed) for messages in messages_list] - asyncio.set_event_loop(asyncio.SelectorEventLoop()) - loop = asyncio.get_event_loop() - responses = loop.run_until_complete(asyncio.gather(*tasks)) - return responses - - def call_chatgpt_multiple_async_with_key(self, messages_dict): - """Calls ChatGPT asynchronously for multiple prompts and returns a list of responses.""" - tasks = [self.call_chatgpt_async(messages=messages, key=key) for key, messages in messages_dict.items()] - asyncio.set_event_loop(asyncio.SelectorEventLoop()) - loop = asyncio.get_event_loop() - responses = loop.run_until_complete(asyncio.gather(*tasks)) - return responses - - def expire_old_traffic(self): - """Expires traffic older than the request window.""" - current_time = time.time() - while self.traffic_queue and self.traffic_queue[0][0] + self.request_window < current_time: - self.total_traffic -= self.traffic_queue.popleft()[1] - - def construct_message_dict( - self, - prompt_list: list[str], - match_key_list: list[str], - system_role: str = "You are a helpful factcheck assistant designed to output JSON.", - ): - assert len(prompt_list) == len(match_key_list), "match_key_list length has to be equal to prompt_list length" - messages_dict = dict() - for key, prompt in zip(match_key_list, prompt_list): - messages = [ - {"role": "system", "content": system_role}, - {"role": "user", "content": prompt}, - ] - messages_dict[key] = messages - return messages_dict - - def construct_message_list( - self, - prompt_list: list[str], - system_role: str = "You are a helpful assistant designed to output JSON.", - ): - return self.client.construct_message_list(prompt_list, system_role) - - -def main(): - """Example usage.""" - client = GPTClient() - prompts = ["ping", "pong", "ping"] - messages_list = client.construct_message_list(prompts) - responses = client.call_chatgpt_multiple_async(messages_list) - print(responses) - - match_key_list = ["1", "2", "3"] - messages_dict = client.construct_message_dict(prompts, match_key_list) - responses = client.call_chatgpt_multiple_async_with_key(messages_dict=messages_dict) - for key, result in responses: - print(key, ": ", result) - - -if __name__ == "__main__": - main() diff --git a/factcheck/utils/api_config.py b/factcheck/utils/api_config.py new file mode 100644 index 0000000..ba9de9d --- /dev/null +++ b/factcheck/utils/api_config.py @@ -0,0 +1,29 @@ +import os + +# Define all keys for the API configuration +keys = [ + "SERPER_API_KEY", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "LOCAL_API_KEY", + "LOCAL_API_URL", +] + + +def load_api_config(api_config: dict = None): + """Load API keys from environment variables or config file, config file take precedence + + Args: + api_config (dict, optional): _description_. Defaults to None. + """ + if api_config is None: + api_config = dict() + assert type(api_config) is dict, "api_config must be a dictionary." + + merged_config = {} + + for key in keys: + merged_config[key] = api_config.get(key, None) + if merged_config[key] is None: + merged_config[key] = os.environ.get(key, None) + return merged_config diff --git a/factcheck/utils/llmclient/__init__.py b/factcheck/utils/llmclient/__init__.py new file mode 100644 index 0000000..a7a6316 --- /dev/null +++ b/factcheck/utils/llmclient/__init__.py @@ -0,0 +1,23 @@ +from .gpt_client import GPTClient +from .claude_client import ClaudeClient +from .local_openai_client import LocalOpenAIClient + +# fmt: off +CLIENTS = { + "gpt": GPTClient, + "claude": ClaudeClient, + "local_openai": LocalOpenAIClient +} +# fmt: on + + +def model2client(model_name: str): + """If the client is not specified, use this function to map the model name to the corresponding client.""" + if model_name.startswith("gpt"): + return GPTClient + elif model_name.startswith("claude"): + return ClaudeClient + elif model_name.startswith("vicuna"): + return LocalOpenAIClient + else: + raise ValueError(f"Model {model_name} not supported.") diff --git a/factcheck/utils/llmclient/base.py b/factcheck/utils/llmclient/base.py new file mode 100644 index 0000000..65df1c9 --- /dev/null +++ b/factcheck/utils/llmclient/base.py @@ -0,0 +1,84 @@ +import time +import asyncio +from abc import abstractmethod +from functools import partial +from collections import deque + + +class BaseClient: + def __init__( + self, + model: str, + api_config: dict, + max_requests_per_minute: int, + request_window: int, + ) -> None: + self.model = model + self.api_config = api_config + self.max_requests_per_minute = max_requests_per_minute + self.request_window = request_window + self.traffic_queue = deque() + self.total_traffic = 0 + + @abstractmethod + def _call(self, messages: str): + """Internal function to call the API.""" + pass + + @abstractmethod + def construct_message_list(self, prompt_list: list[str]) -> list[str]: + """Construct a list of messages for the function self.multi_call.""" + raise NotImplementedError + + @abstractmethod + def get_request_length(self, messages): + """Get the length of the request. Used for tracking traffic.""" + raise NotImplementedError + + def call(self, messages: list[str], num_retries=3, waiting_time=1, **kwargs): + seed = kwargs.get("seed", 42) + assert type(seed) is int, "Seed must be an integer." + assert len(messages) == 1, "Only one message is allowed for this function." + + r = "" + for _ in range(num_retries): + try: + r = self._call(messages[0], seed=seed) + break + except Exception as e: + print(f"Error LLM Client call: {e} Retrying...") + time.sleep(waiting_time) + + if r == "": + raise ValueError("Failed to get response from LLM Client.") + return r + + def set_model(self, model: str): + self.model = model + + async def _async_call(self, messages: list, **kwargs): + """Calls ChatGPT asynchronously, tracks traffic, and enforces rate limits.""" + while len(self.traffic_queue) >= self.max_requests_per_minute: + await asyncio.sleep(1) + self._expire_old_traffic() + + loop = asyncio.get_running_loop() + response = await loop.run_in_executor(None, partial(self._call, messages, **kwargs)) + + self.total_traffic += self.get_request_length(messages) + self.traffic_queue.append((time.time(), self.get_request_length(messages))) + + return response + + def multi_call(self, messages_list, **kwargs): + tasks = [self._async_call(messages=messages, **kwargs) for messages in messages_list] + asyncio.set_event_loop(asyncio.SelectorEventLoop()) + loop = asyncio.get_event_loop() + responses = loop.run_until_complete(asyncio.gather(*tasks)) + return responses + + def _expire_old_traffic(self): + """Expires traffic older than the request window.""" + current_time = time.time() + while self.traffic_queue and self.traffic_queue[0][0] + self.request_window < current_time: + self.total_traffic -= self.traffic_queue.popleft()[1] diff --git a/factcheck/utils/llmclient/claude_client.py b/factcheck/utils/llmclient/claude_client.py new file mode 100644 index 0000000..4423650 --- /dev/null +++ b/factcheck/utils/llmclient/claude_client.py @@ -0,0 +1,42 @@ +import time +from anthropic import Anthropic +from .base import BaseClient + + +class ClaudeClient(BaseClient): + def __init__( + self, + model: str = "claude-3-opus-20240229", + api_config: dict = None, + max_requests_per_minute=200, + request_window=60, + ): + super().__init__(model, api_config, max_requests_per_minute, request_window) + self.client = Anthropic(api_key=self.api_config["ANTHROPIC_API_KEY"]) + + def _call(self, messages: str, **kwargs): + response = self.client.messages.create( + messages=messages, + model=self.model, + max_tokens=2048, + ) + return response.content[0].text + + def get_request_length(self, messages): + return 1 + + def construct_message_list( + self, + prompt_list: list[str], + system_role: str = None, + ): + if system_role is None: + Warning("system_role is not used in this case") + # system role is not used in this case + messages_list = list() + for prompt in prompt_list: + messages = [ + {"role": "user", "content": prompt}, + ] + messages_list.append(messages) + return messages_list diff --git a/factcheck/utils/llmclient/gpt_client.py b/factcheck/utils/llmclient/gpt_client.py new file mode 100644 index 0000000..58122e4 --- /dev/null +++ b/factcheck/utils/llmclient/gpt_client.py @@ -0,0 +1,46 @@ +import time +from openai import OpenAI +from .base import BaseClient + + +class GPTClient(BaseClient): + def __init__( + self, + model: str = "gpt-4-turbo", + api_config: dict = None, + max_requests_per_minute=200, + request_window=60, + ): + super().__init__(model, api_config, max_requests_per_minute, request_window) + self.client = OpenAI(api_key=self.api_config["OPENAI_API_KEY"]) + + def _call(self, messages: str, **kwargs): + seed = kwargs.get("seed", 42) # default seed is 42 + assert type(seed) is int, "Seed must be an integer." + + response = self.client.chat.completions.create( + response_format={"type": "json_object"}, + seed=seed, + model=self.model, + messages=messages, + ) + r = response.choices[0].message.content + return r + + def get_request_length(self, messages): + # TODO: check if we should return the len(menages) instead + return 1 + + def construct_message_list( + self, + prompt_list: list[str], + system_role: str = "You are a helpful assistant designed to output JSON.", + ): + messages_list = list() + for prompt in prompt_list: + messages = [ + {"role": "system", "content": system_role}, + {"role": "user", "content": prompt}, + ] + messages_list.append(messages) + return messages_list diff --git a/factcheck/utils/llmclient/local_openai_client.py b/factcheck/utils/llmclient/local_openai_client.py new file mode 100644 index 0000000..2c179fd --- /dev/null +++ b/factcheck/utils/llmclient/local_openai_client.py @@ -0,0 +1,53 @@ +import time +import openai +from openai import OpenAI +from .base import BaseClient + + +class LocalOpenAIClient(BaseClient): + """Support Local host LLM chatbot with OpenAI API. + see https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md for example usage. + """ + + def __init__( + self, + model: str = "", + api_config: dict = None, + max_requests_per_minute=200, + request_window=60, + ): + super().__init__(model, api_config, max_requests_per_minute, request_window) + + openai.api_key = api_config["LOCAL_API_KEY"] + openai.base_url = api_config["LOCAL_API_URL"] + + def _call(self, messages: str, **kwargs): + seed = kwargs.get("seed", 42) # default seed is 42 + assert type(seed) is int, "Seed must be an integer." + + response = openai.chat.completions.create( + response_format={"type": "json_object"}, + seed=seed, + model=self.model, + messages=messages, + ) + r = response.choices[0].message.content + return r + + def get_request_length(self, messages): + # TODO: check if we should return the len(menages) instead + return 1 + + def construct_message_list( + self, + prompt_list: list[str], + system_role: str = "You are a helpful assistant designed to output JSON.", + ): + messages_list = list() + for prompt in prompt_list: + messages = [ + {"role": "system", "content": system_role}, + {"role": "user", "content": prompt}, + ] + messages_list.append(messages) + return messages_list diff --git a/factcheck/config/CustomLogger.py b/factcheck/utils/logger.py similarity index 100% rename from factcheck/config/CustomLogger.py rename to factcheck/utils/logger.py diff --git a/factcheck/utils/multimodal.py b/factcheck/utils/multimodal.py index 4295d88..0f78049 100644 --- a/factcheck/utils/multimodal.py +++ b/factcheck/utils/multimodal.py @@ -1,22 +1,21 @@ -from factcheck.config.secret_dict import openai_dict from openai import OpenAI import cv2 import base64 import requests -from factcheck.config.CustomLogger import CustomLogger +from .logger import CustomLogger logger = CustomLogger(__name__).getlog() -def voice2text(input): +def voice2text(input, openai_key): # voice to input - client = OpenAI(api_key=openai_dict["key"]) + client = OpenAI(api_key=openai_key) audio_file = open(input, "rb") transcription = client.audio.transcriptions.create(model="whisper-1", file=audio_file) return transcription.text -def image2text(input): +def image2text(input, openai_key): # Function to encode the image def encode_image(image_path): with open(image_path, "rb") as image_file: @@ -27,7 +26,7 @@ def encode_image(image_path): headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {openai_dict['key']}", + "Authorization": f"Bearer {openai_key}", } payload = { @@ -51,7 +50,7 @@ def encode_image(image_path): return caption.json()["choices"][0]["message"]["content"] -def video2text(input): +def video2text(input, openai_key): # Read the video and convert it to pictures video = cv2.VideoCapture(input) @@ -66,7 +65,7 @@ def video2text(input): video.release() # Process the pictures with GPT4-V - client = OpenAI(api_key=openai_dict["key"]) + client = OpenAI(api_key=openai_key) PROMPT_MESSAGES = [ { "role": "user", @@ -86,7 +85,7 @@ def video2text(input): return result.choices[0].message.content -def modal_normalization(modal="text", input=None): +def modal_normalization(modal="text", input=None, openai_key=None): logger.info(f"== Processing: Modal: {modal}, Input: {input}") if modal == "string": response = str(input) @@ -94,11 +93,11 @@ def modal_normalization(modal="text", input=None): with open(input, "r") as f: response = f.read() elif modal == "speech": - response = voice2text(input) + response = voice2text(input, openai_key) elif modal == "image": - response = image2text(input) + response = image2text(input, openai_key) elif modal == "video": - response = video2text(input) + response = video2text(input, openai_key) else: raise NotImplementedError logger.info(f"== Processed: Modal: {modal}, Input: {input}") diff --git a/factcheck/utils/prompt/__init__.py b/factcheck/utils/prompt/__init__.py new file mode 100644 index 0000000..4f10919 --- /dev/null +++ b/factcheck/utils/prompt/__init__.py @@ -0,0 +1,17 @@ +from .chatgpt_prompt import ChatGPTPrompt +from .claude_prompt import ClaudePrompt +from .customized_prompt import CustomizedPrompt + +prompt_map = { + "chatgpt_prompt": ChatGPTPrompt, + "claude_prompt": ClaudePrompt, +} + + +def prompt_mapper(prompt_name: str): + if prompt_name in prompt_map: + return prompt_map[prompt_name]() + elif prompt_name.endswith("yaml") or prompt_name.endswith("json"): + return CustomizedPrompt(prompt_name) + else: + raise NotImplementedError(f"Prompt {prompt_name} not implemented.") diff --git a/factcheck/utils/prompt/base.py b/factcheck/utils/prompt/base.py new file mode 100644 index 0000000..1eab47f --- /dev/null +++ b/factcheck/utils/prompt/base.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + + +@dataclass +class BasePrompt: + decompose_prompt: str = None + checkworthy_prompt: str = None + qgen_prompt: str = None + verify_prompt: str = None diff --git a/factcheck/utils/prompt.py b/factcheck/utils/prompt/chatgpt_prompt.py similarity index 95% rename from factcheck/utils/prompt.py rename to factcheck/utils/prompt/chatgpt_prompt.py index 005fd4d..5213021 100644 --- a/factcheck/utils/prompt.py +++ b/factcheck/utils/prompt/chatgpt_prompt.py @@ -1,5 +1,21 @@ -# Used prompts -CHECKWORTHY_PROMPT = """ +decompose_prompt = """ +Your task is to decompose the text into atomic claims. +The answer should be a JSON with a single key "claims", with the value of a list of strings, where each string should be a context-independent claim, representing one fact. +Note that: +1. Each claim should be concise (less than 15 words) and self-contained. +2. Avoid vague references like 'he', 'she', 'it', 'this', 'the company', 'the man' and using complete names. +3. Generate at least one claim for each single sentence in the texts. + +For example, +Text: Mary is a five-year old girl, she likes playing piano and she doesn't like cookies. +Output: +{{"claims": ["Mary is a five-year old girl.", "Mary likes playing piano.", "Mary doesn't like cookies."]}} + +Text: {doc} +Output: +""" + +checkworthy_prompt = """ Your task is to evaluate each provided statement to determine if it presents information whose factuality can be objectively verified by humans, irrespective of the statement's current accuracy. Consider the following guidelines: 1. Opinions versus Facts: Distinguish between opinions, which are subjective and not verifiable, and statements that assert factual information, even if broad or general. Focus on whether there's a factual claim that can be investigated. 2. Clarity and Specificity: Statements must have clear and specific references to be verifiable (e.g., "he is a professor" is not verifiable without knowing who "he" is). @@ -24,7 +40,28 @@ The output should be: """ -VERIFY_PROMPT = """ +qgen_prompt = """Given a claim, your task is to create minimum number of questions need to be check to verify the correctness of the claim. Output in JSON format with a single key "Questions", the value is a list of questions. For example: + +Claim: Your nose switches back and forth between nostrils. When you sleep, you switch about every 45 minutes. This is to prevent a buildup of mucus. It’s called the nasal cycle. +Output: {{"Questions": ["Does your nose switch between nostrils?", "How often does your nostrils switch?", "Why does your nostril switch?", "What is nasal cycle?"]}} + +Claim: The Stanford Prison Experiment was conducted in the basement of Encina Hall, Stanford’s psychology building. +Output: +{{"Question":["Where was Stanford Prison Experiment was conducted?"]}} + +Claim: The Havel-Hakimi algorithm is an algorithm for converting the adjacency matrix of a graph into its adjacency list. It is named after Vaclav Havel and Samih Hakimi. +Output: +{{"Questions":["What does Havel-Hakimi algorithm do?", "Who are Havel-Hakimi algorithm named after?"]}} + +Claim: Social work is a profession that is based in the philosophical tradition of humanism. It is an intellectual discipline that has its roots in the 1800s. +Output: +{{"Questions":["What philosophical tradition is social work based on?", "What year does social work have its root in?"]}} + +Claim: {claim} +Output: +""" + +verify_prompt = """ Your task is to evaluate the accuracy of a provided statement using the accompanying evidence. Carefully review the evidence, noting that it may vary in detail and sometimes present conflicting information. Your judgment should be informed by this evidence, taking into account its relevance and reliability. Keep in mind that a lack of detail in the evidence does not necessarily indicate that the statement is inaccurate. When assessing the statement's factuality, distinguish between errors and areas where the evidence supports the statement. @@ -69,40 +106,9 @@ Output: """ -SENTENCES_TO_CLAIMS_PROMPT = """ -Your task is to decompose the text into atomic claims. -The answer should be a JSON with a single key "claims", with the value of a list of strings, where each string should be a context-independent claim, representing one fact. -Note that: -1. Each claim should be concise (less than 15 words) and self-contained. -2. Avoid vague references like 'he', 'she', 'it', 'this', 'the company', 'the man' and using complete names. -3. Generate at least one claim for each single sentence in the texts. - -For example, -Text: Mary is a five-year old girl, she likes playing piano and she doesn't like cookies. -Output: -{{"claims": ["Mary is a five-year old girl.", "Mary likes playing piano.", "Mary doesn't like cookies."]}} - -Text: {doc} -Output: -""" - -QGEN_PROMPT = """Given a claim, your task is to create minimum number of questions need to be check to verify the correctness of the claim. Output in JSON format with a single key "Questions", the value is a list of questions. For example: -Claim: Your nose switches back and forth between nostrils. When you sleep, you switch about every 45 minutes. This is to prevent a buildup of mucus. It’s called the nasal cycle. -Output: {{"Questions": ["Does your nose switch between nostrils?", "How often does your nostrils switch?", "Why does your nostril switch?", "What is nasal cycle?"]}} - -Claim: The Stanford Prison Experiment was conducted in the basement of Encina Hall, Stanford’s psychology building. -Output: -{{"Question":["Where was Stanford Prison Experiment was conducted?"]}} - -Claim: The Havel-Hakimi algorithm is an algorithm for converting the adjacency matrix of a graph into its adjacency list. It is named after Vaclav Havel and Samih Hakimi. -Output: -{{"Questions":["What does Havel-Hakimi algorithm do?", "Who are Havel-Hakimi algorithm named after?"]}} - -Claim: Social work is a profession that is based in the philosophical tradition of humanism. It is an intellectual discipline that has its roots in the 1800s. -Output: -{{"Questions":["What philosophical tradition is social work based on?", "What year does social work have its root in?"]}} - -Claim: {claim} -Output: -""" +class ChatGPTPrompt: + decompose_prompt = decompose_prompt + checkworthy_prompt = checkworthy_prompt + qgen_prompt = qgen_prompt + verify_prompt = verify_prompt diff --git a/factcheck/utils/prompt/claude_prompt.py b/factcheck/utils/prompt/claude_prompt.py new file mode 100644 index 0000000..7899948 --- /dev/null +++ b/factcheck/utils/prompt/claude_prompt.py @@ -0,0 +1,114 @@ +decompose_prompt = """ +Your task is to decompose the text into atomic claims. +The answer should be a JSON with a single key "claims", with the value of a list of strings, where each string should be a context-independent claim, representing one fact. +Note that: +1. Each claim should be concise (less than 15 words) and self-contained. +2. Avoid vague references like 'he', 'she', 'it', 'this', 'the company', 'the man' and using complete names. +3. Generate at least one claim for each single sentence in the texts. + +For example, +Text: Mary is a five-year old girl, she likes playing piano and she doesn't like cookies. +Output: +{{"claims": ["Mary is a five-year old girl.", "Mary likes playing piano.", "Mary doesn't like cookies."]}} + +Text: {doc} +Output: +""" + +checkworthy_prompt = """ +Your task is to evaluate each provided statement to determine if it presents information whose factuality can be objectively verified by humans, irrespective of the statement's current accuracy. Consider the following guidelines: +1. Opinions versus Facts: Distinguish between opinions, which are subjective and not verifiable, and statements that assert factual information, even if broad or general. Focus on whether there's a factual claim that can be investigated. +2. Clarity and Specificity: Statements must have clear and specific references to be verifiable (e.g., "he is a professor" is not verifiable without knowing who "he" is). +3. Presence of Factual Information: Consider a statement verifiable if it includes factual elements that can be checked against evidence or reliable sources, even if the overall statement might be broad or incorrect. +Your response should be in JSON format, with each statement as a key and either "Yes" or "No" as the value, along with a brief rationale for your decision. + +For example, given these statements: +1. Gary Smith is a distinguished professor of economics. +2. He is a professor at MBZUAI. +3. Obama is the president of the UK. + +The expected output is a JSON: +{{ + "Gary Smith is a distinguished professor of economics.": "Yes (The statement contains verifiable factual information about Gary Smith's professional title and field.)", + "He is a professor at MBZUAI.": "No (The statement cannot be verified due to the lack of clear reference to who 'he' is.)", + "Obama is the president of the UK.": "Yes (This statement contain verifiable information regarding the political leadership of a country.)" +}} + +For these statements: +{texts} + +The output should be a JSON: +""" + +qgen_prompt = """Given a claim, your task is to create minimum number of questions need to be check to verify the correctness of the claim. Output in JSON format with a single key "Questions", the value is a list of questions. For example: + +Claim: Your nose switches back and forth between nostrils. When you sleep, you switch about every 45 minutes. This is to prevent a buildup of mucus. It’s called the nasal cycle. +JSON Output: {{"Questions": ["Does your nose switch between nostrils?", "How often does your nostrils switch?", "Why does your nostril switch?", "What is nasal cycle?"]}} + +Claim: The Stanford Prison Experiment was conducted in the basement of Encina Hall, Stanford’s psychology building. +JSON Output: +{{"Question":["Where was Stanford Prison Experiment was conducted?"]}} + +Claim: The Havel-Hakimi algorithm is an algorithm for converting the adjacency matrix of a graph into its adjacency list. It is named after Vaclav Havel and Samih Hakimi. +JSON Output: +{{"Questions":["What does Havel-Hakimi algorithm do?", "Who are Havel-Hakimi algorithm named after?"]}} + +Claim: Social work is a profession that is based in the philosophical tradition of humanism. It is an intellectual discipline that has its roots in the 1800s. +Output: +{{"Questions":["What philosophical tradition is social work based on?", "What year does social work have its root in?"]}} + +Claim: {claim} +JSON Output: +""" + +verify_prompt = """ +Your task is to evaluate the accuracy of a provided statement using the accompanying evidence. Carefully review the evidence, noting that it may vary in detail and sometimes present conflicting information. Your judgment should be informed by this evidence, taking into account its relevance and reliability. + +Keep in mind that a lack of detail in the evidence does not necessarily indicate that the statement is inaccurate. When assessing the statement's factuality, distinguish between errors and areas where the evidence supports the statement. + +Please structure your response in JSON format, including the following four keys: +- "reasoning": explain the thought process behind your judgment. +- "error": none if the text is factual; otherwise, identify any specific inaccuracies in the statement. +- "correction": none if the text is factual; otherwise, provide corrections to any identified inaccuracies, using the evidence to support your corrections. +- "factuality": true if the given text is factual, false otherwise, indicating whether the statement is factual, or non-factual based on the evidence. + +For example: +Input: +[text]: MBZUAI is located in Abu Dhabi, United Arab Emirates. +[evidence]: Where is MBZUAI located?\nAnswer: Masdar City - Abu Dhabi - United Arab Emirates + +JSON Output: +{{ + "reasoning": "The evidence confirms that MBZUAI is located in Masdar City, Abu Dhabi, United Arab Emirates, so the statement is factually correct", + "error": none, + "correction": none, + "factuality": true +}} + + +Input: +[text]: Copper reacts with ferrous sulfate (FeSO4). +[evidence]: Copper is less reactive metal. It has positive value of standard reduction potential. Metal with high standard reduction potential can not displace other metal with low standard reduction potential values. Hence copper can not displace iron from ferrous sulphate solution. So no change will take place. + +JSON Output: +{{ + "reasoning": "The evidence provided confirms that copper cannot displace iron from ferrous sulphate solution, and no change will take place.", + "error": "Copper does not react with ferrous sulfate as stated in the text.", + "correction": "Copper does not react with ferrous sulfate as it cannot displace iron from ferrous sulfate solution.", + "factuality": false +}} + + +Input +[text]: {claim} +[evidences]: {evidence} + +JSON Output: +""" + + +class ClaudePrompt: + decompose_prompt = decompose_prompt + checkworthy_prompt = checkworthy_prompt + qgen_prompt = qgen_prompt + verify_prompt = verify_prompt diff --git a/factcheck/utils/prompt/customized_prompt.py b/factcheck/utils/prompt/customized_prompt.py new file mode 100644 index 0000000..8aa829c --- /dev/null +++ b/factcheck/utils/prompt/customized_prompt.py @@ -0,0 +1,33 @@ +import yaml +import json +from .base import BasePrompt + + +class CustomizedPrompt(BasePrompt): + def __init__(self, CustomizedPrompt): + if CustomizedPrompt.endswith("yaml"): + self.prompts = self.load_prompt_yaml(CustomizedPrompt) + elif CustomizedPrompt.endswith("json"): + self.prompts = self.load_prompt_json(CustomizedPrompt) + else: + raise NotImplementedError(f"File type of {CustomizedPrompt} not implemented.") + keys = [ + "decompose_prompt", + "checkworthy_prompt", + "qgen_prompt", + "verify_prompt", + ] + + for key in keys: + assert key in self.prompts, f"Key {key} not found in the prompt yaml file." + setattr(self, key, self.prompts[key]) + + def load_prompt_yaml(self, prompt_name): + # Load the prompt from a yaml file + with open(prompt_name, "r") as file: + return yaml.safe_load(file) + + def load_prompt_json(self, prompt_name): + # Load the prompt from a json file + with open(prompt_name, "r") as file: + return json.load(file) diff --git a/factcheck/utils/utils.py b/factcheck/utils/utils.py new file mode 100644 index 0000000..345f697 --- /dev/null +++ b/factcheck/utils/utils.py @@ -0,0 +1,6 @@ +import yaml + + +def load_yaml(filepath): + with open(filepath, "r") as file: + return yaml.safe_load(file) diff --git a/requirements.txt b/requirements.txt index 50a51d8..fc1f9cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ bs4 flask httpx nltk -openai +openai>=1.0.0 opencv-python pandas playwright diff --git a/templates/input.html b/templates/input.html index c60d69c..e6d8124 100644 --- a/templates/input.html +++ b/templates/input.html @@ -85,7 +85,39 @@ bottom: 0; border: #a9bee3 solid 2px; } + + #timer { + margin-left: 80%; + width: 18%; + padding: 10px 20px; + border: none; + border-radius: 4px; + } + + @@ -102,8 +134,9 @@