diff --git a/matter_server/server/__main__.py b/matter_server/server/__main__.py index cfe14d5d..c740fe70 100644 --- a/matter_server/server/__main__.py +++ b/matter_server/server/__main__.py @@ -13,6 +13,8 @@ DEFAULT_VENDOR_ID = 0xFFF1 DEFAULT_FABRIC_ID = 1 DEFAULT_PORT = 5580 +# Default to None to bind to all addresses on both IPv4 and IPv6 +DEFAULT_LISTEN_ADDRESS = None DEFAULT_STORAGE_PATH = os.path.join(Path.home(), ".matter_server") # Get parsed passed in arguments. @@ -45,6 +47,13 @@ default=DEFAULT_PORT, help=f"TCP Port to run the websocket server, defaults to {DEFAULT_PORT}", ) +parser.add_argument( + "--listen-address", + type=str, + action="append", + default=DEFAULT_LISTEN_ADDRESS, + help="IP address to bind the websocket server to, defaults to any IPv4 and IPv6 address.", +) parser.add_argument( "--log-level", type=str, @@ -95,6 +104,7 @@ def main() -> None: int(args.vendorid), int(args.fabricid), int(args.port), + args.listen_address, args.primary_interface, ) diff --git a/matter_server/server/helpers/custom_web_runner.py b/matter_server/server/helpers/custom_web_runner.py new file mode 100644 index 00000000..20f954e8 --- /dev/null +++ b/matter_server/server/helpers/custom_web_runner.py @@ -0,0 +1,65 @@ +"""Multiple host capable aiohttp Site.""" +from __future__ import annotations + +import asyncio +from ssl import SSLContext + +from aiohttp import web +from yarl import URL + + +class MultiHostTCPSite(web.BaseSite): + """Multiple host capable aiohttp Site. + + Vanilla TCPSite accepts only str as host. However, the underlying asyncio's + create_server() implementation does take a list of strings to bind to multiple + host IP's. To support multiple server_host entries (e.g. to enable dual-stack + explicitly), we would like to pass an array of strings. + """ + + __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl") + + def __init__( + self, + runner: web.BaseRunner, + host: None | str | list[str], + port: int, + *, + ssl_context: SSLContext | None = None, + backlog: int = 128, + reuse_address: bool | None = None, + reuse_port: bool | None = None, + ) -> None: + """Initialize HomeAssistantTCPSite.""" + super().__init__( + runner, + ssl_context=ssl_context, + backlog=backlog, + ) + self._host = host + self._port = port + self._reuse_address = reuse_address + self._reuse_port = reuse_port + + @property + def name(self) -> str: + """Return server URL.""" + scheme = "https" if self._ssl_context else "http" + host = self._host[0] if isinstance(self._host, list) else "0.0.0.0" + return str(URL.build(scheme=scheme, host=host, port=self._port)) + + async def start(self) -> None: + """Start server.""" + await super().start() + loop = asyncio.get_running_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_server( + server, + self._host, + self._port, + ssl=self._ssl_context, + backlog=self._backlog, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ) diff --git a/matter_server/server/server.py b/matter_server/server/server.py index 06c5d802..c19f3530 100644 --- a/matter_server/server/server.py +++ b/matter_server/server/server.py @@ -9,6 +9,8 @@ from aiohttp import web +from matter_server.server.helpers.custom_web_runner import MultiHostTCPSite + from ..common.const import SCHEMA_VERSION from ..common.errors import VersionMismatch from ..common.helpers.api import APICommandHandler, api_command @@ -54,7 +56,7 @@ class MatterServer: """Serve Matter stack over WebSockets.""" _runner: web.AppRunner | None = None - _http: web.TCPSite | None = None + _http: MultiHostTCPSite | None = None def __init__( self, @@ -62,13 +64,15 @@ def __init__( vendor_id: int, fabric_id: int, port: int, - primary_interface: str | None, + listen_addresses: list[str] | None = None, + primary_interface: str | None = None, ) -> None: """Initialize the Matter Server.""" self.storage_path = storage_path self.vendor_id = vendor_id self.fabric_id = fabric_id self.port = port + self.listen_addresses = listen_addresses self.primary_interface = primary_interface self.logger = logging.getLogger(__name__) self.app = web.Application() @@ -102,8 +106,9 @@ async def start(self) -> None: self.app.router.add_route("GET", "/", self._handle_info) self._runner = web.AppRunner(self.app, access_log=None) await self._runner.setup() - # set host to None to bind to all addresses on both IPv4 and IPv6 - self._http = web.TCPSite(self._runner, host=None, port=self.port) + self._http = MultiHostTCPSite( + self._runner, host=self.listen_addresses, port=self.port + ) await self._http.start() self.logger.debug("Webserver initialized.") diff --git a/tests/server/test_server.py b/tests/server/test_server.py index 974898c9..8df54785 100644 --- a/tests/server/test_server.py +++ b/tests/server/test_server.py @@ -13,7 +13,7 @@ pytestmark = pytest.mark.usefixtures( "application", "app_runner", - "tcp_site", + "multi_host_tcp_site", "chip_native", "chip_logging", "chip_stack", @@ -38,11 +38,11 @@ def app_runner_fixture() -> Generator[MagicMock, None, None]: yield app_runner -@pytest.fixture(name="tcp_site") -def tcp_site_fixture() -> Generator[MagicMock, None, None]: +@pytest.fixture(name="multi_host_tcp_site") +def multi_host_tcp_site_fixture() -> Generator[MagicMock, None, None]: """Return a mocked tcp site.""" - with patch("matter_server.server.server.web.TCPSite", autospec=True) as tcp_site: - yield tcp_site + with patch("matter_server.server.server.MultiHostTCPSite", autospec=True) as multi_host_tcp_site: + yield multi_host_tcp_site @pytest.fixture(name="chip_native") @@ -108,7 +108,7 @@ async def server_fixture() -> AsyncGenerator[MatterServer, None]: async def test_server_start( application: MagicMock, app_runner: MagicMock, - tcp_site: MagicMock, + multi_host_tcp_site: MagicMock, server: MatterServer, storage_controller: MagicMock, ) -> None: @@ -123,13 +123,14 @@ async def test_server_start( assert add_route.call_args_list[1][0][1] == "/" assert app_runner.call_count == 1 assert app_runner.return_value.setup.call_count == 1 - assert tcp_site.call_count == 1 - assert tcp_site.return_value.start.call_count == 1 + assert multi_host_tcp_site.call_count == 1 + assert multi_host_tcp_site.return_value.start.call_count == 1 assert storage_controller.return_value.start.call_count == 1 assert server.storage_path == "test_storage_path" assert server.vendor_id == 1234 assert server.fabric_id == 5678 assert server.port == 5580 + assert server.listen_addresses == None assert APICommand.SERVER_INFO in server.command_handlers assert APICommand.SERVER_DIAGNOSTICS in server.command_handlers assert APICommand.GET_NODES in server.command_handlers