Skip to content

Commit

Permalink
fixed typing
Browse files Browse the repository at this point in the history
  • Loading branch information
cloutierMat committed Oct 9, 2024
1 parent 32d552e commit 5b844ed
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
7 changes: 5 additions & 2 deletions moto/apigateway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,10 @@ def create_from_cloudformation_json( # type: ignore[misc]

def add_child(self, path: str, parent_id: Optional[str] = None) -> Resource:
child_id = create_apigw_id(
self.account_id, self.region_name, "resource", parent_id + "." + path
self.account_id,
self.region_name,
"resource",
(parent_id or "") + "." + path,
)
child = Resource(
resource_id=child_id,
Expand Down Expand Up @@ -2158,7 +2161,7 @@ def create_api_key(self, payload: Dict[str, Any]) -> ApiKey:
if api_key.value == payload["value"]:
raise ApiKeyAlreadyExists()
api_key_id = create_apigw_id(
self.account_id, self.region_name, "api_key", payload.get("name")
self.account_id, self.region_name, "api_key", payload["name"]
)
key = ApiKey(api_key_id=api_key_id, **payload)
self.keys[key.id] = key
Expand Down
2 changes: 1 addition & 1 deletion moto/apigateway/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from moto.utilities.id_generator import generate_str_id


def create_apigw_id(account_id, region, resource, name) -> str:
def create_apigw_id(account_id: str, region: str, resource: str, name: str) -> str:
return generate_str_id(
account_id,
region,
Expand Down
52 changes: 36 additions & 16 deletions moto/utilities/id_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import Callable
from typing import Any, Callable

from moto.moto_api._internal import mock_random

Expand All @@ -11,38 +11,50 @@ class MotoIdManager:
use the `id_manager` instance created below."""

_custom_ids: dict[str, str]
_id_sources: [IdSource]
_id_sources: list[IdSource]

_lock: threading.RLock

def __init__(self):
def __init__(self) -> None:
self._custom_ids = {}
self._lock = threading.RLock()
self._id_sources = []

self.add_id_source(self.get_custom_id)

def get_custom_id(self, account_id, region, service, resource, name) -> str | None:
def get_custom_id(
self, account_id: str, region: str, service: str, resource: str, name: str
) -> str | None:
# retrieves a custom_id for a resource. Returns None
return self._custom_ids.get(
".".join([account_id, region, service, resource, name])
)

def set_custom_id(self, account_id, region, service, resource, name, custom_id):
def set_custom_id(
self,
account_id: str,
region: str,
service: str,
resource: str,
name: str,
custom_id: str,
) -> None:
# sets a custom_id for a resource
with self._lock:
self._custom_ids[
".".join([account_id, region, service, resource, name])
] = custom_id

def unset_custom_id(self, account_id, region, service, resource, name):
def unset_custom_id(
self, account_id: str, region: str, service: str, resource: str, name: str
) -> None:
# removes a set custom_id for a resource
with self._lock:
self._custom_ids.pop(
".".join([account_id, region, service, resource, name]), None
)

def add_id_source(self, id_source: IdSource):
def add_id_source(self, id_source: IdSource) -> None:
self._id_sources.append(id_source)

def find_id_from_sources(
Expand All @@ -51,14 +63,22 @@ def find_id_from_sources(
for id_source in self._id_sources:
if found_id := id_source(account_id, region, service, resource, name):
return found_id
return None


id_manager = MotoIdManager()


def moto_id(fn):
def moto_id(fn: Callable[..., str]) -> Callable[..., str]:
# Decorator for helping in creation of static ids within Moto.
def _wrapper(account_id, region, service, resource, name, **kwargs):
def _wrapper(
account_id: str,
region: str,
service: str,
resource: str,
name: str,
**kwargs: dict[str, Any],
) -> str:
if found_id := id_manager.find_id_from_sources(
account_id, region, service, resource, name
):
Expand All @@ -69,14 +89,14 @@ def _wrapper(account_id, region, service, resource, name, **kwargs):


@moto_id
def generate_str_id(
account_id,
region,
service,
resource,
name,
def generate_str_id( # type: ignore
account_id: str,
region: str,
service: str,
resource: str,
name: str,
length: int = 20,
include_digits: bool = True,
lower_case: bool = False,
):
) -> str:
return mock_random.get_random_string(length, include_digits, lower_case)

0 comments on commit 5b844ed

Please sign in to comment.