diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3f65d35..0a0aef1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -60,6 +60,11 @@ jobs: - name: Install dependencies run: poetry install --extras "all" + - name: Run Lint + run: | + poetry run ruff check ./graphrag_sdk/ + poetry run black ./graphrag_sdk/ --check + - name: Wait for Ollama to be ready run: | until curl -s http://localhost:11434; do diff --git a/examples/movies/demo-movies.ipynb b/examples/movies/demo-movies.ipynb index 6b0d55c..92ecddc 100644 --- a/examples/movies/demo-movies.ipynb +++ b/examples/movies/demo-movies.ipynb @@ -34,7 +34,7 @@ "\n", "# Load environment variables\n", "load_dotenv()\n", - "logging.disable(logging.CRITICAL)\n" + "logging.disable(logging.CRITICAL)" ] }, { @@ -57,7 +57,7 @@ "falkor_host = \"\"\n", "falkor_port = None\n", "falkor_username = \"\"\n", - "falkor_password = \"\"\n" + "falkor_password = \"\"" ] }, { @@ -75,12 +75,14 @@ "metadata": {}, "outputs": [], "source": [ - "urls = [\"https://www.rottentomatoes.com/m/side_by_side_2012\",\n", - "\"https://www.rottentomatoes.com/m/matrix\",\n", - "\"https://www.rottentomatoes.com/m/matrix_revolutions\",\n", - "\"https://www.rottentomatoes.com/m/matrix_reloaded\",\n", - "\"https://www.rottentomatoes.com/m/speed_1994\",\n", - "\"https://www.rottentomatoes.com/m/john_wick_chapter_4\"]\n", + "urls = [\n", + " \"https://www.rottentomatoes.com/m/side_by_side_2012\",\n", + " \"https://www.rottentomatoes.com/m/matrix\",\n", + " \"https://www.rottentomatoes.com/m/matrix_revolutions\",\n", + " \"https://www.rottentomatoes.com/m/matrix_reloaded\",\n", + " \"https://www.rottentomatoes.com/m/speed_1994\",\n", + " \"https://www.rottentomatoes.com/m/john_wick_chapter_4\",\n", + "]\n", "\n", "sources = [URL(url) for url in urls]" ] @@ -136,7 +138,7 @@ " host=falkor_host,\n", " port=falkor_port,\n", " username=falkor_username,\n", - " password=falkor_password\n", + " password=falkor_password,\n", ")\n", "kg.process_sources(sources)" ] @@ -190,7 +192,9 @@ "answer = chat.send_message(\"Who is the director of the movie Side by Side?\")\n", "print(f\"Q: {answer['question']} \\nA: {answer['response']}\\n\")\n", "\n", - "answer = chat.send_message(\"Order the directors that you mentioned in all of our conversation by lexical order.\")\n", + "answer = chat.send_message(\n", + " \"Order the directors that you mentioned in all of our conversation by lexical order.\"\n", + ")\n", "print(f\"Q: {answer['question']} \\nA: {answer['response']}\\n\")" ] } diff --git a/examples/trip/demo_orchestrator_trip.ipynb b/examples/trip/demo_orchestrator_trip.ipynb index 989197f..fbb662d 100644 --- a/examples/trip/demo_orchestrator_trip.ipynb +++ b/examples/trip/demo_orchestrator_trip.ipynb @@ -41,7 +41,13 @@ "from graphrag_sdk.agents.kg_agent import KGAgent\n", "from graphrag_sdk.models.openai import OpenAiGenerativeModel\n", "from graphrag_sdk import (\n", - " Ontology, Entity, Relation, Attribute, AttributeType, KnowledgeGraph, KnowledgeGraphModelConfig\n", + " Ontology,\n", + " Entity,\n", + " Relation,\n", + " Attribute,\n", + " AttributeType,\n", + " KnowledgeGraph,\n", + " KnowledgeGraphModelConfig,\n", ")\n", "\n", "# Load environment variables\n", @@ -402,7 +408,9 @@ "orchestrator.register_agent(attractions_agent)\n", "\n", "# Query the orchestrator\n", - "runner = orchestrator.ask(\"Create a two-day itinerary for a trip to Rome. Please don't ask me any questions. Just provide the best itinerary you can.\")" + "runner = orchestrator.ask(\n", + " \"Create a two-day itinerary for a trip to Rome. Please don't ask me any questions. Just provide the best itinerary you can.\"\n", + ")" ] }, { @@ -464,7 +472,9 @@ } ], "source": [ - "runner = orchestrator.ask(\"Please tell me only the name of the restaurant for dinner at the first day that you mention in the itinerary of the trip\")\n", + "runner = orchestrator.ask(\n", + " \"Please tell me only the name of the restaurant for dinner at the first day that you mention in the itinerary of the trip\"\n", + ")\n", "print(runner.output)" ] }, @@ -510,7 +520,9 @@ } ], "source": [ - "runner = orchestrator.ask(\"Can you change this restaurant to another one and give me the updated itinerary?\")\n", + "runner = orchestrator.ask(\n", + " \"Can you change this restaurant to another one and give me the updated itinerary?\"\n", + ")\n", "print(runner.output)" ] }, @@ -528,7 +540,9 @@ } ], "source": [ - "runner = orchestrator.ask(\"Please tell me attraction on the morning at the first day that you mention in the itinerary of the trip\")\n", + "runner = orchestrator.ask(\n", + " \"Please tell me attraction on the morning at the first day that you mention in the itinerary of the trip\"\n", + ")\n", "print(runner.output)" ] } diff --git a/graphrag_sdk/__init__.py b/graphrag_sdk/__init__.py index 0f50159..e233157 100644 --- a/graphrag_sdk/__init__.py +++ b/graphrag_sdk/__init__.py @@ -37,4 +37,4 @@ "Relation", "Attribute", "AttributeType", -] \ No newline at end of file +] diff --git a/graphrag_sdk/agents/__init__.py b/graphrag_sdk/agents/__init__.py index be1ac3e..a6df24a 100644 --- a/graphrag_sdk/agents/__init__.py +++ b/graphrag_sdk/agents/__init__.py @@ -1,3 +1,3 @@ from .agent import Agent -__all__ = ['Agent'] +__all__ = ["Agent"] diff --git a/graphrag_sdk/agents/kg_agent.py b/graphrag_sdk/agents/kg_agent.py index ee5d8bc..8551b5d 100644 --- a/graphrag_sdk/agents/kg_agent.py +++ b/graphrag_sdk/agents/kg_agent.py @@ -126,7 +126,7 @@ def run(self, params: dict) -> str: """ output = self.chat_session.send_message(params["prompt"]) - return output['response'] + return output["response"] def __repr__(self) -> str: """ diff --git a/graphrag_sdk/attribute.py b/graphrag_sdk/attribute.py index 522706d..fc0cfce 100644 --- a/graphrag_sdk/attribute.py +++ b/graphrag_sdk/attribute.py @@ -58,18 +58,18 @@ def from_string(txt: str) -> "AttributeType": raise ValueError(f"Invalid attribute type: {txt}") class Attribute: - """ Represents an attribute of an entity or relation in the ontology. - - Args: - name (str): The name of the attribute. - attr_type (AttributeType): The type of the attribute. - unique (bool): Whether the attribute is unique. - required (bool): Whether the attribute is required. - - Examples: - >>> attr = Attribute("name", AttributeType.STRING, True, True) - >>> print(attr) - name: "string!*" + """Represents an attribute of an entity or relation in the ontology. + + Args: + name (str): The name of the attribute. + attr_type (AttributeType): The type of the attribute. + unique (bool): Whether the attribute is unique. + required (bool): Whether the attribute is required. + + Examples: + >>> attr = Attribute("name", AttributeType.STRING, True, True) + >>> print(attr) + name: "string!*" """ def __init__( @@ -168,4 +168,4 @@ def __str__(self) -> str: Returns: str: A string representation of the Attribute object. """ - return f"{self.name}: \"{self.type}{'!' if self.unique else ''}{'*' if self.required else ''}\"" + return f'{self.name}: "{self.type}{"!" if self.unique else ""}{"*" if self.required else ""}"' diff --git a/graphrag_sdk/chat_session.py b/graphrag_sdk/chat_session.py index e013cde..8b7b409 100644 --- a/graphrag_sdk/chat_session.py +++ b/graphrag_sdk/chat_session.py @@ -25,9 +25,17 @@ class ChatSession: >>> chat_session.send_message("What is the capital of France?") """ - def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph, - cypher_system_instruction: str, qa_system_instruction: str, - cypher_gen_prompt: str, qa_prompt: str, cypher_gen_prompt_history: str): + def __init__( + self, + model_config: KnowledgeGraphModelConfig, + ontology: Ontology, + graph: Graph, + cypher_system_instruction: str, + qa_system_instruction: str, + cypher_gen_prompt: str, + qa_prompt: str, + cypher_gen_prompt_history: str, + ): """ Initializes a new ChatSession object. @@ -73,9 +81,9 @@ def send_message(self, message: str) -> dict: Returns: dict: The response to the message in the following format: - {"question": message, - "response": answer, - "context": context, + {"question": message, + "response": answer, + "context": context, "cypher": cypher} """ cypher_step = GraphQueryGenerationStep( @@ -84,7 +92,7 @@ def send_message(self, message: str) -> dict: ontology=self.ontology, last_answer=self.last_answer, cypher_prompt=self.cypher_prompt, - cypher_prompt_with_history=self.cypher_prompt_with_history + cypher_prompt_with_history=self.cypher_prompt_with_history, ) (context, cypher) = cypher_step.run(message) @@ -94,8 +102,8 @@ def send_message(self, message: str) -> dict: "question": message, "response": "I am sorry, I could not find the answer to your question", "context": None, - "cypher": None - } + "cypher": None, + } qa_step = QAStep( chat_session=self.qa_chat_session, @@ -104,7 +112,7 @@ def send_message(self, message: str) -> dict: answer = qa_step.run(message, cypher, context) self.last_answer = answer - + return { "question": message, "response": answer, diff --git a/graphrag_sdk/document_loaders/jsonl.py b/graphrag_sdk/document_loaders/jsonl.py index 1742b7c..1c35a31 100644 --- a/graphrag_sdk/document_loaders/jsonl.py +++ b/graphrag_sdk/document_loaders/jsonl.py @@ -18,10 +18,6 @@ def load(self) -> Iterator[Document]: num_documents = num_rows // self.rows_per_document for i in range(num_documents): content = "\n".join( - rows[ - i - * self.rows_per_document : (i + 1) - * self.rows_per_document - ] + rows[i * self.rows_per_document : (i + 1) * self.rows_per_document] ) yield Document(content, f"{self.path}#{i}") diff --git a/graphrag_sdk/document_loaders/pdf.py b/graphrag_sdk/document_loaders/pdf.py index c500002..7ad93b2 100644 --- a/graphrag_sdk/document_loaders/pdf.py +++ b/graphrag_sdk/document_loaders/pdf.py @@ -1,7 +1,8 @@ from typing import Iterator from graphrag_sdk.document import Document -class PDFLoader(): + +class PDFLoader: """ Load PDF """ @@ -15,11 +16,11 @@ def __init__(self, path: str) -> None: """ try: - import pypdf - except ImportError: - raise ImportError( - "pypdf package not found, please install it with " "`pip install pypdf`" - ) + __import__("pypdf") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "pypdf package not found, please install it with `pip install pypdf`" + ) self.path = path @@ -30,8 +31,8 @@ def load(self) -> Iterator[Document]: Returns: Iterator[Document]: document iterator """ - - from pypdf import PdfReader # pylint: disable=import-outside-toplevel + + from pypdf import PdfReader # pylint: disable=import-outside-toplevel reader = PdfReader(self.path) yield from [ diff --git a/graphrag_sdk/document_loaders/text.py b/graphrag_sdk/document_loaders/text.py index 79c3b26..a60ebe3 100644 --- a/graphrag_sdk/document_loaders/text.py +++ b/graphrag_sdk/document_loaders/text.py @@ -1,7 +1,8 @@ from typing import Iterator from graphrag_sdk.document import Document -class TextLoader(): + +class TextLoader: """ Load Text """ diff --git a/graphrag_sdk/document_loaders/url.py b/graphrag_sdk/document_loaders/url.py index 5a39e10..8ac6dd4 100644 --- a/graphrag_sdk/document_loaders/url.py +++ b/graphrag_sdk/document_loaders/url.py @@ -4,7 +4,8 @@ from bs4 import BeautifulSoup from graphrag_sdk.document import Document -class URLLoader(): + +class URLLoader: """ Load URL """ @@ -21,7 +22,7 @@ def __init__(self, url: str) -> None: def _download(self) -> str: try: - response = requests.get(self.url, headers={'User-Agent': 'Mozilla/5.0'}) + response = requests.get(self.url, headers={"User-Agent": "Mozilla/5.0"}) response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx) return response.text except requests.exceptions.RequestException as e: @@ -39,12 +40,12 @@ def load(self) -> Iterator[Document]: content = self._download() # extract text from HTML, populate content - soup = BeautifulSoup(content, 'html.parser') + soup = BeautifulSoup(content, "html.parser") # Extract text from the HTML content = soup.get_text() # Remove extra newlines - content = re.sub(r'\n{2,}', '\n', content) + content = re.sub(r"\n{2,}", "\n", content) yield Document(content, self.url) \ No newline at end of file diff --git a/graphrag_sdk/fixtures/regex.py b/graphrag_sdk/fixtures/regex.py index 3933692..759d12b 100644 --- a/graphrag_sdk/fixtures/regex.py +++ b/graphrag_sdk/fixtures/regex.py @@ -4,4 +4,4 @@ NODE_LABEL_REGEX = r"\(.+:(.*?)\)" -NODE_REGEX = r"\(.*?\)" \ No newline at end of file +NODE_REGEX = r"\(.*?\)" diff --git a/graphrag_sdk/helpers.py b/graphrag_sdk/helpers.py index 8c7693f..c6ab635 100644 --- a/graphrag_sdk/helpers.py +++ b/graphrag_sdk/helpers.py @@ -84,12 +84,12 @@ def stringify_falkordb_response(response: Union[list, str]) -> str: elif not isinstance(response[0], list): data = str(response).strip() else: - for l, _ in enumerate(response): - if not isinstance(response[l], list): - response[l] = str(response[l]) + for line, _ in enumerate(response): + if not isinstance(response[line], list): + response[line] = str(response[line]) else: - for i, __ in enumerate(response[l]): - response[l][i] = str(response[l][i]) + for i, __ in enumerate(response[line]): + response[line][i] = str(response[line][i]) data = str(response).strip() return data @@ -194,10 +194,10 @@ def validate_cypher_relations_exist(cypher: str, ontology: Ontology) -> list[str for relation in relation_labels: for label in relation.split("|"): max_idx = min( - label.index("*") if "*" in label else len(label), - label.index("{") if "{" in label else len(label), - label.index("]") if "]" in label else len(label), - ) + label.index("*") if "*" in label else len(label), + label.index("{") if "{" in label else len(label), + label.index("]") if "]" in label else len(label), + ) label = label[:max_idx] if label not in [relation.label for relation in ontology.relations]: not_found_relation_labels.append(label) diff --git a/graphrag_sdk/kg.py b/graphrag_sdk/kg.py index 752bbeb..7c14fdb 100644 --- a/graphrag_sdk/kg.py +++ b/graphrag_sdk/kg.py @@ -83,7 +83,10 @@ def __init__( cypher_system_instruction = CYPHER_GEN_SYSTEM else: if "{ontology}" not in cypher_system_instruction: - warnings.warn("Cypher system instruction should contain {ontology}", category=UserWarning) + warnings.warn( + "Cypher system instruction should contain {ontology}", + category=UserWarning, + ) if qa_system_instruction is None: qa_system_instruction = GRAPH_QA_SYSTEM @@ -106,9 +109,14 @@ def __init__( cypher_gen_prompt_history = CYPHER_GEN_PROMPT_WITH_HISTORY else: if "{question}" not in cypher_gen_prompt_history: - raise Exception("Cypher generation prompt with history should contain {question}") + raise Exception( + "Cypher generation prompt with history should contain {question}" + ) if "{last_answer}" not in cypher_gen_prompt_history: - warnings.warn("Cypher generation prompt with history should contain {last_answer}", category=UserWarning) + warnings.warn( + "Cypher generation prompt with history should contain {last_answer}", + category=UserWarning, + ) # Assign the validated values self.cypher_system_instruction = cypher_system_instruction diff --git a/graphrag_sdk/model_config.py b/graphrag_sdk/model_config.py index 2e2e584..a699850 100644 --- a/graphrag_sdk/model_config.py +++ b/graphrag_sdk/model_config.py @@ -47,7 +47,7 @@ def with_model(model: GenerativeModel) -> "KnowledgeGraphModelConfig": cypher_generation=model, qa=model, ) - + @staticmethod def from_json(json: dict) -> "KnowledgeGraphModelConfig": """ @@ -65,7 +65,7 @@ def from_json(json: dict) -> "KnowledgeGraphModelConfig": GenerativeModel.from_json(json["cypher_generation"]), GenerativeModel.from_json(json["qa"]), ) - + def to_json(self) -> dict: """ Converts the model configuration to a JSON dictionary. diff --git a/graphrag_sdk/models/__init__.py b/graphrag_sdk/models/__init__.py index ba10467..1032afe 100644 --- a/graphrag_sdk/models/__init__.py +++ b/graphrag_sdk/models/__init__.py @@ -13,4 +13,4 @@ "GenerativeModel", "GenerativeModelChatSession", "GenerativeModelConfig", -] \ No newline at end of file +] diff --git a/graphrag_sdk/models/azure_openai.py b/graphrag_sdk/models/azure_openai.py index 253a333..fbc2051 100644 --- a/graphrag_sdk/models/azure_openai.py +++ b/graphrag_sdk/models/azure_openai.py @@ -35,12 +35,12 @@ def __init__( self.model_name = model_name self.generation_config = generation_config or GenerativeModelConfig() self.system_instruction = system_instruction - + # Credentials self.api_key = os.getenv("AZURE_OPENAI_API_KEY") self.azure_endpoint = os.getenv("AZURE_ENDPOINT") self.api_version = os.getenv("AZURE_API_VERSION") - + if not self.api_key or not self.azure_endpoint or not self.api_version: raise ValueError( "Missing credentials in the environment: AZURE_OPENAI_API_KEY, AZURE_ENDPOINT, or AZURE_API_VERSION." @@ -144,7 +144,6 @@ class AzureOpenAiChatSession(GenerativeModelChatSession): A chat session for interacting with the Azure OpenAI model, maintaining conversation history. """ - _history = [] def __init__(self, model: AzureOpenAiGenerativeModel, system_instruction: Optional[str] = None): @@ -162,7 +161,9 @@ def __init__(self, model: AzureOpenAiGenerativeModel, system_instruction: Option else [] ) - def send_message(self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT) -> GenerationResponse: + def send_message( + self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT + ) -> GenerationResponse: """ Send a message in the chat session and receive the model's response. @@ -205,16 +206,16 @@ def _adjust_generation_config(self, output_method: OutputMethod): """ config = self._model.generation_config.to_json() if output_method == OutputMethod.JSON: - config['temperature'] = 0 - config['response_format'] = { "type": "json_object" } - + config["temperature"] = 0 + config["response_format"] = {"type": "json_object"} + return config - + def delete_last_message(self): """ Deletes the last message exchange (user message and assistant response) from the chat history. Preserves the system message if present. - + Example: Before: [ @@ -226,7 +227,7 @@ def delete_last_message(self): [ {"role": "system", "content": "System message"}, ] - + Note: Does nothing if the chat history is empty or contains only a system message. """ # Keep at least the system message if present @@ -237,7 +238,7 @@ def delete_last_message(self): else: # Reset to initial state with just system message if present self._history = ( - [{"role": "system", "content": self._model.system_instruction}] - if self._model.system_instruction is not None - else [] - ) + [{"role": "system", "content": self._model.system_instruction}] + if self._model.system_instruction is not None + else [] + ) diff --git a/graphrag_sdk/models/gemini.py b/graphrag_sdk/models/gemini.py index 31e4c11..7f4eac7 100644 --- a/graphrag_sdk/models/gemini.py +++ b/graphrag_sdk/models/gemini.py @@ -87,7 +87,8 @@ def parse_generate_content_response( == protos.Candidate.FinishReason.MAX_TOKENS else ( FinishReason.STOP - if response.candidates[0].finish_reason == protos.Candidate.FinishReason.STOP + if response.candidates[0].finish_reason + == protos.Candidate.FinishReason.STOP else FinishReason.OTHER ) ), @@ -166,17 +167,14 @@ def _adjust_generation_config(self, output_method: OutputMethod) -> dict: dict: The configuration settings for generation. """ if output_method == OutputMethod.JSON: - return { - "response_mime_type": "application/json", - "temperature": 0 - } + return {"response_mime_type": "application/json", "temperature": 0} return self._model._generation_config - + def delete_last_message(self): """ Deletes the last message exchange (user message and assistant response) from the chat history. Preserves the system message if present. - + Example: Before: [ @@ -193,4 +191,3 @@ def delete_last_message(self): self._chat_session.history.pop() else: self._chat_session.history = [] - diff --git a/graphrag_sdk/models/litellm.py b/graphrag_sdk/models/litellm.py index 723f78f..d59eb53 100644 --- a/graphrag_sdk/models/litellm.py +++ b/graphrag_sdk/models/litellm.py @@ -12,7 +12,8 @@ ) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) +logger.setLevel(logging.INFO) + class LiteModel(GenerativeModel): """ @@ -27,7 +28,7 @@ def __init__( ): """ Initialize the LiteModel with the required parameters. - + LiteLLM model_name format: / Examples: - openai/gpt-4o @@ -42,21 +43,20 @@ def __init__( """ env_val = validate_environment(model_name) - if not env_val['keys_in_environment']: + if not env_val["keys_in_environment"]: raise ValueError(f"Missing {env_val['missing_keys']} in the environment.") self.model_name, provider, _, _ = litellm_utils.get_llm_provider(model_name) self.model = model_name - + if provider == "ollama": self.ollama_client = Client() self.check_and_pull_model() if not self.check_valid_key(model_name): raise ValueError(f"Invalid keys for model {model_name}.") - self.generation_config = generation_config or GenerativeModelConfig() self.system_instruction = system_instruction - + def check_valid_key(self, model: str): """ Checks if the environment key is valid for a specific model by making a litellm.completion call with max_tokens=10 @@ -69,13 +69,11 @@ def check_valid_key(self, model: str): """ messages = [{"role": "user", "content": "Hey, how's it going?"}] try: - completion( - model=model, messages=messages, max_tokens=10 - ) + completion(model=model, messages=messages, max_tokens=10) return True - except: + except: # noqa: E722 return False - + def check_and_pull_model(self) -> None: """ Checks if the specified model is available locally, and pulls it if not. @@ -89,7 +87,9 @@ def check_and_pull_model(self) -> None: """ # Get the list of available models response = self.ollama_client.list() # This returns a dictionary - available_models = [model['name'] for model in response['models']] # Extract model names + available_models = [ + model["name"] for model in response["models"] + ] # Extract model names # Check if the model is already pulled if self.model_name in available_models: @@ -191,7 +191,9 @@ def __init__(self, model: LiteModel, system_instruction: Optional[str] = None): else [] ) - def send_message(self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT) -> GenerationResponse: + def send_message( + self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT + ) -> GenerationResponse: """ Send a message in the chat session and receive the model's response. @@ -208,14 +210,16 @@ def send_message(self, message: str, output_method: OutputMethod = OutputMethod. response = completion( model=self._model.model, messages=self._chat_history, - **generation_config + **generation_config, ) except Exception as e: - raise ValueError(f"Error during completion request, please check the credentials - {e}") + raise ValueError( + f"Error during completion request, please check the credentials - {e}" + ) content = self._model.parse_generate_content_response(response) self._chat_history.append({"role": "assistant", "content": content.text}) return content - + def _adjust_generation_config(self, output_method: OutputMethod): """ Adjust the generation configuration based on the specified output method. @@ -228,9 +232,9 @@ def _adjust_generation_config(self, output_method: OutputMethod): """ config = self._model.generation_config.to_json() if output_method == OutputMethod.JSON: - config['temperature'] = 0 - config['response_format'] = { "type": "json_object" } - + config["temperature"] = 0 + config["response_format"] = {"type": "json_object"} + return config def get_chat_history(self) -> list[dict]: @@ -246,7 +250,7 @@ def delete_last_message(self): """ Deletes the last message exchange (user message and assistant response) from the chat history. Preserves the system message if present. - + Example: Before: [ @@ -269,7 +273,7 @@ def delete_last_message(self): else: # Reset to initial state with just system message if present self._chat_history = ( - [{"role": "system", "content": self._model.system_instruction}] - if self._model.system_instruction is not None - else [] - ) + [{"role": "system", "content": self._model.system_instruction}] + if self._model.system_instruction is not None + else [] + ) diff --git a/graphrag_sdk/models/model.py b/graphrag_sdk/models/model.py index bef5943..c490469 100644 --- a/graphrag_sdk/models/model.py +++ b/graphrag_sdk/models/model.py @@ -8,9 +8,11 @@ class FinishReason: STOP = "STOP" OTHER = "OTHER" + class OutputMethod(Enum): - JSON = 'json' - DEFAULT = 'default' + JSON = "json" + DEFAULT = "default" + class GenerativeModelConfig: """ @@ -67,7 +69,6 @@ def from_json(json: dict) -> "GenerativeModelConfig": class GenerationResponse: - def __init__(self, text: str, finish_reason: FinishReason): self.text = text self.finish_reason = finish_reason @@ -88,7 +89,9 @@ def __init__(self, model: "GenerativeModel"): self.model = model @abstractmethod - def send_message(self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT) -> GenerationResponse: + def send_message( + self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT + ) -> GenerationResponse: pass diff --git a/graphrag_sdk/models/ollama.py b/graphrag_sdk/models/ollama.py index 2d72ba0..5fba55b 100644 --- a/graphrag_sdk/models/ollama.py +++ b/graphrag_sdk/models/ollama.py @@ -11,7 +11,8 @@ ) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) +logger.setLevel(logging.INFO) + class OllamaGenerativeModel(GenerativeModel): """ @@ -60,7 +61,9 @@ def check_and_pull_model(self) -> None: """ # Get the list of available models response = self.client.list() # This returns a dictionary - available_models = [model['name'] for model in response['models']] # Extract model names + available_models = [ + model["name"] for model in response["models"] + ] # Extract model names # Check if the model is already pulled if self.model_name in available_models: @@ -96,9 +99,8 @@ def parse_generate_content_response(self, response: any) -> GenerationResponse: GenerationResponse: Parsed response containing the generated text. """ return GenerationResponse( - text=response["message"]["content"], - finish_reason=FinishReason.STOP - ) + text=response["message"]["content"], finish_reason=FinishReason.STOP + ) def to_json(self) -> dict: """ @@ -132,6 +134,7 @@ def from_json(json: dict) -> "GenerativeModel": system_instruction=json["system_instruction"], ) + class OllamaChatSession(GenerativeModelChatSession): """ A chat session for interacting with the Ollama model, maintaining conversation history. @@ -151,7 +154,7 @@ def __init__(self, model: OllamaGenerativeModel, system_instruction: Optional[st if system_instruction is not None else [] ) - + def get_chat_history(self) -> list[dict]: """ Retrieve the conversation history for the current chat session. @@ -161,7 +164,9 @@ def get_chat_history(self) -> list[dict]: """ return self._chat_history.copy() - def send_message(self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT) -> GenerationResponse: + def send_message( + self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT + ) -> GenerationResponse: """ Send a message in the chat session and receive the model's response. @@ -177,12 +182,12 @@ def send_message(self, message: str, output_method: OutputMethod = OutputMethod. response = self._model.client.chat( model=self._model.model_name, messages=self._chat_history, - options=Options(**generation_config) + options=Options(**generation_config), ) content = self._model.parse_generate_content_response(response) self._chat_history.append({"role": "assistant", "content": content.text}) return content - + def _adjust_generation_config(self, output_method: OutputMethod) -> dict: """ Adjust the generation configuration based on the specified output method. @@ -195,16 +200,16 @@ def _adjust_generation_config(self, output_method: OutputMethod) -> dict: """ config = self._model.generation_config.to_json() if output_method == OutputMethod.JSON: - config['temperature'] = 0 - config['format'] = 'json' - + config["temperature"] = 0 + config["format"] = "json" + return config - + def delete_last_message(self): """ Deletes the last message exchange (user message and assistant response) from the chat history. Preserves the system message if present. - + Example: Before: [ @@ -227,7 +232,7 @@ def delete_last_message(self): else: # Reset to initial state with just system message if present self._chat_history = ( - [{"role": "system", "content": self._model.system_instruction}] - if self._model.system_instruction is not None - else [] - ) \ No newline at end of file + [{"role": "system", "content": self._model.system_instruction}] + if self._model.system_instruction is not None + else [] + ) diff --git a/graphrag_sdk/models/openai.py b/graphrag_sdk/models/openai.py index d87ac60..aca0919 100644 --- a/graphrag_sdk/models/openai.py +++ b/graphrag_sdk/models/openai.py @@ -158,16 +158,16 @@ def _adjust_generation_config(self, output_method: OutputMethod): """ config = self._model.generation_config.to_json() if output_method == OutputMethod.JSON: - config['temperature'] = 0 - config['response_format'] = { "type": "json_object" } - + config["temperature"] = 0 + config["response_format"] = {"type": "json_object"} + return config - + def delete_last_message(self): """ Deletes the last message exchange (user message and assistant response) from the chat history. Preserves the system message if present. - + Example: Before: [ @@ -190,7 +190,7 @@ def delete_last_message(self): else: # Reset to initial state with just system message if present self._history = ( - [{"role": "system", "content": self._model.system_instruction}] - if self._model.system_instruction is not None - else [] - ) + [{"role": "system", "content": self._model.system_instruction}] + if self._model.system_instruction is not None + else [] + ) diff --git a/graphrag_sdk/ontology.py b/graphrag_sdk/ontology.py index a771fe0..d640d47 100644 --- a/graphrag_sdk/ontology.py +++ b/graphrag_sdk/ontology.py @@ -318,7 +318,7 @@ def validate_entities(self) -> bool: f""" *** WARNING *** The following entities do not have unique attributes: -{', '.join(entities_without_unique_attributes)} +{", ".join(entities_without_unique_attributes)} """ ) return False diff --git a/graphrag_sdk/orchestrator/__init__.py b/graphrag_sdk/orchestrator/__init__.py index 3b34d10..25372e6 100644 --- a/graphrag_sdk/orchestrator/__init__.py +++ b/graphrag_sdk/orchestrator/__init__.py @@ -1,10 +1,5 @@ from .orchestrator import Orchestrator from .orchestrator_runner import OrchestratorRunner from .execution_plan import ExecutionPlan -from .step import StepResult, PlanStep, StepBlockType -__all__ = [ - 'Orchestrator', - 'ExecutionPlan', - 'OrchestratorRunner' -] +__all__ = ["Orchestrator", "ExecutionPlan", "OrchestratorRunner"] diff --git a/graphrag_sdk/orchestrator/step.py b/graphrag_sdk/orchestrator/step.py index 6c1ce11..de3d762 100644 --- a/graphrag_sdk/orchestrator/step.py +++ b/graphrag_sdk/orchestrator/step.py @@ -76,12 +76,13 @@ def from_json(json: Union[dict, str]) -> "PlanStep": """ json = json if isinstance(json, dict) else loads(json) from graphrag_sdk.orchestrator.steps import PLAN_STEP_TYPE_MAP + block = StepBlockType.from_str(json["block"]) step_type = PLAN_STEP_TYPE_MAP[block] if step_type is None: raise ValueError(f"Unknown step block type: {block}") - + return step_type.from_json(json) @abstractmethod diff --git a/graphrag_sdk/orchestrator/step_result.py b/graphrag_sdk/orchestrator/step_result.py index 575b4b0..de67111 100644 --- a/graphrag_sdk/orchestrator/step_result.py +++ b/graphrag_sdk/orchestrator/step_result.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -class StepResult(ABC): +class StepResult(ABC): @property @abstractmethod def output(self) -> str: diff --git a/graphrag_sdk/orchestrator/steps/__init__.py b/graphrag_sdk/orchestrator/steps/__init__.py index 3197453..3fe75e4 100644 --- a/graphrag_sdk/orchestrator/steps/__init__.py +++ b/graphrag_sdk/orchestrator/steps/__init__.py @@ -17,5 +17,5 @@ "ParallelStep", "SummaryStep", "UserInputStep", - "PLAN_STEP_TYPE_MAP" + "PLAN_STEP_TYPE_MAP", ] diff --git a/graphrag_sdk/orchestrator/steps/parallel.py b/graphrag_sdk/orchestrator/steps/parallel.py index 24e6cac..52f6584 100644 --- a/graphrag_sdk/orchestrator/steps/parallel.py +++ b/graphrag_sdk/orchestrator/steps/parallel.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor, wait from graphrag_sdk.orchestrator.step_result import StepResult from graphrag_sdk.orchestrator.orchestrator_runner import OrchestratorRunner +from graphrag_sdk.orchestrator.step import PlanStep class ParallelStepResult(StepResult): @@ -43,10 +44,7 @@ def from_json(json: dict) -> "ParallelStepResult": ParallelStepResult: An instance of ParallelStepResult. """ return ParallelStepResult( - [ - StepResult.from_json(result) - for result in json["results"] - ] + [StepResult.from_json(result) for result in json["results"]] ) def __str__(self) -> str: @@ -89,7 +87,7 @@ def from_json(json: dict) -> "ParallelProperties": """ return ParallelProperties( [ - graphrag_sdk.orchestrator.step.PlanStep.from_json(step) + PlanStep.from_json(step) for step in (json if isinstance(json, list) else json["steps"]) ] ) @@ -102,10 +100,10 @@ def to_json(self) -> dict: dict: A dictionary representation of the parallel properties. """ return {"steps": [step.to_json() for step in self.steps]} - + def __str__(self) -> str: return f"ParallelProperties(steps={self.steps})" - + def __repr__(self) -> str: return str(self) diff --git a/graphrag_sdk/orchestrator/steps/user_input.py b/graphrag_sdk/orchestrator/steps/user_input.py index 5106b45..9315ee9 100644 --- a/graphrag_sdk/orchestrator/steps/user_input.py +++ b/graphrag_sdk/orchestrator/steps/user_input.py @@ -96,10 +96,10 @@ def to_json(self) -> dict: return { "question": self.question, } - + def __str__(self) -> str: return f"UserInputProperties(question={self.question})" - + def __repr__(self) -> str: return str(self) diff --git a/graphrag_sdk/relation.py b/graphrag_sdk/relation.py index f8bd0ff..154bb5f 100644 --- a/graphrag_sdk/relation.py +++ b/graphrag_sdk/relation.py @@ -59,9 +59,9 @@ def to_json(self) -> dict: def __str__(self) -> str: """ Returns a string representation of the Relation object. - + The string representation includes the label of the Relation object. - + Returns: str: The string representation of the Relation object. """ diff --git a/graphrag_sdk/steps/create_ontology_step.py b/graphrag_sdk/steps/create_ontology_step.py index 3b31f1e..9686a54 100644 --- a/graphrag_sdk/steps/create_ontology_step.py +++ b/graphrag_sdk/steps/create_ontology_step.py @@ -5,7 +5,6 @@ from threading import Lock from typing import Optional from graphrag_sdk.steps.Step import Step -from graphrag_sdk.document import Document from graphrag_sdk.ontology import Ontology from graphrag_sdk.helpers import extract_json from ratelimit import limits, sleep_and_retry @@ -92,9 +91,12 @@ def run(self, boundaries: Optional[str] = None): """ tasks: list[Future[Ontology]] = [] - with tqdm(total=len(self.sources) + 1, desc="Process Documents", disable=self.hide_progress) as pbar: + with tqdm( + total=len(self.sources) + 1, + desc="Process Documents", + disable=self.hide_progress, + ) as pbar: with ThreadPoolExecutor(max_workers=self.config["max_workers"]) as executor: - # Process each source document in parallel for source in self.sources: task = executor.submit( @@ -112,14 +114,16 @@ def run(self, boundaries: Optional[str] = None): with self.counter_lock: pbar.n = self.process_files pbar.refresh() - + # Validate the ontology if len(self.ontology.entities) == 0: raise Exception("Failed to create ontology") - + # Finalize the ontology - task_fin = executor.submit(self._fix_ontology, self._create_chat(), self.ontology) - + task_fin = executor.submit( + self._fix_ontology, self._create_chat(), self.ontology + ) + # Wait for the final task to be completed while not task_fin.done(): time.sleep(RENDER_STEP_SIZE) @@ -151,12 +155,16 @@ def _process_source( """ try: document = next(source.load()) - + text = document.content[: self.config["max_input_tokens"]] user_message = CREATE_ONTOLOGY_PROMPT.format( - text = text, - boundaries = BOUNDARIES_PREFIX.format(user_boundaries=boundaries) if boundaries is not None else "", + text=text, + boundaries=( + BOUNDARIES_PREFIX.format(user_boundaries=boundaries) + if boundaries is not None + else "" + ), ) responses: list[GenerationResponse] = [] @@ -166,7 +174,10 @@ def _process_source( logger.debug(f"Model response: {responses[response_idx]}") - while responses[response_idx].finish_reason == FinishReason.MAX_TOKENS and response_idx < retries: + while ( + responses[response_idx].finish_reason == FinishReason.MAX_TOKENS + and response_idx < retries + ): response_idx += 1 responses.append(self._call_model(chat_session, "continue")) @@ -181,7 +192,7 @@ def _process_source( data = json.loads(extract_json(combined_text)) except json.decoder.JSONDecodeError as e: logger.debug(f"Error extracting JSON: {e}") - logger.debug(f"Prompting model to fix JSON") + logger.debug("Prompting model to fix JSON") json_fix_response = self._call_model( self._create_chat(), FIX_JSON_PROMPT.format(json=combined_text, error=str(e)), @@ -249,7 +260,7 @@ def _fix_ontology(self, chat_session: GenerativeModelChatSession, o: Ontology): data = json.loads(extract_json(combined_text)) except json.decoder.JSONDecodeError as e: logger.debug(f"Error extracting JSON: {e}") - logger.debug(f"Prompting model to fix JSON") + logger.debug("Prompting model to fix JSON") json_fix_response = self._call_model( self._create_chat(), FIX_JSON_PROMPT.format(json=combined_text, error=str(e)), @@ -263,7 +274,7 @@ def _fix_ontology(self, chat_session: GenerativeModelChatSession, o: Ontology): if data is None: return o - + try: new_ontology = Ontology.from_json(data) except Exception as e: diff --git a/graphrag_sdk/steps/extract_data_step.py b/graphrag_sdk/steps/extract_data_step.py index 6e8a122..ae6237e 100644 --- a/graphrag_sdk/steps/extract_data_step.py +++ b/graphrag_sdk/steps/extract_data_step.py @@ -8,7 +8,6 @@ from threading import Lock from typing import Optional from graphrag_sdk.steps.Step import Step -from graphrag_sdk.document import Document from ratelimit import limits, sleep_and_retry from graphrag_sdk.source import AbstractSource from graphrag_sdk.models.model import OutputMethod @@ -188,14 +187,27 @@ def _process_document( responses: list[GenerationResponse] = [] response_idx = 0 - responses.append(self._call_model(chat_session, user_message, output_method=OutputMethod.JSON)) + responses.append( + self._call_model( + chat_session, user_message, output_method=OutputMethod.JSON + ) + ) _task_logger.debug(f"Model response: {responses[response_idx].text}") - while responses[response_idx].finish_reason == FinishReason.MAX_TOKENS and response_idx < retries: + while ( + responses[response_idx].finish_reason == FinishReason.MAX_TOKENS + and response_idx < retries + ): _task_logger.debug("Asking model to continue") response_idx += 1 - responses.append(self._call_model(chat_session, COMPLETE_DATA_EXTRACTION, output_method=OutputMethod.JSON)) + responses.append( + self._call_model( + chat_session, + COMPLETE_DATA_EXTRACTION, + output_method=OutputMethod.JSON, + ) + ) _task_logger.debug( f"Model response after continue: {responses[response_idx].text}" ) @@ -215,7 +227,7 @@ def _process_document( data = json.loads(extract_json(last_respond)) except Exception as e: _task_logger.debug(f"Error extracting JSON: {e}") - _task_logger.debug(f"Prompting model to fix JSON") + _task_logger.debug("Prompting model to fix JSON") json_fix_response = self._call_model( self._create_chat(), FIX_JSON_PROMPT.format(json=last_respond, error=str(e)), @@ -229,7 +241,7 @@ def _process_document( f"Invalid data format. Missing entities or relations. {data}" ) raise Exception( - f"Invalid data format. Missing 'entities' or 'relations' in JSON." + "Invalid data format. Missing 'entities' or 'relations' in JSON." ) for entity in data["entities"]: try: @@ -244,7 +256,7 @@ def _process_document( except Exception as e: _task_logger.error(f"Error creating relation: {e}") continue - + except Exception as e: logger.exception(f"Task id: {task_id} failed - {e}") raise e diff --git a/graphrag_sdk/steps/graph_query_step.py b/graphrag_sdk/steps/graph_query_step.py index 0edffb2..a886e7e 100644 --- a/graphrag_sdk/steps/graph_query_step.py +++ b/graphrag_sdk/steps/graph_query_step.py @@ -66,10 +66,12 @@ def run(self, question: str, retries: Optional[int] = 10): for i in range(retries): try: cypher_prompt = ( - (self.cypher_prompt.format(question=question) + self.cypher_prompt.format(question=question) if self.last_answer is None - else self.cypher_prompt_with_history.format(question=question, last_answer=self.last_answer)) - ) + else self.cypher_prompt_with_history.format( + question=question, last_answer=self.last_answer + ) + ) logger.debug(f"Cypher Prompt: {cypher_prompt}") cypher_statement_response = self.chat_session.send_message( cypher_prompt, diff --git a/poetry.lock b/poetry.lock index 46b98eb..1670b9c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -359,6 +359,52 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "black" +version = "24.10.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.9" +files = [ + {file = "black-24.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6668650ea4b685440857138e5fe40cde4d652633b1bdffc62933d0db4ed9812"}, + {file = "black-24.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1c536fcf674217e87b8cc3657b81809d3c085d7bf3ef262ead700da345bfa6ea"}, + {file = "black-24.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:649fff99a20bd06c6f727d2a27f401331dc0cc861fb69cde910fe95b01b5928f"}, + {file = "black-24.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:fe4d6476887de70546212c99ac9bd803d90b42fc4767f058a0baa895013fbb3e"}, + {file = "black-24.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5a2221696a8224e335c28816a9d331a6c2ae15a2ee34ec857dcf3e45dbfa99ad"}, + {file = "black-24.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f9da3333530dbcecc1be13e69c250ed8dfa67f43c4005fb537bb426e19200d50"}, + {file = "black-24.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4007b1393d902b48b36958a216c20c4482f601569d19ed1df294a496eb366392"}, + {file = "black-24.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:394d4ddc64782e51153eadcaaca95144ac4c35e27ef9b0a42e121ae7e57a9175"}, + {file = "black-24.10.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b5e39e0fae001df40f95bd8cc36b9165c5e2ea88900167bddf258bacef9bbdc3"}, + {file = "black-24.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d37d422772111794b26757c5b55a3eade028aa3fde43121ab7b673d050949d65"}, + {file = "black-24.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14b3502784f09ce2443830e3133dacf2c0110d45191ed470ecb04d0f5f6fcb0f"}, + {file = "black-24.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:30d2c30dc5139211dda799758559d1b049f7f14c580c409d6ad925b74a4208a8"}, + {file = "black-24.10.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1cbacacb19e922a1d75ef2b6ccaefcd6e93a2c05ede32f06a21386a04cedb981"}, + {file = "black-24.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1f93102e0c5bb3907451063e08b9876dbeac810e7da5a8bfb7aeb5a9ef89066b"}, + {file = "black-24.10.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ddacb691cdcdf77b96f549cf9591701d8db36b2f19519373d60d31746068dbf2"}, + {file = "black-24.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:680359d932801c76d2e9c9068d05c6b107f2584b2a5b88831c83962eb9984c1b"}, + {file = "black-24.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:17374989640fbca88b6a448129cd1745c5eb8d9547b464f281b251dd00155ccd"}, + {file = "black-24.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:63f626344343083322233f175aaf372d326de8436f5928c042639a4afbbf1d3f"}, + {file = "black-24.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfa1d0cb6200857f1923b602f978386a3a2758a65b52e0950299ea014be6800"}, + {file = "black-24.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cd9c95431d94adc56600710f8813ee27eea544dd118d45896bb734e9d7a0dc7"}, + {file = "black-24.10.0-py3-none-any.whl", hash = "sha256:3bb2b7a1f7b685f85b11fed1ef10f8a9148bceb49853e47a294a3dd963c1dd7d"}, + {file = "black-24.10.0.tar.gz", hash = "sha256:846ea64c97afe3bc677b761787993be4991810ecc7a4a937816dd6bddedc4875"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.10)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "bleach" version = "6.2.0" @@ -597,7 +643,7 @@ files = [ name = "click" version = "8.1.8" description = "Composable command line interface toolkit" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, @@ -990,16 +1036,16 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1179,8 +1225,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extr google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev" proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" @@ -2356,6 +2402,17 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nbclient" version = "0.10.2" @@ -2655,9 +2712,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2714,6 +2771,17 @@ files = [ qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["docopt", "pytest"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -3819,6 +3887,33 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "ruff" +version = "0.9.1" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.9.1-py3-none-linux_armv6l.whl", hash = "sha256:84330dda7abcc270e6055551aca93fdde1b0685fc4fd358f26410f9349cf1743"}, + {file = "ruff-0.9.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:3cae39ba5d137054b0e5b472aee3b78a7c884e61591b100aeb544bcd1fc38d4f"}, + {file = "ruff-0.9.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:50c647ff96f4ba288db0ad87048257753733763b409b2faf2ea78b45c8bb7fcb"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0c8b149e9c7353cace7d698e1656ffcf1e36e50f8ea3b5d5f7f87ff9986a7ca"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:beb3298604540c884d8b282fe7625651378e1986c25df51dec5b2f60cafc31ce"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39d0174ccc45c439093971cc06ed3ac4dc545f5e8bdacf9f067adf879544d969"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:69572926c0f0c9912288915214ca9b2809525ea263603370b9e00bed2ba56dbd"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:937267afce0c9170d6d29f01fcd1f4378172dec6760a9f4dface48cdabf9610a"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:186c2313de946f2c22bdf5954b8dd083e124bcfb685732cfb0beae0c47233d9b"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f94942a3bb767675d9a051867c036655fe9f6c8a491539156a6f7e6b5f31831"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:728d791b769cc28c05f12c280f99e8896932e9833fef1dd8756a6af2261fd1ab"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2f312c86fb40c5c02b44a29a750ee3b21002bd813b5233facdaf63a51d9a85e1"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ae017c3a29bee341ba584f3823f805abbe5fe9cd97f87ed07ecbf533c4c88366"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5dc40a378a0e21b4cfe2b8a0f1812a6572fc7b230ef12cd9fac9161aa91d807f"}, + {file = "ruff-0.9.1-py3-none-win32.whl", hash = "sha256:46ebf5cc106cf7e7378ca3c28ce4293b61b449cd121b98699be727d40b79ba72"}, + {file = "ruff-0.9.1-py3-none-win_amd64.whl", hash = "sha256:342a824b46ddbcdddd3abfbb332fa7fcaac5488bf18073e841236aadf4ad5c19"}, + {file = "ruff-0.9.1-py3-none-win_arm64.whl", hash = "sha256:1cd76c7f9c679e6e8f2af8f778367dca82b95009bc7b1a85a47f1521ae524fa7"}, + {file = "ruff-0.9.1.tar.gz", hash = "sha256:fd2b25ecaf907d6458fa842675382c8597b3c746a2dde6717fe3415425df0c17"}, +] + [[package]] name = "send2trash" version = "1.8.3" @@ -4652,4 +4747,4 @@ vertexai = ["vertexai"] [metadata] lock-version = "2.0" python-versions = "^3.9.0" -content-hash = "d53e56537233e2b2c26621582df70c2adbba414af24a25ccecfa8d430e692b95" +content-hash = "d4e5b51292ec888a9dae2d7d9e8bcc8122222a568e42a5505b64fa02b2544413" diff --git a/pyproject.toml b/pyproject.toml index 325c77b..4a28881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ sphinx = "^7.3.7" sphinx-rtd-theme = "^2.0.0" pandas = "^2.2.2" jupyter = "^1.0.0" +ruff = "^0.9.1" +black = "^24.10.0" [build-system] requires = ["poetry-core"] diff --git a/tests/test_auto_create_ontology.py b/tests/test_auto_create_ontology.py index 4a9ec6d..945a94d 100644 --- a/tests/test_auto_create_ontology.py +++ b/tests/test_auto_create_ontology.py @@ -1,4 +1,5 @@ from dotenv import load_dotenv + load_dotenv() from graphrag_sdk.ontology import Ontology import unittest @@ -19,7 +20,6 @@ class TestAutoDetectOntology(unittest.TestCase): """ def test_auto_detect_ontology(self): - file_path = "tests/data/madoff.txt" sources = [Source(file_path)] diff --git a/tests/test_helper_validate_cypher.py b/tests/test_helper_validate_cypher.py index 85429d3..5fead53 100644 --- a/tests/test_helper_validate_cypher.py +++ b/tests/test_helper_validate_cypher.py @@ -28,7 +28,6 @@ class TestValidateCypher1(unittest.TestCase): @classmethod def setUpClass(cls): - cls._ontology = Ontology() cls._ontology.add_entity( @@ -55,19 +54,16 @@ def setUpClass(cls): ) def test_validate_cypher_entities_exist(self): - errors = validate_cypher_entities_exist(self.cypher, self._ontology) assert len(errors) == 0 def test_validate_cypher_relations_exist(self): - errors = validate_cypher_relations_exist(self.cypher, self._ontology) assert len(errors) == 0 def test_validate_cypher_relation_directions(self): - errors = validate_cypher_relation_directions(self.cypher, self._ontology) assert len(errors) == 0 @@ -80,7 +76,7 @@ def test_validate_cypher(self): class TestValidateCypher2(unittest.TestCase): """ - Test a cypher query with the wrong relation direction + Test a cypher query with the wrong relation direction """ cypher = """ @@ -89,7 +85,6 @@ class TestValidateCypher2(unittest.TestCase): @classmethod def setUpClass(cls): - cls._ontology = Ontology([], []) cls._ontology.add_entity( @@ -116,19 +111,16 @@ def setUpClass(cls): ) def test_validate_cypher_entities_exist(self): - errors = validate_cypher_entities_exist(self.cypher, self._ontology) assert len(errors) == 0 def test_validate_cypher_relations_exist(self): - errors = validate_cypher_relations_exist(self.cypher, self._ontology) assert len(errors) == 0 def test_validate_cypher_relation_directions(self): - errors = validate_cypher_relation_directions(self.cypher, self._ontology) assert len(errors) == 1 @@ -151,7 +143,6 @@ class TestValidateCypher3(unittest.TestCase): @classmethod def setUpClass(cls): - cls._ontology = Ontology([], []) cls._ontology.add_entity( @@ -201,19 +192,16 @@ def setUpClass(cls): ) def test_validate_cypher_nodes_exist(self): - errors = validate_cypher_entities_exist(self.cypher, self._ontology) assert len(errors) == 0 def test_validate_cypher_edges_exist(self): - errors = validate_cypher_relations_exist(self.cypher, self._ontology) assert len(errors) == 0 def test_validate_cypher_edge_directions(self): - errors = validate_cypher_relation_directions(self.cypher, self._ontology) assert len(errors) == 0 diff --git a/tests/test_kg.py b/tests/test_kg.py index 3398eae..b508f63 100644 --- a/tests/test_kg.py +++ b/tests/test_kg.py @@ -21,7 +21,6 @@ class TestKG(unittest.TestCase): @classmethod def setUpClass(cls): - cls.ontology = Ontology([], []) cls.ontology.add_entity( diff --git a/tests/test_kg_gemini.py b/tests/test_kg_gemini.py index adf63e2..24671b3 100644 --- a/tests/test_kg_gemini.py +++ b/tests/test_kg_gemini.py @@ -24,7 +24,6 @@ class TestKGGemini(unittest.TestCase): @classmethod def setUpClass(cls): - cls.ontology = Ontology([], []) cls.ontology.add_entity( @@ -84,14 +83,14 @@ def test_kg_creation(self): sources = [Source(file_path)] self.kg.process_sources(sources) - + chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") - answer = answer['response'] + answer = answer["response"] logger.info(f"Answer: {answer}") - actors_count = re.findall(r'\d+', answer) + actors_count = re.findall(r"\d+", answer) num_actors = 0 if len(actors_count) == 0 else int(actors_count[0]) assert num_actors > 10, "The number of actors found should be greater than 10" diff --git a/tests/test_kg_litellm_openai.py b/tests/test_kg_litellm_openai.py index 253d7bf..d78c0ba 100644 --- a/tests/test_kg_litellm_openai.py +++ b/tests/test_kg_litellm_openai.py @@ -1,7 +1,6 @@ import re import logging import unittest -from falkordb import FalkorDB from dotenv import load_dotenv from graphrag_sdk.entity import Entity from graphrag_sdk.source import Source_FromRawText @@ -16,6 +15,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + class TestKGLiteLLM(unittest.TestCase): """ Test Knowledge Graph @@ -23,7 +23,6 @@ class TestKGLiteLLM(unittest.TestCase): @classmethod def setUpClass(cls): - cls.ontology = Ontology([], []) cls.ontology.add_entity( @@ -76,7 +75,6 @@ def setUpClass(cls): ) def test_kg_creation(self): - file_path = "tests/data/madoff.txt" with open(file_path) as f: string = f.read() @@ -87,11 +85,11 @@ def test_kg_creation(self): chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") - answer = answer['response'] + answer = answer["response"] logger.info(f"Answer: {answer}") - actors_count = re.findall(r'\d+', answer) + actors_count = re.findall(r"\d+", answer) num_actors = 0 if len(actors_count) == 0 else int(actors_count[0]) - assert num_actors > 10, "The number of actors found should be greater than 10" \ No newline at end of file + assert num_actors > 10, "The number of actors found should be greater than 10" diff --git a/tests/test_kg_ollama.py b/tests/test_kg_ollama.py index 06ec03e..770943f 100644 --- a/tests/test_kg_ollama.py +++ b/tests/test_kg_ollama.py @@ -10,7 +10,11 @@ from graphrag_sdk.attribute import Attribute, AttributeType from graphrag_sdk.models.ollama import OllamaGenerativeModel from graphrag_sdk.models.openai import OpenAiGenerativeModel -from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig, GenerativeModelConfig +from graphrag_sdk import ( + KnowledgeGraph, + KnowledgeGraphModelConfig, + GenerativeModelConfig, +) load_dotenv() @@ -18,6 +22,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + class TestKGOllama(unittest.TestCase): """ Test Knowledge Graph @@ -25,7 +30,6 @@ class TestKGOllama(unittest.TestCase): @classmethod def setUpClass(cls): - cls.ontology = Ontology([], []) cls.ontology.add_entity( @@ -72,17 +76,23 @@ def setUpClass(cls): cls.graph_name = "IMDB_ollama" - model_ollama = OllamaGenerativeModel(model_name="llama3:8b", generation_config=GenerativeModelConfig(temperature=0)) + model_ollama = OllamaGenerativeModel( + model_name="llama3:8b", + generation_config=GenerativeModelConfig(temperature=0), + ) model_openai = OpenAiGenerativeModel(model_name="gpt-3.5-turbo") cls.kg = KnowledgeGraph( name=cls.graph_name, ontology=cls.ontology, - model_config=KnowledgeGraphModelConfig(extract_data=model_openai, cypher_generation=model_ollama, qa=model_ollama), + model_config=KnowledgeGraphModelConfig( + extract_data=model_openai, + cypher_generation=model_ollama, + qa=model_ollama, + ), ) def test_kg_creation(self): - file_path = "tests/data/madoff.txt" sources = [Source(file_path)] @@ -91,17 +101,16 @@ def test_kg_creation(self): chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") - answer = answer['response'] + answer = answer["response"] logger.info(f"Answer: {answer}") - actors_count = re.findall(r'\d+', answer) + actors_count = re.findall(r"\d+", answer) num_actors = 0 if len(actors_count) == 0 else int(actors_count[0]) assert num_actors > 5, "The number of actors found should be greater than 5" def test_kg_delete(self): - self.kg.delete() db = FalkorDB() diff --git a/tests/test_kg_openai.py b/tests/test_kg_openai.py index 785a28d..68c3aae 100644 --- a/tests/test_kg_openai.py +++ b/tests/test_kg_openai.py @@ -16,6 +16,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + class TestKGOpenAI(unittest.TestCase): """ Test Knowledge Graph @@ -23,7 +24,6 @@ class TestKGOpenAI(unittest.TestCase): @classmethod def setUpClass(cls): - cls.ontology = Ontology([], []) cls.ontology.add_entity( @@ -76,7 +76,6 @@ def setUpClass(cls): ) def test_kg_creation(self): - file_path = "tests/data/madoff.txt" sources = [Source(file_path)] @@ -85,11 +84,11 @@ def test_kg_creation(self): chat = self.kg.chat_session() answer = chat.send_message("How many actors acted in a movie?") - answer = answer['response'] + answer = answer["response"] logger.info(f"Answer: {answer}") - actors_count = re.findall(r'\d+', answer) + actors_count = re.findall(r"\d+", answer) num_actors = 0 if len(actors_count) == 0 else int(actors_count[0]) assert num_actors > 10, "The number of actors found should be greater than 10" diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 02ae597..fbdb92c 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -17,10 +17,8 @@ class TestMultiAgent(unittest.TestCase): - @classmethod def setUpClass(cls): - cls.restaurants_ontology = Ontology() cls.restaurants_ontology.add_entity( Entity( @@ -306,7 +304,6 @@ def import_data( ) def test_multi_agent(self): - response = self.orchestrator.ask( "Write me a two-day itinerary for a trip to Rome. Do not ask any questions to me, just provide your best itinerary." ) @@ -314,7 +311,6 @@ def test_multi_agent(self): print(response) assert response is not None - assert ( "itinerary" in response.output.lower() or "day" in response.output.lower() ), f"Response should contain the 'itinerary' or 'day' string: {response.output}"