Skip to content

Commit 03309d9

Browse files
Support specifying which IP addresses to listen on (#526)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
1 parent 1384e6f commit 03309d9

File tree

4 files changed

+93
-12
lines changed

4 files changed

+93
-12
lines changed

matter_server/server/__main__.py

+10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
DEFAULT_VENDOR_ID = 0xFFF1
1414
DEFAULT_FABRIC_ID = 1
1515
DEFAULT_PORT = 5580
16+
# Default to None to bind to all addresses on both IPv4 and IPv6
17+
DEFAULT_LISTEN_ADDRESS = None
1618
DEFAULT_STORAGE_PATH = os.path.join(Path.home(), ".matter_server")
1719

1820
# Get parsed passed in arguments.
@@ -45,6 +47,13 @@
4547
default=DEFAULT_PORT,
4648
help=f"TCP Port to run the websocket server, defaults to {DEFAULT_PORT}",
4749
)
50+
parser.add_argument(
51+
"--listen-address",
52+
type=str,
53+
action="append",
54+
default=DEFAULT_LISTEN_ADDRESS,
55+
help="IP address to bind the websocket server to, defaults to any IPv4 and IPv6 address.",
56+
)
4857
parser.add_argument(
4958
"--log-level",
5059
type=str,
@@ -95,6 +104,7 @@ def main() -> None:
95104
int(args.vendorid),
96105
int(args.fabricid),
97106
int(args.port),
107+
args.listen_address,
98108
args.primary_interface,
99109
)
100110

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Multiple host capable aiohttp Site."""
2+
from __future__ import annotations
3+
4+
import asyncio
5+
from ssl import SSLContext
6+
7+
from aiohttp import web
8+
from yarl import URL
9+
10+
11+
class MultiHostTCPSite(web.BaseSite):
12+
"""Multiple host capable aiohttp Site.
13+
14+
Vanilla TCPSite accepts only str as host. However, the underlying asyncio's
15+
create_server() implementation does take a list of strings to bind to multiple
16+
host IP's. To support multiple server_host entries (e.g. to enable dual-stack
17+
explicitly), we would like to pass an array of strings.
18+
"""
19+
20+
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")
21+
22+
def __init__(
23+
self,
24+
runner: web.BaseRunner,
25+
host: None | str | list[str],
26+
port: int,
27+
*,
28+
ssl_context: SSLContext | None = None,
29+
backlog: int = 128,
30+
reuse_address: bool | None = None,
31+
reuse_port: bool | None = None,
32+
) -> None:
33+
"""Initialize HomeAssistantTCPSite."""
34+
super().__init__(
35+
runner,
36+
ssl_context=ssl_context,
37+
backlog=backlog,
38+
)
39+
self._host = host
40+
self._port = port
41+
self._reuse_address = reuse_address
42+
self._reuse_port = reuse_port
43+
44+
@property
45+
def name(self) -> str:
46+
"""Return server URL."""
47+
scheme = "https" if self._ssl_context else "http"
48+
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
49+
return str(URL.build(scheme=scheme, host=host, port=self._port))
50+
51+
async def start(self) -> None:
52+
"""Start server."""
53+
await super().start()
54+
loop = asyncio.get_running_loop()
55+
server = self._runner.server
56+
assert server is not None
57+
self._server = await loop.create_server(
58+
server,
59+
self._host,
60+
self._port,
61+
ssl=self._ssl_context,
62+
backlog=self._backlog,
63+
reuse_address=self._reuse_address,
64+
reuse_port=self._reuse_port,
65+
)

matter_server/server/server.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from aiohttp import web
1111

12+
from matter_server.server.helpers.custom_web_runner import MultiHostTCPSite
13+
1214
from ..common.const import SCHEMA_VERSION
1315
from ..common.errors import VersionMismatch
1416
from ..common.helpers.api import APICommandHandler, api_command
@@ -54,21 +56,23 @@ class MatterServer:
5456
"""Serve Matter stack over WebSockets."""
5557

5658
_runner: web.AppRunner | None = None
57-
_http: web.TCPSite | None = None
59+
_http: MultiHostTCPSite | None = None
5860

5961
def __init__(
6062
self,
6163
storage_path: str,
6264
vendor_id: int,
6365
fabric_id: int,
6466
port: int,
65-
primary_interface: str | None,
67+
listen_addresses: list[str] | None = None,
68+
primary_interface: str | None = None,
6669
) -> None:
6770
"""Initialize the Matter Server."""
6871
self.storage_path = storage_path
6972
self.vendor_id = vendor_id
7073
self.fabric_id = fabric_id
7174
self.port = port
75+
self.listen_addresses = listen_addresses
7276
self.primary_interface = primary_interface
7377
self.logger = logging.getLogger(__name__)
7478
self.app = web.Application()
@@ -102,8 +106,9 @@ async def start(self) -> None:
102106
self.app.router.add_route("GET", "/", self._handle_info)
103107
self._runner = web.AppRunner(self.app, access_log=None)
104108
await self._runner.setup()
105-
# set host to None to bind to all addresses on both IPv4 and IPv6
106-
self._http = web.TCPSite(self._runner, host=None, port=self.port)
109+
self._http = MultiHostTCPSite(
110+
self._runner, host=self.listen_addresses, port=self.port
111+
)
107112
await self._http.start()
108113
self.logger.debug("Webserver initialized.")
109114

tests/server/test_server.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
pytestmark = pytest.mark.usefixtures(
1414
"application",
1515
"app_runner",
16-
"tcp_site",
16+
"multi_host_tcp_site",
1717
"chip_native",
1818
"chip_logging",
1919
"chip_stack",
@@ -38,11 +38,11 @@ def app_runner_fixture() -> Generator[MagicMock, None, None]:
3838
yield app_runner
3939

4040

41-
@pytest.fixture(name="tcp_site")
42-
def tcp_site_fixture() -> Generator[MagicMock, None, None]:
41+
@pytest.fixture(name="multi_host_tcp_site")
42+
def multi_host_tcp_site_fixture() -> Generator[MagicMock, None, None]:
4343
"""Return a mocked tcp site."""
44-
with patch("matter_server.server.server.web.TCPSite", autospec=True) as tcp_site:
45-
yield tcp_site
44+
with patch("matter_server.server.server.MultiHostTCPSite", autospec=True) as multi_host_tcp_site:
45+
yield multi_host_tcp_site
4646

4747

4848
@pytest.fixture(name="chip_native")
@@ -108,7 +108,7 @@ async def server_fixture() -> AsyncGenerator[MatterServer, None]:
108108
async def test_server_start(
109109
application: MagicMock,
110110
app_runner: MagicMock,
111-
tcp_site: MagicMock,
111+
multi_host_tcp_site: MagicMock,
112112
server: MatterServer,
113113
storage_controller: MagicMock,
114114
) -> None:
@@ -123,13 +123,14 @@ async def test_server_start(
123123
assert add_route.call_args_list[1][0][1] == "/"
124124
assert app_runner.call_count == 1
125125
assert app_runner.return_value.setup.call_count == 1
126-
assert tcp_site.call_count == 1
127-
assert tcp_site.return_value.start.call_count == 1
126+
assert multi_host_tcp_site.call_count == 1
127+
assert multi_host_tcp_site.return_value.start.call_count == 1
128128
assert storage_controller.return_value.start.call_count == 1
129129
assert server.storage_path == "test_storage_path"
130130
assert server.vendor_id == 1234
131131
assert server.fabric_id == 5678
132132
assert server.port == 5580
133+
assert server.listen_addresses == None
133134
assert APICommand.SERVER_INFO in server.command_handlers
134135
assert APICommand.SERVER_DIAGNOSTICS in server.command_handlers
135136
assert APICommand.GET_NODES in server.command_handlers

0 commit comments

Comments
 (0)