3
3
import tarfile
4
4
import urllib .request
5
5
from dataclasses import dataclass , field
6
- from typing import Dict , Optional
6
+ from typing import Dict , List , Optional , Union
7
7
from urllib .request import build_opener , install_opener
8
-
8
+ import boto3
9
+ import botocore .exceptions
9
10
from benchmark import DATASETS_DIR
10
11
from dataset_reader .ann_compound_reader import AnnCompoundReader
11
12
from dataset_reader .ann_h5_reader import AnnH5Reader
13
+ from dataset_reader .ann_h5_multi_reader import AnnH5MultiReader
12
14
from dataset_reader .base_reader import BaseReader
13
15
from dataset_reader .json_reader import JSONReader
14
- from dataset_reader .sparse_reader import SparseReader
16
+ from tqdm import tqdm
17
+ from pathlib import Path
15
18
16
19
# Needed for Cloudflare's firewall in ann-benchmarks
17
20
# See https://github.com/erikbern/ann-benchmarks/pull/561
24
27
class DatasetConfig :
25
28
name : str
26
29
type : str
27
- path : str
28
-
29
- link : Optional [str ] = None
30
+ path : Union [str , Dict [str , List [Dict [str , str ]]]] # Can be a string or a dict for multi-file structure
31
+ link : Optional [Union [str , Dict [str , List [Dict [str , str ]]]]] = None
30
32
schema : Optional [Dict [str , str ]] = field (default_factory = dict )
31
33
# None in case of sparse vectors:
32
34
vector_size : Optional [int ] = None
@@ -35,57 +37,227 @@ class DatasetConfig:
35
37
36
38
READER_TYPE = {
37
39
"h5" : AnnH5Reader ,
40
+ "h5-multi" : AnnH5MultiReader ,
38
41
"jsonl" : JSONReader ,
39
42
"tar" : AnnCompoundReader ,
40
- "sparse" : SparseReader ,
41
43
}
42
44
43
45
46
+ # Progress bar for urllib downloads
47
+ def show_progress (block_num , block_size , total_size ):
48
+ percent = round (block_num * block_size / total_size * 100 , 2 )
49
+ print (f"{ percent } %" , end = "\r " )
50
+
51
+
52
+ # Progress handler for S3 downloads
53
+ class S3Progress (tqdm ):
54
+ def __init__ (self , total_size ):
55
+ super ().__init__ (
56
+ total = total_size , unit = "B" , unit_scale = True , desc = "Downloading from S3"
57
+ )
58
+
59
+ def __call__ (self , bytes_amount ):
60
+ self .update (bytes_amount )
61
+
62
+
44
63
class Dataset :
45
- def __init__ (self , config : dict ):
64
+ def __init__ (
65
+ self ,
66
+ config : dict ,
67
+ skip_upload : bool ,
68
+ skip_search : bool ,
69
+ upload_start_idx : int ,
70
+ upload_end_idx : int ,
71
+ ):
46
72
self .config = DatasetConfig (** config )
73
+ self .skip_upload = skip_upload
74
+ self .skip_search = skip_search
75
+ self .upload_start_idx = upload_start_idx
76
+ self .upload_end_idx = upload_end_idx
47
77
48
78
def download (self ):
49
- target_path = DATASETS_DIR / self .config .path
79
+ if isinstance (self .config .path , dict ): # Handle multi-file datasets
80
+ if self .skip_search is False :
81
+ # Download query files
82
+ for query in self .config .path .get ("queries" , []):
83
+ self ._download_file (query ["path" ], query ["link" ])
84
+ else :
85
+ print (
86
+ f"skipping to download query file given skip_search={ self .skip_search } "
87
+ )
88
+ if self .skip_upload is False :
89
+ # Download data files
90
+ for data in self .config .path .get ("data" , []):
91
+ start_idx = data ["start_idx" ]
92
+ end_idx = data ["end_idx" ]
93
+ data_path = data ["path" ]
94
+ data_link = data ["link" ]
95
+ if self .upload_start_idx >= end_idx :
96
+ print (
97
+ f"skipping downloading { data_path } from { data_link } given { self .upload_start_idx } >{ end_idx } "
98
+ )
99
+ continue
100
+ if self .upload_end_idx < start_idx :
101
+ print (
102
+ f"skipping downloading { data_path } from { data_link } given { self .upload_end_idx } <{ start_idx } "
103
+ )
104
+ continue
105
+ self ._download_file (data ["path" ], data ["link" ])
106
+ else :
107
+ print (
108
+ f"skipping to download data/upload files given skip_upload={ self .skip_upload } "
109
+ )
110
+
111
+ else : # Handle single-file datasets
112
+ target_path = DATASETS_DIR / self .config .path
113
+
114
+ if target_path .exists ():
115
+ print (f"{ target_path } already exists" )
116
+ return
117
+
118
+ if self .config .link :
119
+ downloaded_withboto = False
120
+ if is_s3_link (self .config .link ):
121
+ print ("Use boto3 to download from S3. Faster!" )
122
+ try :
123
+ self ._download_from_s3 (self .config .link , target_path )
124
+ downloaded_withboto = True
125
+ except botocore .exceptions .NoCredentialsError :
126
+ print ("Credentials not found, downloading without boto3" )
127
+ if not downloaded_withboto :
128
+ print (f"Downloading from URL { self .config .link } ..." )
129
+ tmp_path , _ = urllib .request .urlretrieve (
130
+ self .config .link , None , show_progress
131
+ )
132
+ self ._extract_or_move_file (tmp_path , target_path )
50
133
134
+ def _download_file (self , relative_path : str , url : str ):
135
+ target_path = DATASETS_DIR / relative_path
51
136
if target_path .exists ():
52
137
print (f"{ target_path } already exists" )
53
138
return
54
139
55
- if self . config . link :
56
- print ( f"Downloading { self . config . link } ..." )
57
- tmp_path , _ = urllib . request . urlretrieve ( self . config . link )
140
+ print ( f"Downloading from { url } to { target_path } " )
141
+ tmp_path , _ = urllib . request . urlretrieve ( url , None , show_progress )
142
+ self . _extract_or_move_file ( tmp_path , target_path )
58
143
59
- if self .config .link .endswith (".tgz" ) or self .config .link .endswith (
60
- ".tar.gz"
61
- ):
62
- print (f"Extracting: { tmp_path } -> { target_path } " )
63
- (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
64
- file = tarfile .open (tmp_path )
144
+ def _extract_or_move_file (self , tmp_path , target_path ):
145
+ if tmp_path .endswith (".tgz" ) or tmp_path .endswith (".tar.gz" ):
146
+ print (f"Extracting: { tmp_path } -> { target_path } " )
147
+ (DATASETS_DIR / self .config .path ).mkdir (exist_ok = True , parents = True )
148
+ with tarfile .open (tmp_path ) as file :
65
149
file .extractall (target_path )
66
- file .close ()
67
- os .remove (tmp_path )
68
- else :
69
- print (f"Moving: { tmp_path } -> { target_path } " )
70
- (DATASETS_DIR / self .config .path ).parent .mkdir (exist_ok = True )
71
- shutil .copy2 (tmp_path , target_path )
72
- os .remove (tmp_path )
150
+ os .remove (tmp_path )
151
+ else :
152
+ print (f"Moving: { tmp_path } -> { target_path } " )
153
+ Path (target_path ).parent .mkdir (exist_ok = True )
154
+ shutil .copy2 (tmp_path , target_path )
155
+ os .remove (tmp_path )
156
+
157
+ def _download_from_s3 (self , link , target_path ):
158
+ s3 = boto3 .client ("s3" )
159
+ bucket_name , s3_key = parse_s3_url (link )
160
+ tmp_path = f"/tmp/{ os .path .basename (s3_key )} "
161
+
162
+ print (
163
+ f"Downloading from S3: { link } ... bucket_name={ bucket_name } , s3_key={ s3_key } "
164
+ )
165
+ object_info = s3 .head_object (Bucket = bucket_name , Key = s3_key )
166
+ total_size = object_info ["ContentLength" ]
167
+
168
+ with open (tmp_path , "wb" ) as f :
169
+ progress = S3Progress (total_size )
170
+ s3 .download_fileobj (bucket_name , s3_key , f , Callback = progress )
171
+
172
+ self ._extract_or_move_file (tmp_path , target_path )
73
173
74
174
def get_reader (self , normalize : bool ) -> BaseReader :
75
175
reader_class = READER_TYPE [self .config .type ]
76
- return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
176
+
177
+ if self .config .type == "h5-multi" :
178
+ # For h5-multi, we need to pass both data files and query file
179
+ data_files = self .config .path ["data" ]
180
+ for data_file_dict in data_files :
181
+ data_file_dict ["path" ] = DATASETS_DIR / data_file_dict ["path" ]
182
+ query_file = DATASETS_DIR / self .config .path ["queries" ][0 ]["path" ]
183
+ return reader_class (
184
+ data_files = data_files ,
185
+ query_file = query_file ,
186
+ normalize = normalize ,
187
+ skip_upload = self .skip_upload ,
188
+ skip_search = self .skip_search ,
189
+ )
190
+ else :
191
+ # For single-file datasets
192
+ return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
193
+
194
+
195
+ def is_s3_link (link ):
196
+ return link .startswith ("s3://" ) or "s3.amazonaws.com" in link
197
+
198
+
199
+ def parse_s3_url (s3_url ):
200
+ if s3_url .startswith ("s3://" ):
201
+ s3_parts = s3_url .replace ("s3://" , "" ).split ("/" , 1 )
202
+ bucket_name = s3_parts [0 ]
203
+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
204
+ else :
205
+ s3_parts = s3_url .replace ("http://" , "" ).replace ("https://" , "" ).split ("/" , 1 )
206
+
207
+ if ".s3.amazonaws.com" in s3_parts [0 ]:
208
+ bucket_name = s3_parts [0 ].split (".s3.amazonaws.com" )[0 ]
209
+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
210
+ else :
211
+ bucket_name = s3_parts [0 ]
212
+ s3_key = s3_parts [1 ] if len (s3_parts ) > 1 else ""
213
+
214
+ return bucket_name , s3_key
77
215
78
216
79
217
if __name__ == "__main__" :
80
- dataset = Dataset (
218
+ dataset_s3_split = Dataset (
81
219
{
82
- "name" : "glove-25-angular" ,
83
- "vector_size" : 25 ,
84
- "distance" : "Cosine" ,
85
- "type" : "h5" ,
86
- "path" : "glove-25-angular/glove-25-angular.hdf5" ,
87
- "link" : "http://ann-benchmarks.com/glove-25-angular.hdf5" ,
88
- }
220
+ "name" : "laion-img-emb-768d-1Billion-cosine" ,
221
+ "vector_size" : 768 ,
222
+ "distance" : "cosine" ,
223
+ "type" : "h5-multi" ,
224
+ "path" : {
225
+ "data" : [
226
+ {
227
+ "file_number" : 1 ,
228
+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
229
+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part1-0_to_10000000.hdf5" ,
230
+ "vector_range" : "0-10000000" ,
231
+ "file_size" : "30.7 GB" ,
232
+ },
233
+ {
234
+ "file_number" : 2 ,
235
+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
236
+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part10-90000000_to_100000000.hdf5" ,
237
+ "vector_range" : "90000000-100000000" ,
238
+ "file_size" : "30.7 GB" ,
239
+ },
240
+ {
241
+ "file_number" : 3 ,
242
+ "path" : "laion-1b/data/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
243
+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-data-part100-990000000_to_1000000000.hdf5" ,
244
+ "vector_range" : "990000000-1000000000" ,
245
+ "file_size" : "30.7 GB" ,
246
+ },
247
+ ],
248
+ "queries" : [
249
+ {
250
+ "path" : "laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
251
+ "link" : "http://benchmarks.redislabs.s3.amazonaws.com/vecsim/laion-1b/laion-img-emb-768d-1Billion-cosine-queries.hdf5" ,
252
+ "file_size" : "38.7 MB" ,
253
+ },
254
+ ],
255
+ },
256
+ },
257
+ skip_upload = True ,
258
+ skip_search = False ,
89
259
)
90
260
91
- dataset .download ()
261
+ dataset_s3_split .download ()
262
+ reader = dataset_s3_split .get_reader (normalize = False )
263
+ print (reader ) # Outputs the AnnH5MultiReader instance
0 commit comments