Skip to content

Commit 9a3797a

Browse files
add collection search extension (#136)
* add collection-search extension * define collections_get_request_model * add test for additional extensions with collection-search * pass collections_get_request_model to StacApi * do not pass collection extensions to post_request_model * do not pass collection extensions to get_request_model * Do not add extensions to collection-search extension * use CollectionSearchExtension.from_extensions() * keep extensions and collection_search_extension separate * update tests * filter -> filter_query * add collection_get_request_model to client * add collection_request_model to the client * recycle collections_get_request_model in client * drop print statement * simplify * remove unused * clean up control flow for extension-specific logic * add link to PR in changelog --------- Co-authored-by: vincentsarago <vincent.sarago@gmail.com>
1 parent c89360c commit 9a3797a

File tree

9 files changed

+269
-91
lines changed

9 files changed

+269
-91
lines changed

.dockerignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ coverage.xml
99
*.log
1010
.git
1111
.envrc
12+
*egg-info
1213

13-
venv
14+
venv
15+
env

.github/workflows/cicd.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
runs-on: ubuntu-latest
4848
services:
4949
pgstac:
50-
image: ghcr.io/stac-utils/pgstac:v0.7.10
50+
image: ghcr.io/stac-utils/pgstac:v0.8.6
5151
env:
5252
POSTGRES_USER: username
5353
POSTGRES_PASSWORD: password

CHANGES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## [Unreleased]
44

5-
- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, https://github.com/stac-utils/stac-fastapi-pgstac/pull/142)
5+
- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, <https://github.com/stac-utils/stac-fastapi-pgstac/pull/142>)
6+
- Add collection search extension ([#139](https://github.com/stac-utils/stac-fastapi-pgstac/pull/139))
67

78
- Fix `filter` extension implementation in `CoreCrudClient`
89

setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
"orjson",
1111
"pydantic",
1212
"stac_pydantic==3.1.*",
13-
"stac-fastapi.api~=3.0",
14-
"stac-fastapi.extensions~=3.0",
15-
"stac-fastapi.types~=3.0",
13+
"stac-fastapi.api~=3.0.2",
14+
"stac-fastapi.extensions~=3.0.2",
15+
"stac-fastapi.types~=3.0.2",
1616
"asyncpg",
1717
"buildpg",
1818
"brotli_asgi",

stac_fastapi/pgstac/app.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi.responses import ORJSONResponse
1111
from stac_fastapi.api.app import StacApi
1212
from stac_fastapi.api.models import (
13+
EmptyRequest,
1314
ItemCollectionUri,
1415
create_get_request_model,
1516
create_post_request_model,
@@ -22,6 +23,7 @@
2223
TokenPaginationExtension,
2324
TransactionExtension,
2425
)
26+
from stac_fastapi.extensions.core.collection_search import CollectionSearchExtension
2527
from stac_fastapi.extensions.third_party import BulkTransactionExtension
2628

2729
from stac_fastapi.pgstac.config import Settings
@@ -47,34 +49,49 @@
4749
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
4850
}
4951

50-
if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"):
51-
extensions = [
52-
extensions_map[extension_name] for extension_name in enabled_extensions.split(",")
53-
]
54-
else:
55-
extensions = list(extensions_map.values())
52+
enabled_extensions = (
53+
os.environ["ENABLED_EXTENSIONS"].split(",")
54+
if "ENABLED_EXTENSIONS" in os.environ
55+
else list(extensions_map.keys()) + ["collection_search"]
56+
)
57+
extensions = [
58+
extension for key, extension in extensions_map.items() if key in enabled_extensions
59+
]
5660

57-
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions):
58-
items_get_request_model = create_request_model(
61+
items_get_request_model = (
62+
create_request_model(
5963
model_name="ItemCollectionUri",
6064
base_model=ItemCollectionUri,
6165
mixins=[TokenPaginationExtension().GET],
6266
request_type="GET",
6367
)
64-
else:
65-
items_get_request_model = ItemCollectionUri
68+
if any(isinstance(ext, TokenPaginationExtension) for ext in extensions)
69+
else ItemCollectionUri
70+
)
71+
72+
collection_search_extension = (
73+
CollectionSearchExtension.from_extensions(extensions)
74+
if "collection_search" in enabled_extensions
75+
else None
76+
)
77+
collections_get_request_model = (
78+
collection_search_extension.GET if collection_search_extension else EmptyRequest
79+
)
6680

6781
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
6882
get_request_model = create_get_request_model(extensions)
6983

7084
api = StacApi(
7185
settings=settings,
72-
extensions=extensions,
86+
extensions=extensions + [collection_search_extension]
87+
if collection_search_extension
88+
else extensions,
7389
client=CoreCrudClient(post_request_model=post_request_model), # type: ignore
7490
response_class=ORJSONResponse,
7591
items_get_request_model=items_get_request_model,
7692
search_get_request_model=get_request_model,
7793
search_post_request_model=post_request_model,
94+
collections_get_request_model=collections_get_request_model,
7895
)
7996
app = api.app
8097

stac_fastapi/pgstac/core.py

+135-72
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Item crud client."""
22

3+
import json
34
import re
45
from typing import Any, Dict, List, Optional, Set, Union
56
from urllib.parse import unquote_plus, urljoin
@@ -14,12 +15,11 @@
1415
from pygeofilter.parsers.cql2_text import parse as parse_cql2_text
1516
from pypgstac.hydration import hydrate
1617
from stac_fastapi.api.models import JSONResponse
17-
from stac_fastapi.types.core import AsyncBaseCoreClient
18+
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
1819
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
1920
from stac_fastapi.types.requests import get_base_url
2021
from stac_fastapi.types.rfc3339 import DateTimeType
2122
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
22-
from stac_pydantic.links import Relations
2323
from stac_pydantic.shared import BBox, MimeTypes
2424

2525
from stac_fastapi.pgstac.config import Settings
@@ -39,17 +39,66 @@
3939
class CoreCrudClient(AsyncBaseCoreClient):
4040
"""Client for core endpoints defined by stac."""
4141

42-
async def all_collections(self, request: Request, **kwargs) -> Collections:
43-
"""Read all collections from the database."""
42+
async def all_collections( # noqa: C901
43+
self,
44+
request: Request,
45+
# Extensions
46+
bbox: Optional[BBox] = None,
47+
datetime: Optional[DateTimeType] = None,
48+
limit: Optional[int] = None,
49+
query: Optional[str] = None,
50+
token: Optional[str] = None,
51+
fields: Optional[List[str]] = None,
52+
sortby: Optional[str] = None,
53+
filter: Optional[str] = None,
54+
filter_lang: Optional[str] = None,
55+
**kwargs,
56+
) -> Collections:
57+
"""Cross catalog search (GET).
58+
59+
Called with `GET /collections`.
60+
61+
Returns:
62+
Collections which match the search criteria, returns all
63+
collections by default.
64+
"""
4465
base_url = get_base_url(request)
4566

67+
# Parse request parameters
68+
base_args = {
69+
"bbox": bbox,
70+
"limit": limit,
71+
"token": token,
72+
"query": orjson.loads(unquote_plus(query)) if query else query,
73+
}
74+
75+
clean_args = clean_search_args(
76+
base_args=base_args,
77+
datetime=datetime,
78+
fields=fields,
79+
sortby=sortby,
80+
filter_query=filter,
81+
filter_lang=filter_lang,
82+
)
83+
4684
async with request.app.state.get_connection(request, "r") as conn:
47-
collections = await conn.fetchval(
48-
"""
49-
SELECT * FROM all_collections();
85+
q, p = render(
5086
"""
87+
SELECT * FROM collection_search(:req::text::jsonb);
88+
""",
89+
req=json.dumps(clean_args),
5190
)
91+
collections_result: Collections = await conn.fetchval(q, *p)
92+
93+
next: Optional[str] = None
94+
prev: Optional[str] = None
95+
96+
if links := collections_result.get("links"):
97+
next = collections_result["links"].pop("next")
98+
prev = collections_result["links"].pop("prev")
99+
52100
linked_collections: List[Collection] = []
101+
collections = collections_result["collections"]
53102
if collections is not None and len(collections) > 0:
54103
for c in collections:
55104
coll = Collection(**c)
@@ -71,25 +120,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71120

72121
linked_collections.append(coll)
73122

74-
links = [
75-
{
76-
"rel": Relations.root.value,
77-
"type": MimeTypes.json,
78-
"href": base_url,
79-
},
80-
{
81-
"rel": Relations.parent.value,
82-
"type": MimeTypes.json,
83-
"href": base_url,
84-
},
85-
{
86-
"rel": Relations.self.value,
87-
"type": MimeTypes.json,
88-
"href": urljoin(base_url, "collections"),
89-
},
90-
]
91-
collection_list = Collections(collections=linked_collections or [], links=links)
92-
return collection_list
123+
links = await PagingLinks(
124+
request=request,
125+
next=next,
126+
prev=prev,
127+
).get_links()
128+
129+
return Collections(
130+
collections=linked_collections or [],
131+
links=links,
132+
)
93133

94134
async def get_collection(
95135
self, collection_id: str, request: Request, **kwargs
@@ -386,7 +426,7 @@ async def post_search(
386426

387427
return ItemCollection(**item_collection)
388428

389-
async def get_search( # noqa: C901
429+
async def get_search(
390430
self,
391431
request: Request,
392432
collections: Optional[List[str]] = None,
@@ -421,51 +461,15 @@ async def get_search( # noqa: C901
421461
"query": orjson.loads(unquote_plus(query)) if query else query,
422462
}
423463

424-
if filter:
425-
if filter_lang == "cql2-text":
426-
filter = to_cql2(parse_cql2_text(filter))
427-
filter_lang = "cql2-json"
428-
429-
base_args["filter"] = orjson.loads(filter)
430-
base_args["filter-lang"] = filter_lang
431-
432-
if datetime:
433-
base_args["datetime"] = format_datetime_range(datetime)
434-
435-
if intersects:
436-
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
437-
438-
if sortby:
439-
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
440-
sort_param = []
441-
for sort in sortby:
442-
sortparts = re.match(r"^([+-]?)(.*)$", sort)
443-
if sortparts:
444-
sort_param.append(
445-
{
446-
"field": sortparts.group(2).strip(),
447-
"direction": "desc" if sortparts.group(1) == "-" else "asc",
448-
}
449-
)
450-
base_args["sortby"] = sort_param
451-
452-
if fields:
453-
includes = set()
454-
excludes = set()
455-
for field in fields:
456-
if field[0] == "-":
457-
excludes.add(field[1:])
458-
elif field[0] == "+":
459-
includes.add(field[1:])
460-
else:
461-
includes.add(field)
462-
base_args["fields"] = {"include": includes, "exclude": excludes}
463-
464-
# Remove None values from dict
465-
clean = {}
466-
for k, v in base_args.items():
467-
if v is not None and v != []:
468-
clean[k] = v
464+
clean = clean_search_args(
465+
base_args=base_args,
466+
intersects=intersects,
467+
datetime=datetime,
468+
fields=fields,
469+
sortby=sortby,
470+
filter_query=filter,
471+
filter_lang=filter_lang,
472+
)
469473

470474
# Do the request
471475
try:
@@ -476,3 +480,62 @@ async def get_search( # noqa: C901
476480
) from e
477481

478482
return await self.post_search(search_request, request=request)
483+
484+
485+
def clean_search_args( # noqa: C901
486+
base_args: Dict[str, Any],
487+
intersects: Optional[str] = None,
488+
datetime: Optional[DateTimeType] = None,
489+
fields: Optional[List[str]] = None,
490+
sortby: Optional[str] = None,
491+
filter_query: Optional[str] = None,
492+
filter_lang: Optional[str] = None,
493+
) -> Dict[str, Any]:
494+
"""Clean up search arguments to match format expected by pgstac"""
495+
if filter_query:
496+
if filter_lang == "cql2-text":
497+
filter_query = to_cql2(parse_cql2_text(filter_query))
498+
filter_lang = "cql2-json"
499+
500+
base_args["filter"] = orjson.loads(filter_query)
501+
base_args["filter_lang"] = filter_lang
502+
503+
if datetime:
504+
base_args["datetime"] = format_datetime_range(datetime)
505+
506+
if intersects:
507+
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
508+
509+
if sortby:
510+
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
511+
sort_param = []
512+
for sort in sortby:
513+
sortparts = re.match(r"^([+-]?)(.*)$", sort)
514+
if sortparts:
515+
sort_param.append(
516+
{
517+
"field": sortparts.group(2).strip(),
518+
"direction": "desc" if sortparts.group(1) == "-" else "asc",
519+
}
520+
)
521+
base_args["sortby"] = sort_param
522+
523+
if fields:
524+
includes = set()
525+
excludes = set()
526+
for field in fields:
527+
if field[0] == "-":
528+
excludes.add(field[1:])
529+
elif field[0] == "+":
530+
includes.add(field[1:])
531+
else:
532+
includes.add(field)
533+
base_args["fields"] = {"include": includes, "exclude": excludes}
534+
535+
# Remove None values from dict
536+
clean = {}
537+
for k, v in base_args.items():
538+
if v is not None and v != []:
539+
clean[k] = v
540+
541+
return clean

0 commit comments

Comments
 (0)