Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support specifying which IP addresses to listen on #526

Merged
merged 14 commits into from
Feb 6, 2024
10 changes: 10 additions & 0 deletions matter_server/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
DEFAULT_VENDOR_ID = 0xFFF1
DEFAULT_FABRIC_ID = 1
DEFAULT_PORT = 5580
# Default to None to bind to all addresses on both IPv4 and IPv6
DEFAULT_LISTEN_ADDRESS = None
DEFAULT_STORAGE_PATH = os.path.join(Path.home(), ".matter_server")

# Get parsed passed in arguments.
Expand Down Expand Up @@ -45,6 +47,13 @@
default=DEFAULT_PORT,
help=f"TCP Port to run the websocket server, defaults to {DEFAULT_PORT}",
)
parser.add_argument(
"--listen-address",
type=str,
action="append",
default=DEFAULT_LISTEN_ADDRESS,
help="IP address to bind the websocket server to, defaults to any IPv4 and IPv6 address.",
)
parser.add_argument(
"--log-level",
type=str,
Expand Down Expand Up @@ -95,6 +104,7 @@ def main() -> None:
int(args.vendorid),
int(args.fabricid),
int(args.port),
args.listen_address,
args.primary_interface,
)

Expand Down
65 changes: 65 additions & 0 deletions matter_server/server/helpers/custom_web_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Multiple host capable aiohttp Site."""
from __future__ import annotations

import asyncio
from ssl import SSLContext

from aiohttp import web
from yarl import URL


class MultiHostTCPSite(web.BaseSite):
"""Multiple host capable aiohttp Site.

Vanilla TCPSite accepts only str as host. However, the underlying asyncio's
create_server() implementation does take a list of strings to bind to multiple
host IP's. To support multiple server_host entries (e.g. to enable dual-stack
explicitly), we would like to pass an array of strings.
"""

__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")

def __init__(
self,
runner: web.BaseRunner,
host: None | str | list[str],
port: int,
*,
ssl_context: SSLContext | None = None,
backlog: int = 128,
reuse_address: bool | None = None,
reuse_port: bool | None = None,
) -> None:
"""Initialize HomeAssistantTCPSite."""
super().__init__(
runner,
ssl_context=ssl_context,
backlog=backlog,
)
self._host = host
self._port = port
self._reuse_address = reuse_address
self._reuse_port = reuse_port

@property
def name(self) -> str:
"""Return server URL."""
scheme = "https" if self._ssl_context else "http"
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
return str(URL.build(scheme=scheme, host=host, port=self._port))

async def start(self) -> None:
"""Start server."""
await super().start()
loop = asyncio.get_running_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
server,
self._host,
self._port,
ssl=self._ssl_context,
backlog=self._backlog,
reuse_address=self._reuse_address,
reuse_port=self._reuse_port,
)
13 changes: 9 additions & 4 deletions matter_server/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from aiohttp import web

from matter_server.server.helpers.custom_web_runner import MultiHostTCPSite

from ..common.const import SCHEMA_VERSION
from ..common.errors import VersionMismatch
from ..common.helpers.api import APICommandHandler, api_command
Expand Down Expand Up @@ -54,21 +56,23 @@ class MatterServer:
"""Serve Matter stack over WebSockets."""

_runner: web.AppRunner | None = None
_http: web.TCPSite | None = None
_http: MultiHostTCPSite | None = None

def __init__(
self,
storage_path: str,
vendor_id: int,
fabric_id: int,
port: int,
primary_interface: str | None,
listen_addresses: list[str] | None = None,
primary_interface: str | None = None,
) -> None:
"""Initialize the Matter Server."""
self.storage_path = storage_path
self.vendor_id = vendor_id
self.fabric_id = fabric_id
self.port = port
self.listen_addresses = listen_addresses
self.primary_interface = primary_interface
self.logger = logging.getLogger(__name__)
self.app = web.Application()
Expand Down Expand Up @@ -102,8 +106,9 @@ async def start(self) -> None:
self.app.router.add_route("GET", "/", self._handle_info)
self._runner = web.AppRunner(self.app, access_log=None)
await self._runner.setup()
# set host to None to bind to all addresses on both IPv4 and IPv6
self._http = web.TCPSite(self._runner, host=None, port=self.port)
self._http = MultiHostTCPSite(
self._runner, host=self.listen_addresses, port=self.port
)
await self._http.start()
self.logger.debug("Webserver initialized.")

Expand Down
17 changes: 9 additions & 8 deletions tests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
pytestmark = pytest.mark.usefixtures(
"application",
"app_runner",
"tcp_site",
"multi_host_tcp_site",
"chip_native",
"chip_logging",
"chip_stack",
Expand All @@ -38,11 +38,11 @@ def app_runner_fixture() -> Generator[MagicMock, None, None]:
yield app_runner


@pytest.fixture(name="tcp_site")
def tcp_site_fixture() -> Generator[MagicMock, None, None]:
@pytest.fixture(name="multi_host_tcp_site")
def multi_host_tcp_site_fixture() -> Generator[MagicMock, None, None]:
"""Return a mocked tcp site."""
with patch("matter_server.server.server.web.TCPSite", autospec=True) as tcp_site:
yield tcp_site
with patch("matter_server.server.server.MultiHostTCPSite", autospec=True) as multi_host_tcp_site:
yield multi_host_tcp_site


@pytest.fixture(name="chip_native")
Expand Down Expand Up @@ -108,7 +108,7 @@ async def server_fixture() -> AsyncGenerator[MatterServer, None]:
async def test_server_start(
application: MagicMock,
app_runner: MagicMock,
tcp_site: MagicMock,
multi_host_tcp_site: MagicMock,
server: MatterServer,
storage_controller: MagicMock,
) -> None:
Expand All @@ -123,13 +123,14 @@ async def test_server_start(
assert add_route.call_args_list[1][0][1] == "/"
assert app_runner.call_count == 1
assert app_runner.return_value.setup.call_count == 1
assert tcp_site.call_count == 1
assert tcp_site.return_value.start.call_count == 1
assert multi_host_tcp_site.call_count == 1
assert multi_host_tcp_site.return_value.start.call_count == 1
assert storage_controller.return_value.start.call_count == 1
assert server.storage_path == "test_storage_path"
assert server.vendor_id == 1234
assert server.fabric_id == 5678
assert server.port == 5580
assert server.listen_addresses == None
assert APICommand.SERVER_INFO in server.command_handlers
assert APICommand.SERVER_DIAGNOSTICS in server.command_handlers
assert APICommand.GET_NODES in server.command_handlers
Expand Down
Loading