From fac9df6098172cf73411bd4b23106497f1fcc9d4 Mon Sep 17 00:00:00 2001 From: Lyonnet Date: Mon, 30 Dec 2024 15:01:53 +0800 Subject: [PATCH] Fix typing for headers class --- curl_cffi/requests/headers.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/curl_cffi/requests/headers.py b/curl_cffi/requests/headers.py index 0c14b04..2425c2c 100644 --- a/curl_cffi/requests/headers.py +++ b/curl_cffi/requests/headers.py @@ -48,25 +48,15 @@ def to_str(value: Union[str, bytes], encoding: str = "utf-8") -> str: return value if isinstance(value, str) else value.decode(encoding) -def to_bytes_or_str_or_none(value: Optional[str], match_type_of: AnyStr) -> Optional[AnyStr]: - if value is None: - return value - - if isinstance(match_type_of, str): - return value - - return value.encode() - - SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} def obfuscate_sensitive_headers( - items: Iterable[Tuple[AnyStr, AnyStr]], + items: Iterable[Tuple[AnyStr, Optional[AnyStr]]], ) -> Iterator[Tuple[AnyStr, Optional[AnyStr]]]: for k, v in items: if to_str(k.lower()) in SENSITIVE_HEADERS: - v = to_bytes_or_str_or_none("[secure]", match_type_of=v) + v = b"[secure]" if isinstance(v, bytes) else "[secure]" # type: ignore yield k, v @@ -110,7 +100,7 @@ class Headers(MutableMapping[str, Optional[str]]): def __init__(self, headers: Optional[HeaderTypes] = None, encoding: Optional[str] = None): if not headers: - self._list = [] # type: List[Tuple[bytes, bytes, bytes]] + self._list = [] # type: List[Tuple[bytes, bytes, Optional[bytes]]] elif isinstance(headers, Headers): self._list = list(headers._list) elif isinstance(headers, AbcMapping): @@ -187,7 +177,7 @@ def values(self) -> ValuesView[Optional[str]]: values_dict: Dict[str, str] = {} for _, key, value in self._list: str_key = key.decode(self.encoding) - str_value = value.decode(self.encoding) if value is not None else value + str_value = value.decode(self.encoding) if value is not None else "None" if str_key in values_dict: values_dict[str_key] += f", {str_value}" else: @@ -202,7 +192,7 @@ def items(self) -> ItemsView[str, Optional[str]]: values_dict: Dict[str, str] = {} for _, key, value in self._list: str_key = key.decode(self.encoding) - str_value = value.decode(self.encoding) if value is not None else value + str_value = value.decode(self.encoding) if value is not None else "None" if str_key in values_dict: values_dict[str_key] += f", {str_value}" else: @@ -249,7 +239,7 @@ def get_list(self, key: str, split_commas: bool = False) -> List[Optional[str]]: split_values = [] for value in values: - split_values.extend([item.strip() for item in value.split(",")]) + split_values.extend([item.strip() for item in value.split(",")]) # type: ignore return split_values def update(self, headers: Optional[HeaderTypes] = None) -> None: # type: ignore @@ -280,7 +270,7 @@ def __getitem__(self, key: str) -> Optional[str]: return None if items: - return ", ".join(items) + return ", ".join([str(item) for item in items]) raise KeyError(key)