From 4e7d396e5bd7258e7b5d5377b07f20ec5b789d96 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 25 Apr 2025 15:18:09 -1000 Subject: [PATCH] Add WebSocket API to zeroconf to observe discovery (#143540) * Add WebSocket API to zeroconf to observe discovery * Add WebSocket API to zeroconf to observe discovery * increase timeout * cover * cover * cover * cover * cover * cover * fix lasting side effects * cleanup merge * format --- homeassistant/components/zeroconf/__init__.py | 2 + .../components/zeroconf/discovery.py | 27 ++- .../components/zeroconf/websocket_api.py | 163 +++++++++++++++ tests/components/zeroconf/test_usage.py | 97 +++++---- .../components/zeroconf/test_websocket_api.py | 194 ++++++++++++++++++ 5 files changed, 439 insertions(+), 44 deletions(-) create mode 100644 homeassistant/components/zeroconf/websocket_api.py create mode 100644 tests/components/zeroconf/test_websocket_api.py diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index 383276d645f..311c42ee18e 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -36,6 +36,7 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_homekit, async_get_zeroconf, bind_hass from homeassistant.setup import async_when_setup_or_start +from . import websocket_api from .const import DOMAIN, ZEROCONF_TYPE from .discovery import ( # noqa: F401 DATA_DISCOVERY, @@ -198,6 +199,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) await discovery.async_setup() hass.data[DATA_DISCOVERY] = discovery + websocket_api.async_setup(hass) async def _async_zeroconf_hass_start(hass: HomeAssistant, comp: str) -> None: """Expose Home Assistant on zeroconf when it starts. diff --git a/homeassistant/components/zeroconf/discovery.py b/homeassistant/components/zeroconf/discovery.py index 0ea0e4c1619..e9b4508caee 100644 --- a/homeassistant/components/zeroconf/discovery.py +++ b/homeassistant/components/zeroconf/discovery.py @@ -2,9 +2,10 @@ from __future__ import annotations +from collections.abc import Callable import contextlib from fnmatch import translate -from functools import lru_cache +from functools import lru_cache, partial from ipaddress import IPv4Address, IPv6Address import logging import re @@ -190,6 +191,26 @@ class ZeroconfDiscovery: self.homekit_model_lookups = homekit_model_lookups self.homekit_model_matchers = homekit_model_matchers self.async_service_browser: AsyncServiceBrowser | None = None + self._service_update_listeners: set[Callable[[AsyncServiceInfo], None]] = set() + self._service_removed_listeners: set[Callable[[str], None]] = set() + + @callback + def async_register_service_update_listener( + self, + listener: Callable[[AsyncServiceInfo], None], + ) -> Callable[[], None]: + """Register a service update listener.""" + self._service_update_listeners.add(listener) + return partial(self._service_update_listeners.remove, listener) + + @callback + def async_register_service_removed_listener( + self, + listener: Callable[[str], None], + ) -> Callable[[], None]: + """Register a service removed listener.""" + self._service_removed_listeners.add(listener) + return partial(self._service_removed_listeners.remove, listener) async def async_setup(self) -> None: """Start discovery.""" @@ -258,6 +279,8 @@ class ZeroconfDiscovery: if state_change is ServiceStateChange.Removed: self._async_dismiss_discoveries(name) + for listener in self._service_removed_listeners: + listener(name) return self._async_service_update(zeroconf, service_type, name) @@ -304,6 +327,8 @@ class ZeroconfDiscovery: self, async_service_info: AsyncServiceInfo, service_type: str, name: str ) -> None: """Process a zeroconf update.""" + for listener in self._service_update_listeners: + listener(async_service_info) info = info_from_service(async_service_info) if not info: # Prevent the browser thread from collapsing diff --git a/homeassistant/components/zeroconf/websocket_api.py b/homeassistant/components/zeroconf/websocket_api.py new file mode 100644 index 00000000000..3a1881e6f4e --- /dev/null +++ b/homeassistant/components/zeroconf/websocket_api.py @@ -0,0 +1,163 @@ +"""The zeroconf integration websocket apis.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from functools import partial +from itertools import chain +import logging +from typing import Any, cast + +import voluptuous as vol +from zeroconf import BadTypeInNameException, DNSPointer, Zeroconf, current_time_millis +from zeroconf.asyncio import AsyncServiceInfo, IPVersion + +from homeassistant.components import websocket_api +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.json import json_bytes + +from .const import DOMAIN, REQUEST_TIMEOUT +from .discovery import DATA_DISCOVERY, ZeroconfDiscovery +from .models import HaAsyncZeroconf + +_LOGGER = logging.getLogger(__name__) +CLASS_IN = 1 +TYPE_PTR = 12 + + +@callback +def async_setup(hass: HomeAssistant) -> None: + """Set up the zeroconf websocket API.""" + websocket_api.async_register_command(hass, ws_subscribe_discovery) + + +def serialize_service_info(service_info: AsyncServiceInfo) -> dict[str, Any]: + """Serialize an AsyncServiceInfo object.""" + return { + "name": service_info.name, + "type": service_info.type, + "port": service_info.port, + "properties": service_info.decoded_properties, + "ip_addresses": [ + str(ip) for ip in service_info.ip_addresses_by_version(IPVersion.All) + ], + } + + +class _DiscoverySubscription: + """Class to hold and manage the subscription data.""" + + def __init__( + self, + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + ws_msg_id: int, + aiozc: HaAsyncZeroconf, + discovery: ZeroconfDiscovery, + ) -> None: + """Initialize the subscription data.""" + self.hass = hass + self.discovery = discovery + self.aiozc = aiozc + self.ws_msg_id = ws_msg_id + self.connection = connection + + @callback + def _async_unsubscribe( + self, cancel_callbacks: tuple[Callable[[], None], ...] + ) -> None: + """Unsubscribe the callback.""" + for cancel_callback in cancel_callbacks: + cancel_callback() + + async def async_start(self) -> None: + """Start the subscription.""" + connection = self.connection + listeners = ( + self.discovery.async_register_service_update_listener( + self._async_on_update + ), + self.discovery.async_register_service_removed_listener( + self._async_on_remove + ), + ) + connection.subscriptions[self.ws_msg_id] = partial( + self._async_unsubscribe, listeners + ) + self.connection.send_message( + json_bytes(websocket_api.result_message(self.ws_msg_id)) + ) + await self._async_update_from_cache() + + async def _async_update_from_cache(self) -> None: + """Load the records from the cache.""" + tasks: list[asyncio.Task[None]] = [] + now = current_time_millis() + for record in self._async_get_ptr_records(self.aiozc.zeroconf): + try: + info = AsyncServiceInfo(record.name, record.alias) + except BadTypeInNameException as ex: + _LOGGER.debug( + "Ignoring record with bad type in name: %s: %s", record.alias, ex + ) + continue + if info.load_from_cache(self.aiozc.zeroconf, now): + self._async_on_update(info) + else: + tasks.append( + self.hass.async_create_background_task( + self._async_handle_service(info), + f"zeroconf resolve {record.alias}", + ), + ) + + if tasks: + await asyncio.gather(*tasks) + + def _async_get_ptr_records(self, zc: Zeroconf) -> list[DNSPointer]: + """Return all PTR records for the HAP type.""" + return cast( + list[DNSPointer], + list( + chain.from_iterable( + zc.cache.async_all_by_details(zc_type, TYPE_PTR, CLASS_IN) + for zc_type in self.discovery.zeroconf_types + ) + ), + ) + + async def _async_handle_service(self, info: AsyncServiceInfo) -> None: + """Add a device that became visible via zeroconf.""" + await info.async_request(self.aiozc.zeroconf, REQUEST_TIMEOUT) + self._async_on_update(info) + + def _async_event_message(self, message: dict[str, Any]) -> None: + self.connection.send_message( + json_bytes(websocket_api.event_message(self.ws_msg_id, message)) + ) + + def _async_on_update(self, info: AsyncServiceInfo) -> None: + if info.type in self.discovery.zeroconf_types: + self._async_event_message({"add": [serialize_service_info(info)]}) + + def _async_on_remove(self, name: str) -> None: + self._async_event_message({"remove": [{"name": name}]}) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "zeroconf/subscribe_discovery", + } +) +@websocket_api.async_response +async def ws_subscribe_discovery( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle subscribe advertisements websocket command.""" + discovery = hass.data[DATA_DISCOVERY] + aiozc: HaAsyncZeroconf = hass.data[DOMAIN] + await _DiscoverySubscription( + hass, connection, msg["id"], aiozc, discovery + ).async_start() diff --git a/tests/components/zeroconf/test_usage.py b/tests/components/zeroconf/test_usage.py index e79f2319915..2e186bc39d0 100644 --- a/tests/components/zeroconf/test_usage.py +++ b/tests/components/zeroconf/test_usage.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, patch import pytest -import zeroconf from homeassistant.components.zeroconf import async_get_instance from homeassistant.components.zeroconf.usage import install_multiple_zeroconf_catcher @@ -15,6 +14,16 @@ from tests.common import extract_stack_to_frame DOMAIN = "zeroconf" +class MockZeroconf: + """Mock Zeroconf class.""" + + def __init__(self, *args, **kwargs) -> None: + """Initialize the mock.""" + + def __new__(cls, *args, **kwargs) -> "MockZeroconf": + """Return the shared instance.""" + + @pytest.mark.usefixtures("mock_async_zeroconf", "mock_zeroconf") async def test_multiple_zeroconf_instances( hass: HomeAssistant, caplog: pytest.LogCaptureFixture @@ -24,12 +33,13 @@ async def test_multiple_zeroconf_instances( zeroconf_instance = await async_get_instance(hass) - install_multiple_zeroconf_catcher(zeroconf_instance) + with patch("zeroconf.Zeroconf", MockZeroconf): + install_multiple_zeroconf_catcher(zeroconf_instance) - new_zeroconf_instance = zeroconf.Zeroconf() - assert new_zeroconf_instance == zeroconf_instance + new_zeroconf_instance = MockZeroconf() + assert new_zeroconf_instance == zeroconf_instance - assert "Zeroconf" in caplog.text + assert "Zeroconf" in caplog.text @pytest.mark.usefixtures("mock_async_zeroconf", "mock_zeroconf") @@ -41,44 +51,45 @@ async def test_multiple_zeroconf_instances_gives_shared( zeroconf_instance = await async_get_instance(hass) - install_multiple_zeroconf_catcher(zeroconf_instance) + with patch("zeroconf.Zeroconf", MockZeroconf): + install_multiple_zeroconf_catcher(zeroconf_instance) - correct_frame = Mock( - filename="/config/custom_components/burncpu/light.py", - lineno="23", - line="self.light.is_on", - ) - with ( - patch( - "homeassistant.helpers.frame.linecache.getline", - return_value=correct_frame.line, - ), - patch( - "homeassistant.helpers.frame.get_current_frame", - return_value=extract_stack_to_frame( - [ - Mock( - filename="/home/dev/homeassistant/core.py", - lineno="23", - line="do_something()", - ), - correct_frame, - Mock( - filename="/home/dev/homeassistant/components/zeroconf/usage.py", - lineno="23", - line="self.light.is_on", - ), - Mock( - filename="/home/dev/mdns/lights.py", - lineno="2", - line="something()", - ), - ] + correct_frame = Mock( + filename="/config/custom_components/burncpu/light.py", + lineno="23", + line="self.light.is_on", + ) + with ( + patch( + "homeassistant.helpers.frame.linecache.getline", + return_value=correct_frame.line, ), - ), - ): - assert zeroconf.Zeroconf() == zeroconf_instance + patch( + "homeassistant.helpers.frame.get_current_frame", + return_value=extract_stack_to_frame( + [ + Mock( + filename="/home/dev/homeassistant/core.py", + lineno="23", + line="do_something()", + ), + correct_frame, + Mock( + filename="/home/dev/homeassistant/components/zeroconf/usage.py", + lineno="23", + line="self.light.is_on", + ), + Mock( + filename="/home/dev/mdns/lights.py", + lineno="2", + line="something()", + ), + ] + ), + ), + ): + assert MockZeroconf() == zeroconf_instance - assert "custom_components/burncpu/light.py" in caplog.text - assert "23" in caplog.text - assert "self.light.is_on" in caplog.text + assert "custom_components/burncpu/light.py" in caplog.text + assert "23" in caplog.text + assert "self.light.is_on" in caplog.text diff --git a/tests/components/zeroconf/test_websocket_api.py b/tests/components/zeroconf/test_websocket_api.py new file mode 100644 index 00000000000..9677b3e34fd --- /dev/null +++ b/tests/components/zeroconf/test_websocket_api.py @@ -0,0 +1,194 @@ +"""The tests for the zeroconf WebSocket API.""" + +import asyncio +import socket +from unittest.mock import patch + +from zeroconf import ( + DNSAddress, + DNSPointer, + DNSService, + DNSText, + RecordUpdate, + const, + current_time_millis, +) + +from homeassistant.components.zeroconf import DOMAIN, async_get_async_instance +from homeassistant.core import HomeAssistant +from homeassistant.generated import zeroconf as zc_gen +from homeassistant.setup import async_setup_component + +from tests.typing import WebSocketGenerator + + +async def test_subscribe_discovery( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test zeroconf subscribe_discovery.""" + instance = await async_get_async_instance(hass) + instance.zeroconf.cache.async_add_records( + [ + DNSPointer( + "_fakeservice._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + "wrong._wrongservice._tcp.local.", + ), + DNSPointer( + "_fakeservice._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + "foo2._fakeservice._tcp.local.", + ), + DNSService( + "foo2._fakeservice._tcp.local.", + const._TYPE_SRV, + const._CLASS_IN, + const._DNS_OTHER_TTL, + 0, + 0, + 1234, + "foo2.local.", + ), + DNSAddress( + "foo2.local.", + const._TYPE_A, + const._CLASS_IN, + const._DNS_HOST_TTL, + socket.inet_aton("127.0.0.1"), + ), + DNSText( + "foo2.local.", + const._TYPE_TXT, + const._CLASS_IN, + const._DNS_HOST_TTL, + b"\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5" + b"\x05c#=12\x04s#=1", + ), + DNSPointer( + "_fakeservice._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + "foo3._fakeservice._tcp.local.", + ), + DNSService( + "foo3._fakeservice._tcp.local.", + const._TYPE_SRV, + const._CLASS_IN, + const._DNS_OTHER_TTL, + 0, + 0, + 1234, + "foo3.local.", + ), + DNSText( + "foo3.local.", + const._TYPE_TXT, + const._CLASS_IN, + const._DNS_HOST_TTL, + b"\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5" + b"\x05c#=12\x04s#=1", + ), + ] + ) + with patch.dict( + zc_gen.ZEROCONF, + {"_fakeservice._tcp.local.": []}, + clear=True, + ): + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}}) + await hass.async_block_till_done() + client = await hass_ws_client() + await client.send_json( + { + "id": 1, + "type": "zeroconf/subscribe_discovery", + } + ) + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["success"] + + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "add": [ + { + "ip_addresses": ["127.0.0.1"], + "name": "foo2._fakeservice._tcp.local.", + "port": 1234, + "properties": {}, + "type": "_fakeservice._tcp.local.", + } + ] + } + + # now late inject the address record + records = [ + DNSAddress( + "foo3.local.", + const._TYPE_A, + const._CLASS_IN, + const._DNS_HOST_TTL, + socket.inet_aton("127.0.0.1"), + ), + ] + instance.zeroconf.cache.async_add_records(records) + instance.zeroconf.record_manager.async_updates( + current_time_millis(), + [RecordUpdate(record, None) for record in records], + ) + # Now for the add + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "add": [ + { + "ip_addresses": ["127.0.0.1"], + "name": "foo3._fakeservice._tcp.local.", + "port": 1234, + "properties": {}, + "type": "_fakeservice._tcp.local.", + } + ] + } + # Now for the update + async with asyncio.timeout(1): + response = await client.receive_json() + assert response["event"] == { + "add": [ + { + "ip_addresses": ["127.0.0.1"], + "name": "foo3._fakeservice._tcp.local.", + "port": 1234, + "properties": {}, + "type": "_fakeservice._tcp.local.", + } + ] + } + + # now move time forward and remove the record + future = current_time_millis() + (4500 * 1000) + records = instance.zeroconf.cache.async_expire(future) + record_updates = [RecordUpdate(record, record) for record in records] + instance.zeroconf.record_manager.async_updates(future, record_updates) + instance.zeroconf.record_manager.async_updates_complete(True) + + removes: set[str] = set() + for _ in range(3): + async with asyncio.timeout(1): + response = await client.receive_json() + assert "remove" in response["event"] + removes.add(next(iter(response["event"]["remove"]))["name"]) + + assert len(removes) == 3 + assert removes == { + "foo2._fakeservice._tcp.local.", + "foo3._fakeservice._tcp.local.", + "wrong._wrongservice._tcp.local.", + }