1
1
import functools
2
2
import time
3
- from multiprocessing import get_context , Barrier , Process , Queue
3
+ from multiprocessing import get_context
4
4
from typing import Iterable , List , Optional , Tuple
5
5
import itertools
6
6
@@ -65,10 +65,6 @@ def search_all(
65
65
):
66
66
parallel = self .search_params .get ("parallel" , 1 )
67
67
top = self .search_params .get ("top" , None )
68
-
69
- # Convert queries to a list to calculate its length
70
- queries = list (queries ) # This allows us to calculate len(queries)
71
-
72
68
# setup_search may require initialized client
73
69
self .init_client (
74
70
self .host , distance , self .connection_params , self .search_params
@@ -84,56 +80,31 @@ def search_all(
84
80
print (f"Limiting queries to [0:{ MAX_QUERIES - 1 } ]" )
85
81
86
82
if parallel == 1 :
87
- # Single-threaded execution
88
83
start = time .perf_counter ()
89
-
90
- results = [search_one (query ) for query in tqdm .tqdm (queries )]
91
- total_time = time .perf_counter () - start
92
-
84
+ precisions , latencies = list (
85
+ zip (* [search_one (query ) for query in tqdm .tqdm (used_queries )])
86
+ )
93
87
else :
94
- # Dynamically calculate chunk size
95
- chunk_size = max (1 , len (queries ) // parallel )
96
- query_chunks = list (chunked_iterable (queries , chunk_size ))
88
+ ctx = get_context (self .get_mp_start_method ())
97
89
98
- # Function to be executed by each worker process
99
- def worker_function (chunk , result_queue ):
100
- self .__class__ .init_client (
90
+ with ctx .Pool (
91
+ processes = parallel ,
92
+ initializer = self .__class__ .init_client ,
93
+ initargs = (
101
94
self .host ,
102
95
distance ,
103
96
self .connection_params ,
104
97
self .search_params ,
98
+ ),
99
+ ) as pool :
100
+ if parallel > 10 :
101
+ time .sleep (15 ) # Wait for all processes to start
102
+ start = time .perf_counter ()
103
+ precisions , latencies = list (
104
+ zip (* pool .imap_unordered (search_one , iterable = tqdm .tqdm (used_queries )))
105
105
)
106
- self .setup_search ()
107
- results = process_chunk (chunk , search_one )
108
- result_queue .put (results )
109
-
110
- # Create a queue to collect results
111
- result_queue = Queue ()
112
-
113
- # Create and start worker processes
114
- processes = []
115
- for chunk in query_chunks :
116
- process = Process (target = worker_function , args = (chunk , result_queue ))
117
- processes .append (process )
118
- process .start ()
119
-
120
- # Start measuring time for the critical work
121
- start = time .perf_counter ()
122
106
123
- # Collect results from all worker processes
124
- results = []
125
- for _ in processes :
126
- results .extend (result_queue .get ())
127
-
128
- # Wait for all worker processes to finish
129
- for process in processes :
130
- process .join ()
131
-
132
- # Stop measuring time for the critical work
133
- total_time = time .perf_counter () - start
134
-
135
- # Extract precisions and latencies (outside the timed section)
136
- precisions , latencies = zip (* results )
107
+ total_time = time .perf_counter () - start
137
108
138
109
self .__class__ .delete_client ()
139
110
@@ -161,20 +132,3 @@ def post_search(self):
161
132
@classmethod
162
133
def delete_client (cls ):
163
134
pass
164
-
165
-
166
- def chunked_iterable (iterable , size ):
167
- """Yield successive chunks of a given size from an iterable."""
168
- it = iter (iterable )
169
- while chunk := list (itertools .islice (it , size )):
170
- yield chunk
171
-
172
-
173
- def process_chunk (chunk , search_one ):
174
- """Process a chunk of queries using the search_one function."""
175
- return [search_one (query ) for query in chunk ]
176
-
177
-
178
- def process_chunk_wrapper (chunk , search_one ):
179
- """Wrapper to process a chunk of queries."""
180
- return process_chunk (chunk , search_one )
0 commit comments