Skip to content

Class to facilitate embedding generated documents to ChromaDb #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,10 @@ __pycache__
# Ignore everything in the generated directory
/generated/*

# Ignore the locally persisted chromadb
.chroma_db/
**/chroma-embeddings.parquet

# Don't ignore .gitkeep files in the generated directory
!/generated/.gitkeep

5 changes: 5 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import os


EXTENSION_TO_SKIP = [".png",".jpg",".jpeg",".gif",".bmp",".svg",".ico",".tif",".tiff"]
DEFAULT_DIR = "generated"
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_DIR_FULL_PATH = os.path.join(ROOT_DIR, DEFAULT_DIR)
DEFAULT_MODEL = "gpt-3.5-turbo" # we recommend 'gpt-4' if you have it # gpt3.5 is going to be worse at generating code so we strongly recommend gpt4. i know most people dont have access, we are working on a hosted version
DEFAULT_MAX_TOKENS = 2000 # i wonder how to tweak this properly. we dont want it to be max length as it encourages verbosity of code. but too short and code also truncates suddenly.
Empty file added src/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions src/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

from termcolor import colored
import chromadb

from constants import ROOT_DIR
from src.traversal import traverse_dir


class Embeddings():

EXCLUDE_PATTERNS = ["*.png","*.jpg","*.jpeg","*.gif","*.bmp","*.svg","*.ico","*.tif","*.tiff"]

def __init__(self, debug=False):
self.debug = debug
self.GENERATED_FILES_COLLECTION_NAME = "generated_files"

DB_DIR = os.path.join(ROOT_DIR, ".chroma_db")
# Create the embedding database directory if it doesn't exist
if not os.path.exists(DB_DIR):
os.makedirs(DB_DIR)

self.client = chromadb.Client(
chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=DB_DIR,
)
)

self.ensure_generated_files_collection_exists()

def ensure_generated_files_collection_exists(self):
# Create the generated files collection if it doesn't exist
self.generated_files_collection = self.client.get_or_create_collection(self.GENERATED_FILES_COLLECTION_NAME)


def persist_generated_file_contents(self, reset=False):
if reset:
self.generated_files_collection.reset()
self.ensure_generated_files_collection_exists()

# Iterate over all files in the generated directory
file_paths_list = []
file_contents_list = []
metadatas_list = []
for file_path in traverse_dir("generated", exclude_patterns=self.EXCLUDE_PATTERNS):
file_paths_list.append(file_path)
if self.debug:
print("embedding: " + colored(file_path, 'green'))
# Read the file
with open(file_path, "r") as file:
file_contents_list.append(file.read())
# Get the filename
filename = os.path.basename(file_path)
# Get the extension
extension = filename.split(".")[-1]
metadatas_list.append({
"filename": filename,
"extension": extension,
})

# Upsert the file into the database
self.generated_files_collection.upsert(
documents=file_contents_list,
metadatas=metadatas_list,
ids=file_paths_list,
)
if self.debug:
print(colored("persisted embeddings for %s files." % len(file_paths_list), 'yellow'))


23 changes: 23 additions & 0 deletions src/traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import fnmatch

def traverse_dir(root_dir, include_patterns=None, exclude_patterns=None):
"""
# Example usage:
root_dir = '/path/to/directory'
include_patterns = ['*.txt', '*.py'] # Include files matching these patterns
exclude_patterns = ['exclude_dir1/*', 'exclude_dir2/*'] # Exclude directories matching these patterns

for file_path in traverse_dir(root_dir, include_patterns=include_patterns, exclude_patterns=exclude_patterns):
print(file_path)
"""
for root, dirs, files in os.walk(root_dir):
if exclude_patterns and any(fnmatch.fnmatch(root, pattern) for pattern in exclude_patterns):
continue
for file in files:
file_path = os.path.join(root, file)
if include_patterns and not any(fnmatch.fnmatch(file_path, pattern) for pattern in include_patterns):
continue
yield file_path