10
10
AsyncIterable ,
11
11
Callable ,
12
12
Coroutine ,
13
+ DefaultDict ,
13
14
Deque ,
14
15
Dict ,
15
16
Iterable ,
@@ -43,12 +44,14 @@ def __init__(
43
44
rules : List [Rule ] = None ,
44
45
groups : Dict [Callable , Selector ] = None ,
45
46
save_rules : Dict [str , Any ] = None ,
47
+ events : Optional [DefaultDict ] = None ,
46
48
has_async : bool = False ,
47
49
scraper : Optional ["ScraperAbstract" ] = None ,
48
50
) -> None :
49
51
self .rules : List [Rule ] = rules or []
50
52
self .groups : Dict [Callable , Selector ] = groups or {}
51
53
self .save_rules : Dict [str , Any ] = save_rules or {"json" : save_json }
54
+ self .events : DefaultDict = events or collections .defaultdict (list )
52
55
self .has_async = has_async
53
56
self .scraper = scraper
54
57
self .adblock = Adblocker ()
@@ -129,10 +132,8 @@ def wrapper(func: Callable) -> Union[Callable, Coroutine]:
129
132
navigate = navigate ,
130
133
priority = priority ,
131
134
)
132
- if not self .scraper :
133
- self .rules .append (rule )
134
- else :
135
- self .scraper .rules .append (rule )
135
+ rules = self .scraper .rules if self .scraper else self .rules
136
+ rules .append (rule )
136
137
return func
137
138
138
139
return wrapper
@@ -198,10 +199,61 @@ def wrapper(func: Callable) -> Callable:
198
199
if asyncio .iscoroutinefunction (func ):
199
200
self .has_async = True
200
201
201
- if not self .scraper :
202
- self .save_rules [format ] = func
203
- else :
204
- self .scraper .save_rules [format ] = func
202
+ save_rules = self .scraper .save_rules if self .scraper else self .save_rules
203
+ save_rules [format ] = func
204
+ return func
205
+
206
+ return wrapper
207
+
208
+ def startup (self ) -> Callable :
209
+ """
210
+ Decorator to register a function to startup events.
211
+
212
+ Startup events are executed before any actual scraping happens to, for example, setup databases, etc.
213
+ """
214
+
215
+ def wrapper (func : Callable ) -> Callable :
216
+ if asyncio .iscoroutinefunction (func ):
217
+ self .has_async = True
218
+
219
+ events = self .scraper .events if self .scraper else self .events
220
+ events ["startup" ].append (func )
221
+ return func
222
+
223
+ return wrapper
224
+
225
+ def pre_setup (self ) -> Callable :
226
+ """
227
+ Decorator to register a function to pre-setup events.
228
+
229
+ Pre-setup events are executed after a page is loaded (or HTTP response in case of HTTPX) and
230
+ before running the setup functions.
231
+ """
232
+
233
+ def wrapper (func : Callable ) -> Callable :
234
+ if asyncio .iscoroutinefunction (func ):
235
+ self .has_async = True
236
+
237
+ events = self .scraper .events if self .scraper else self .events
238
+ events ["pre-setup" ].append (func )
239
+ return func
240
+
241
+ return wrapper
242
+
243
+ def post_setup (self ) -> Callable :
244
+ """
245
+ Decorator to register a function to post-setup events.
246
+
247
+ Post-setup events are executed after running the setup functions and before the actual web-scraping happens.
248
+ This is useful when "page clean-ups" are done in the setup functions.
249
+ """
250
+
251
+ def wrapper (func : Callable ) -> Callable :
252
+ if asyncio .iscoroutinefunction (func ):
253
+ self .has_async = True
254
+
255
+ events = self .scraper .events if self .scraper else self .events
256
+ events ["post-setup" ].append (func )
205
257
return func
206
258
207
259
return wrapper
@@ -220,16 +272,47 @@ def iter_urls(self) -> Iterable[str]:
220
272
except IndexError :
221
273
pass
222
274
275
+ def _update_rule_groups (self ) -> Iterable [Rule ]:
276
+ for rule in self .rules :
277
+ if rule .group :
278
+ yield rule
279
+ elif rule .handler in self .groups :
280
+ yield rule ._replace (group = self .groups [rule .handler ])
281
+ else :
282
+ yield rule ._replace (group = Selector (selector = ":root" ))
283
+
284
+ def initialize_scraper (self , urls : Sequence [str ]) -> None :
285
+ self .rules = [rule for rule in self ._update_rule_groups ()]
286
+ self .urls = collections .deque (urls )
287
+ self .allowed_domains = {urlparse (url ).netloc for url in urls }
288
+ self .event_startup ()
289
+
290
+ def event_startup (self ) -> None :
291
+ """
292
+ Run all startup events
293
+ """
294
+ loop = None
295
+ if self .has_async :
296
+ loop = asyncio .get_event_loop ()
297
+
298
+ for func in self .events ["startup" ]:
299
+ if asyncio .iscoroutinefunction (func ):
300
+ assert loop is not None
301
+ loop .run_until_complete (func ())
302
+ else :
303
+ func ()
304
+
223
305
224
306
class ScraperAbstract (ScraperBase ):
225
307
def __init__ (
226
308
self ,
227
309
rules : List [Rule ] = None ,
228
310
groups : Dict [Callable , Selector ] = None ,
229
311
save_rules : Dict [str , Any ] = None ,
312
+ events : Optional [DefaultDict ] = None ,
230
313
has_async : bool = False ,
231
314
) -> None :
232
- super (ScraperAbstract , self ).__init__ (rules , groups , save_rules , has_async )
315
+ super (ScraperAbstract , self ).__init__ (rules , groups , save_rules , events , has_async )
233
316
self .collected_data : List [ScrapedData ] = []
234
317
235
318
@abstractmethod
@@ -248,18 +331,6 @@ def navigate(self) -> bool:
248
331
async def navigate_async (self ) -> bool :
249
332
raise NotImplementedError # pragma: no cover
250
333
251
- def update_rule_groups (self ) -> None :
252
- self .rules = [rule for rule in self ._update_rule_groups ()]
253
-
254
- def _update_rule_groups (self ) -> Iterable [Rule ]:
255
- for rule in self .rules :
256
- if rule .group :
257
- yield rule
258
- elif rule .handler in self .groups :
259
- yield rule ._replace (group = self .groups [rule .handler ])
260
- else :
261
- yield rule ._replace (group = Selector (selector = ":root" ))
262
-
263
334
@abstractmethod
264
335
def collect_elements (self ) -> Iterable [Tuple [str , int , int , int , Any , Callable ]]:
265
336
"""
@@ -276,6 +347,34 @@ async def collect_elements_async(self) -> AsyncIterable[Tuple[str, int, int, int
276
347
yield "" , 0 , 0 , 0 , 0 , str # HACK: mypy does not identify this as AsyncIterable # pragma: no cover
277
348
raise NotImplementedError # pragma: no cover
278
349
350
+ def event_pre_setup (self , * args : Any ) -> None :
351
+ """
352
+ Run all pre-setup events.
353
+ """
354
+ for func in self .events ["pre-setup" ]:
355
+ func (* args )
356
+
357
+ def event_post_setup (self , * args : Any ) -> None :
358
+ """
359
+ Run all post-setup events.
360
+ """
361
+ for func in self .events ["post-setup" ]:
362
+ func (* args )
363
+
364
+ async def event_pre_setup_async (self , * args : Any ) -> None :
365
+ """
366
+ Run all pre-setup events.
367
+ """
368
+ for func in self .events ["pre-setup" ]:
369
+ await func (* args )
370
+
371
+ async def event_post_setup_async (self , * args : Any ) -> None :
372
+ """
373
+ Run all post-setup events.
374
+ """
375
+ for func in self .events ["post-setup" ]:
376
+ await func (* args )
377
+
279
378
def extract_all (self , page_number : int , ** kwargs : Any ) -> Iterable [ScrapedData ]:
280
379
"""
281
380
Extracts all the data using the registered handler functions.
0 commit comments