Skip to content

Commit

Permalink
Fix typing for headers class
Browse files Browse the repository at this point in the history
  • Loading branch information
lexiforest committed Dec 30, 2024
1 parent 9e8e90d commit fac9df6
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions curl_cffi/requests/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,15 @@ def to_str(value: Union[str, bytes], encoding: str = "utf-8") -> str:
return value if isinstance(value, str) else value.decode(encoding)


def to_bytes_or_str_or_none(value: Optional[str], match_type_of: AnyStr) -> Optional[AnyStr]:
if value is None:
return value

if isinstance(match_type_of, str):
return value

return value.encode()


SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}


def obfuscate_sensitive_headers(
items: Iterable[Tuple[AnyStr, AnyStr]],
items: Iterable[Tuple[AnyStr, Optional[AnyStr]]],
) -> Iterator[Tuple[AnyStr, Optional[AnyStr]]]:
for k, v in items:
if to_str(k.lower()) in SENSITIVE_HEADERS:
v = to_bytes_or_str_or_none("[secure]", match_type_of=v)
v = b"[secure]" if isinstance(v, bytes) else "[secure]" # type: ignore
yield k, v


Expand Down Expand Up @@ -110,7 +100,7 @@ class Headers(MutableMapping[str, Optional[str]]):

def __init__(self, headers: Optional[HeaderTypes] = None, encoding: Optional[str] = None):
if not headers:
self._list = [] # type: List[Tuple[bytes, bytes, bytes]]
self._list = [] # type: List[Tuple[bytes, bytes, Optional[bytes]]]
elif isinstance(headers, Headers):
self._list = list(headers._list)
elif isinstance(headers, AbcMapping):
Expand Down Expand Up @@ -187,7 +177,7 @@ def values(self) -> ValuesView[Optional[str]]:
values_dict: Dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding) if value is not None else value
str_value = value.decode(self.encoding) if value is not None else "None"
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
Expand All @@ -202,7 +192,7 @@ def items(self) -> ItemsView[str, Optional[str]]:
values_dict: Dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding) if value is not None else value
str_value = value.decode(self.encoding) if value is not None else "None"
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
Expand Down Expand Up @@ -249,7 +239,7 @@ def get_list(self, key: str, split_commas: bool = False) -> List[Optional[str]]:

split_values = []
for value in values:
split_values.extend([item.strip() for item in value.split(",")])
split_values.extend([item.strip() for item in value.split(",")]) # type: ignore
return split_values

def update(self, headers: Optional[HeaderTypes] = None) -> None: # type: ignore
Expand Down Expand Up @@ -280,7 +270,7 @@ def __getitem__(self, key: str) -> Optional[str]:
return None

if items:
return ", ".join(items)
return ", ".join([str(item) for item in items])

raise KeyError(key)

Expand Down

0 comments on commit fac9df6

Please sign in to comment.