1
1
"""Item crud client."""
2
2
3
+ import json
3
4
import re
4
5
from typing import Any , Dict , List , Optional , Set , Union
5
6
from urllib .parse import unquote_plus , urljoin
14
15
from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
15
16
from pypgstac .hydration import hydrate
16
17
from stac_fastapi .api .models import JSONResponse
17
- from stac_fastapi .types .core import AsyncBaseCoreClient
18
+ from stac_fastapi .types .core import AsyncBaseCoreClient , Relations
18
19
from stac_fastapi .types .errors import InvalidQueryParameter , NotFoundError
19
20
from stac_fastapi .types .requests import get_base_url
20
21
from stac_fastapi .types .rfc3339 import DateTimeType
21
22
from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
22
- from stac_pydantic .links import Relations
23
23
from stac_pydantic .shared import BBox , MimeTypes
24
24
25
25
from stac_fastapi .pgstac .config import Settings
39
39
class CoreCrudClient (AsyncBaseCoreClient ):
40
40
"""Client for core endpoints defined by stac."""
41
41
42
- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43
- """Read all collections from the database."""
42
+ async def all_collections ( # noqa: C901
43
+ self ,
44
+ request : Request ,
45
+ # Extensions
46
+ bbox : Optional [BBox ] = None ,
47
+ datetime : Optional [DateTimeType ] = None ,
48
+ limit : Optional [int ] = None ,
49
+ query : Optional [str ] = None ,
50
+ token : Optional [str ] = None ,
51
+ fields : Optional [List [str ]] = None ,
52
+ sortby : Optional [str ] = None ,
53
+ filter : Optional [str ] = None ,
54
+ filter_lang : Optional [str ] = None ,
55
+ ** kwargs ,
56
+ ) -> Collections :
57
+ """Cross catalog search (GET).
58
+
59
+ Called with `GET /collections`.
60
+
61
+ Returns:
62
+ Collections which match the search criteria, returns all
63
+ collections by default.
64
+ """
44
65
base_url = get_base_url (request )
45
66
67
+ # Parse request parameters
68
+ base_args = {
69
+ "bbox" : bbox ,
70
+ "limit" : limit ,
71
+ "token" : token ,
72
+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
73
+ }
74
+
75
+ clean_args = clean_search_args (
76
+ base_args = base_args ,
77
+ datetime = datetime ,
78
+ fields = fields ,
79
+ sortby = sortby ,
80
+ filter_query = filter ,
81
+ filter_lang = filter_lang ,
82
+ )
83
+
46
84
async with request .app .state .get_connection (request , "r" ) as conn :
47
- collections = await conn .fetchval (
48
- """
49
- SELECT * FROM all_collections();
85
+ q , p = render (
50
86
"""
87
+ SELECT * FROM collection_search(:req::text::jsonb);
88
+ """ ,
89
+ req = json .dumps (clean_args ),
51
90
)
91
+ collections_result : Collections = await conn .fetchval (q , * p )
92
+
93
+ next : Optional [str ] = None
94
+ prev : Optional [str ] = None
95
+
96
+ if links := collections_result .get ("links" ):
97
+ next = collections_result ["links" ].pop ("next" )
98
+ prev = collections_result ["links" ].pop ("prev" )
99
+
52
100
linked_collections : List [Collection ] = []
101
+ collections = collections_result ["collections" ]
53
102
if collections is not None and len (collections ) > 0 :
54
103
for c in collections :
55
104
coll = Collection (** c )
@@ -71,25 +120,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71
120
72
121
linked_collections .append (coll )
73
122
74
- links = [
75
- {
76
- "rel" : Relations .root .value ,
77
- "type" : MimeTypes .json ,
78
- "href" : base_url ,
79
- },
80
- {
81
- "rel" : Relations .parent .value ,
82
- "type" : MimeTypes .json ,
83
- "href" : base_url ,
84
- },
85
- {
86
- "rel" : Relations .self .value ,
87
- "type" : MimeTypes .json ,
88
- "href" : urljoin (base_url , "collections" ),
89
- },
90
- ]
91
- collection_list = Collections (collections = linked_collections or [], links = links )
92
- return collection_list
123
+ links = await PagingLinks (
124
+ request = request ,
125
+ next = next ,
126
+ prev = prev ,
127
+ ).get_links ()
128
+
129
+ return Collections (
130
+ collections = linked_collections or [],
131
+ links = links ,
132
+ )
93
133
94
134
async def get_collection (
95
135
self , collection_id : str , request : Request , ** kwargs
@@ -386,7 +426,7 @@ async def post_search(
386
426
387
427
return ItemCollection (** item_collection )
388
428
389
- async def get_search ( # noqa: C901
429
+ async def get_search (
390
430
self ,
391
431
request : Request ,
392
432
collections : Optional [List [str ]] = None ,
@@ -421,51 +461,15 @@ async def get_search( # noqa: C901
421
461
"query" : orjson .loads (unquote_plus (query )) if query else query ,
422
462
}
423
463
424
- if filter :
425
- if filter_lang == "cql2-text" :
426
- filter = to_cql2 (parse_cql2_text (filter ))
427
- filter_lang = "cql2-json"
428
-
429
- base_args ["filter" ] = orjson .loads (filter )
430
- base_args ["filter-lang" ] = filter_lang
431
-
432
- if datetime :
433
- base_args ["datetime" ] = format_datetime_range (datetime )
434
-
435
- if intersects :
436
- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
437
-
438
- if sortby :
439
- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
440
- sort_param = []
441
- for sort in sortby :
442
- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
443
- if sortparts :
444
- sort_param .append (
445
- {
446
- "field" : sortparts .group (2 ).strip (),
447
- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
448
- }
449
- )
450
- base_args ["sortby" ] = sort_param
451
-
452
- if fields :
453
- includes = set ()
454
- excludes = set ()
455
- for field in fields :
456
- if field [0 ] == "-" :
457
- excludes .add (field [1 :])
458
- elif field [0 ] == "+" :
459
- includes .add (field [1 :])
460
- else :
461
- includes .add (field )
462
- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
463
-
464
- # Remove None values from dict
465
- clean = {}
466
- for k , v in base_args .items ():
467
- if v is not None and v != []:
468
- clean [k ] = v
464
+ clean = clean_search_args (
465
+ base_args = base_args ,
466
+ intersects = intersects ,
467
+ datetime = datetime ,
468
+ fields = fields ,
469
+ sortby = sortby ,
470
+ filter_query = filter ,
471
+ filter_lang = filter_lang ,
472
+ )
469
473
470
474
# Do the request
471
475
try :
@@ -476,3 +480,62 @@ async def get_search( # noqa: C901
476
480
) from e
477
481
478
482
return await self .post_search (search_request , request = request )
483
+
484
+
485
+ def clean_search_args ( # noqa: C901
486
+ base_args : Dict [str , Any ],
487
+ intersects : Optional [str ] = None ,
488
+ datetime : Optional [DateTimeType ] = None ,
489
+ fields : Optional [List [str ]] = None ,
490
+ sortby : Optional [str ] = None ,
491
+ filter_query : Optional [str ] = None ,
492
+ filter_lang : Optional [str ] = None ,
493
+ ) -> Dict [str , Any ]:
494
+ """Clean up search arguments to match format expected by pgstac"""
495
+ if filter_query :
496
+ if filter_lang == "cql2-text" :
497
+ filter_query = to_cql2 (parse_cql2_text (filter_query ))
498
+ filter_lang = "cql2-json"
499
+
500
+ base_args ["filter" ] = orjson .loads (filter_query )
501
+ base_args ["filter_lang" ] = filter_lang
502
+
503
+ if datetime :
504
+ base_args ["datetime" ] = format_datetime_range (datetime )
505
+
506
+ if intersects :
507
+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
508
+
509
+ if sortby :
510
+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
511
+ sort_param = []
512
+ for sort in sortby :
513
+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
514
+ if sortparts :
515
+ sort_param .append (
516
+ {
517
+ "field" : sortparts .group (2 ).strip (),
518
+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
519
+ }
520
+ )
521
+ base_args ["sortby" ] = sort_param
522
+
523
+ if fields :
524
+ includes = set ()
525
+ excludes = set ()
526
+ for field in fields :
527
+ if field [0 ] == "-" :
528
+ excludes .add (field [1 :])
529
+ elif field [0 ] == "+" :
530
+ includes .add (field [1 :])
531
+ else :
532
+ includes .add (field )
533
+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
534
+
535
+ # Remove None values from dict
536
+ clean = {}
537
+ for k , v in base_args .items ():
538
+ if v is not None and v != []:
539
+ clean [k ] = v
540
+
541
+ return clean
0 commit comments