From b8eefdef92ce1aaed3bd08da67c3a38db2ca6af4 Mon Sep 17 00:00:00 2001 From: Aman Rusia Date: Thu, 26 Dec 2024 20:59:52 +0530 Subject: [PATCH 1/3] Optimisations in repo context --- src/wcgw/client/repo_ops/path_prob.py | 45 +++++--- src/wcgw/client/repo_ops/repo_context.py | 126 ++++++++++++++++++----- 2 files changed, 126 insertions(+), 45 deletions(-) diff --git a/src/wcgw/client/repo_ops/path_prob.py b/src/wcgw/client/repo_ops/path_prob.py index 14383e2..35eaf94 100644 --- a/src/wcgw/client/repo_ops/path_prob.py +++ b/src/wcgw/client/repo_ops/path_prob.py @@ -20,28 +20,39 @@ def __init__(self, model_path: str, vocab_path: str) -> None: self.encoder = tokenizers.Tokenizer.from_file(model_path) - def tokenize(self, text: str) -> List[str]: - """Tokenize text using the vocabulary.""" - return self.encoder.encode(text).tokens # type: ignore[no-any-return] + def tokenize_batch(self, texts: List[str]) -> List[List[str]]: + """Tokenize multiple texts at once.""" + encodings = self.encoder.encode_batch(texts) + return [encoding.tokens for encoding in encodings] # type: ignore[no-any-return] def detokenize(self, tokens: List[str]) -> str: """Convert tokens back to text, handling special tokens.""" return self.encoder.decode(tokens) # type: ignore[no-any-return] + def calculate_path_probabilities_batch( + self, paths: List[str] + ) -> List[Tuple[float, List[str], List[str]]]: + """Calculate log probability for multiple paths at once.""" + # Batch tokenize all paths + all_tokens = self.tokenize_batch(paths) + + results = [] + for tokens in all_tokens: + # Calculate sum of log probabilities for each path + log_prob_sum = 0.0 + unknown_tokens = [] + for token in tokens: + if token in self.vocab_probs: + log_prob_sum += self.vocab_probs[token] + else: + unknown_tokens.append(token) + + results.append((log_prob_sum, tokens, unknown_tokens)) + + return results + def calculate_path_probability( self, path: str ) -> Tuple[float, List[str], List[str]]: - """Calculate log probability for a given path.""" - # Tokenize the path - tokens = self.tokenize(path) - - # Calculate sum of log probabilities - log_prob_sum = 0.0 - unknown_tokens = [] - for token in tokens: - if token in self.vocab_probs: - log_prob_sum += self.vocab_probs[token] - else: - unknown_tokens.append(token) - - return log_prob_sum, tokens, unknown_tokens + """Calculate log probability for a single path.""" + return self.calculate_path_probabilities_batch([path])[0] diff --git a/src/wcgw/client/repo_ops/repo_context.py b/src/wcgw/client/repo_ops/repo_context.py index 281f98b..4c2f0e7 100644 --- a/src/wcgw/client/repo_ops/repo_context.py +++ b/src/wcgw/client/repo_ops/repo_context.py @@ -1,4 +1,6 @@ -from pathlib import Path +import os +from collections import deque +from pathlib import Path # Still needed for other parts from typing import Optional from pygit2 import GitError, Repository @@ -22,32 +24,71 @@ def find_ancestor_with_git(path: Path) -> Optional[Repository]: return None -def get_all_files_max_depth( - folder: Path, +MAX_FILES_CHECK = 100_000 + + +def _get_all_files_max_depth( + abs_folder: str, max_depth: int, - rel_to: str, repo: Optional[Repository], - current_depth: int, -) -> list[str]: - if current_depth > max_depth: - return [] - + current_depth: int = 0, + rel_path_prefix: str = "", + files_found: int = 0, +) -> tuple[list[str], int]: + """BFS implementation using deque that maintains relative paths during traversal. + Returns (files_list, total_files_found) to track file count.""" all_files = [] - for child in folder.iterdir(): - rel_path = str(child.relative_to(rel_to)) - if repo and repo.path_is_ignored(rel_path): + # Queue stores: (folder_path, depth, rel_path_prefix) + queue = deque([(abs_folder, current_depth, rel_path_prefix)]) + + while queue and files_found < MAX_FILES_CHECK: + current_folder, depth, prefix = queue.popleft() + + if depth > max_depth: continue - if child.is_file(): - all_files.append(rel_path) - elif child.is_dir(): - all_files.extend( - get_all_files_max_depth( - child, max_depth, rel_to, repo, current_depth + 1 - ) - ) + try: + entries = list(os.scandir(current_folder)) + except PermissionError: + continue + except OSError: + continue + # Split into files and folders with single scan + files = [] + folders = [] + for entry in entries: + is_file = entry.is_file(follow_symlinks=False) + name = entry.name + rel_path = f"{prefix}{name}" if prefix else name + + if repo and repo.path_is_ignored(rel_path): + continue + + if is_file: + files.append(rel_path) + else: + folders.append((entry.path, rel_path)) + + # Process files first (maintain priority) + chunk = files[: min(10_000, MAX_FILES_CHECK - files_found)] + all_files.extend(chunk) + files_found += len(chunk) + + # Add folders to queue for BFS traversal + for folder_path, folder_rel_path in folders: + next_prefix = f"{folder_rel_path}/" + queue.append((folder_path, depth + 1, next_prefix)) - return all_files + return all_files, files_found + + +def get_all_files_max_depth( + abs_folder: str, + max_depth: int, + repo: Optional[Repository], +) -> list[str]: + """Public interface that expects absolute paths.""" + return _get_all_files_max_depth(abs_folder, max_depth, repo)[0] def get_repo_context(file_or_repo_path: str, max_files: int) -> tuple[str, Path]: @@ -63,13 +104,16 @@ def get_repo_context(file_or_repo_path: str, max_files: int) -> tuple[str, Path] else: context_dir = file_or_repo_path_ - all_files = get_all_files_max_depth(context_dir, 10, str(context_dir), repo, 0) + all_files = get_all_files_max_depth(str(context_dir), 10, repo) + + # Calculate probabilities in batch + path_scores = PATH_SCORER.calculate_path_probabilities_batch(all_files) - sorted_files = sorted( - all_files, - key=lambda x: PATH_SCORER.calculate_path_probability(x)[0], - reverse=True, - ) + # Create list of (path, score) tuples and sort by score + path_with_scores = list(zip(all_files, (score[0] for score in path_scores))) + sorted_files = [ + path for path, _ in sorted(path_with_scores, key=lambda x: x[1], reverse=True) + ] top_files = sorted_files[:max_files] @@ -82,7 +126,33 @@ def get_repo_context(file_or_repo_path: str, max_files: int) -> tuple[str, Path] if __name__ == "__main__": + import cProfile + import pstats import sys + from line_profiler import LineProfiler + folder = sys.argv[1] - print(get_repo_context(folder, 200)[0]) + + # Profile using cProfile for overall function statistics + profiler = cProfile.Profile() + profiler.enable() + result = get_repo_context(folder, 200)[0] + profiler.disable() + + # Print cProfile stats + stats = pstats.Stats(profiler) + stats.sort_stats("cumulative") + print("\n=== Function-level profiling ===") + stats.print_stats(20) # Print top 20 functions + + # Profile using line_profiler for line-by-line statistics + lp = LineProfiler() + lp_wrapper = lp(get_repo_context) + lp_wrapper(folder, 200) + + print("\n=== Line-by-line profiling ===") + lp.print_stats() + + print("\n=== Result ===") + print(result) From e4e1a7a4eb5f041b3bd6ed64fae5610ae8fed7f4 Mon Sep 17 00:00:00 2001 From: Aman Rusia Date: Thu, 26 Dec 2024 21:26:09 +0530 Subject: [PATCH 2/3] Optimisations --- src/wcgw/client/repo_ops/display_tree.py | 2 ++ src/wcgw/client/repo_ops/repo_context.py | 36 +++++++++--------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/src/wcgw/client/repo_ops/display_tree.py b/src/wcgw/client/repo_ops/display_tree.py index bcc7b54..9d59bcc 100644 --- a/src/wcgw/client/repo_ops/display_tree.py +++ b/src/wcgw/client/repo_ops/display_tree.py @@ -49,6 +49,8 @@ def expand(self, rel_path: str) -> None: while str(current) >= str(self.root): if current not in self.expanded_dirs: self.expanded_dirs[current] = self._list_directory(current) + if current == current.parent: + break current = current.parent def _list_directory(self, dir_path: Path) -> List[Path]: diff --git a/src/wcgw/client/repo_ops/repo_context.py b/src/wcgw/client/repo_ops/repo_context.py index 4c2f0e7..88dccf1 100644 --- a/src/wcgw/client/repo_ops/repo_context.py +++ b/src/wcgw/client/repo_ops/repo_context.py @@ -24,24 +24,21 @@ def find_ancestor_with_git(path: Path) -> Optional[Repository]: return None -MAX_FILES_CHECK = 100_000 +MAX_ENTRIES_CHECK = 100_000 -def _get_all_files_max_depth( +def get_all_files_max_depth( abs_folder: str, max_depth: int, repo: Optional[Repository], - current_depth: int = 0, - rel_path_prefix: str = "", - files_found: int = 0, -) -> tuple[list[str], int]: +) -> list[str]: """BFS implementation using deque that maintains relative paths during traversal. Returns (files_list, total_files_found) to track file count.""" all_files = [] # Queue stores: (folder_path, depth, rel_path_prefix) - queue = deque([(abs_folder, current_depth, rel_path_prefix)]) - - while queue and files_found < MAX_FILES_CHECK: + queue = deque([(abs_folder, 0, "")]) + entries_check = 0 + while queue and entries_check < MAX_ENTRIES_CHECK: current_folder, depth, prefix = queue.popleft() if depth > max_depth: @@ -57,7 +54,11 @@ def _get_all_files_max_depth( files = [] folders = [] for entry in entries: - is_file = entry.is_file(follow_symlinks=False) + entries_check += 1 + try: + is_file = entry.is_file(follow_symlinks=False) + except OSError: + continue name = entry.name rel_path = f"{prefix}{name}" if prefix else name @@ -70,25 +71,15 @@ def _get_all_files_max_depth( folders.append((entry.path, rel_path)) # Process files first (maintain priority) - chunk = files[: min(10_000, MAX_FILES_CHECK - files_found)] + chunk = files[: min(10_000, max(0, MAX_ENTRIES_CHECK - entries_check))] all_files.extend(chunk) - files_found += len(chunk) # Add folders to queue for BFS traversal for folder_path, folder_rel_path in folders: next_prefix = f"{folder_rel_path}/" queue.append((folder_path, depth + 1, next_prefix)) - return all_files, files_found - - -def get_all_files_max_depth( - abs_folder: str, - max_depth: int, - repo: Optional[Repository], -) -> list[str]: - """Public interface that expects absolute paths.""" - return _get_all_files_max_depth(abs_folder, max_depth, repo)[0] + return all_files def get_repo_context(file_or_repo_path: str, max_files: int) -> tuple[str, Path]: @@ -118,7 +109,6 @@ def get_repo_context(file_or_repo_path: str, max_files: int) -> tuple[str, Path] top_files = sorted_files[:max_files] directory_printer = DirectoryTree(context_dir, max_files=max_files) - for file in top_files: directory_printer.expand(file) From 7ebf9e1d153ad5952c95402df29de34e0ce6564d Mon Sep 17 00:00:00 2001 From: Aman Rusia Date: Thu, 26 Dec 2024 22:36:15 +0530 Subject: [PATCH 3/3] fixes and improvements --- src/wcgw/client/repo_ops/display_tree.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/wcgw/client/repo_ops/display_tree.py b/src/wcgw/client/repo_ops/display_tree.py index 9d59bcc..139c7ca 100644 --- a/src/wcgw/client/repo_ops/display_tree.py +++ b/src/wcgw/client/repo_ops/display_tree.py @@ -1,7 +1,6 @@ import io -from collections import defaultdict from pathlib import Path -from typing import Dict, List, Set +from typing import List, Set class DirectoryTree: @@ -16,7 +15,7 @@ def __init__(self, root: Path, max_files: int = 10): self.root = root self.max_files = max_files self.expanded_files: Set[Path] = set() - self.expanded_dirs: Dict[Path, List[Path]] = defaultdict(list) + self.expanded_dirs = set[Path]() if not self.root.exists(): raise ValueError(f"Root path {root} does not exist") @@ -48,7 +47,7 @@ def expand(self, rel_path: str) -> None: current = abs_path.parent while str(current) >= str(self.root): if current not in self.expanded_dirs: - self.expanded_dirs[current] = self._list_directory(current) + self.expanded_dirs.add(current) if current == current.parent: break current = current.parent