Add WebSocket API to zeroconf to observe discovery (#143540)

* Add WebSocket API to zeroconf to observe discovery

* Add WebSocket API to zeroconf to observe discovery

* increase timeout

* cover

* cover

* cover

* cover

* cover

* cover

* fix lasting side effects

* cleanup merge

* format
This commit is contained in:
J. Nick Koston 2025-04-25 15:18:09 -10:00 committed by GitHub
parent 34d17ca458
commit 4e7d396e5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 439 additions and 44 deletions

View File

@ -36,6 +36,7 @@ from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_homekit, async_get_zeroconf, bind_hass
from homeassistant.setup import async_when_setup_or_start
from . import websocket_api
from .const import DOMAIN, ZEROCONF_TYPE
from .discovery import ( # noqa: F401
DATA_DISCOVERY,
@ -198,6 +199,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
await discovery.async_setup()
hass.data[DATA_DISCOVERY] = discovery
websocket_api.async_setup(hass)
async def _async_zeroconf_hass_start(hass: HomeAssistant, comp: str) -> None:
"""Expose Home Assistant on zeroconf when it starts.

View File

@ -2,9 +2,10 @@
from __future__ import annotations
from collections.abc import Callable
import contextlib
from fnmatch import translate
from functools import lru_cache
from functools import lru_cache, partial
from ipaddress import IPv4Address, IPv6Address
import logging
import re
@ -190,6 +191,26 @@ class ZeroconfDiscovery:
self.homekit_model_lookups = homekit_model_lookups
self.homekit_model_matchers = homekit_model_matchers
self.async_service_browser: AsyncServiceBrowser | None = None
self._service_update_listeners: set[Callable[[AsyncServiceInfo], None]] = set()
self._service_removed_listeners: set[Callable[[str], None]] = set()
@callback
def async_register_service_update_listener(
self,
listener: Callable[[AsyncServiceInfo], None],
) -> Callable[[], None]:
"""Register a service update listener."""
self._service_update_listeners.add(listener)
return partial(self._service_update_listeners.remove, listener)
@callback
def async_register_service_removed_listener(
self,
listener: Callable[[str], None],
) -> Callable[[], None]:
"""Register a service removed listener."""
self._service_removed_listeners.add(listener)
return partial(self._service_removed_listeners.remove, listener)
async def async_setup(self) -> None:
"""Start discovery."""
@ -258,6 +279,8 @@ class ZeroconfDiscovery:
if state_change is ServiceStateChange.Removed:
self._async_dismiss_discoveries(name)
for listener in self._service_removed_listeners:
listener(name)
return
self._async_service_update(zeroconf, service_type, name)
@ -304,6 +327,8 @@ class ZeroconfDiscovery:
self, async_service_info: AsyncServiceInfo, service_type: str, name: str
) -> None:
"""Process a zeroconf update."""
for listener in self._service_update_listeners:
listener(async_service_info)
info = info_from_service(async_service_info)
if not info:
# Prevent the browser thread from collapsing

View File

@ -0,0 +1,163 @@
"""The zeroconf integration websocket apis."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import partial
from itertools import chain
import logging
from typing import Any, cast
import voluptuous as vol
from zeroconf import BadTypeInNameException, DNSPointer, Zeroconf, current_time_millis
from zeroconf.asyncio import AsyncServiceInfo, IPVersion
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.json import json_bytes
from .const import DOMAIN, REQUEST_TIMEOUT
from .discovery import DATA_DISCOVERY, ZeroconfDiscovery
from .models import HaAsyncZeroconf
_LOGGER = logging.getLogger(__name__)
CLASS_IN = 1
TYPE_PTR = 12
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up the zeroconf websocket API."""
websocket_api.async_register_command(hass, ws_subscribe_discovery)
def serialize_service_info(service_info: AsyncServiceInfo) -> dict[str, Any]:
"""Serialize an AsyncServiceInfo object."""
return {
"name": service_info.name,
"type": service_info.type,
"port": service_info.port,
"properties": service_info.decoded_properties,
"ip_addresses": [
str(ip) for ip in service_info.ip_addresses_by_version(IPVersion.All)
],
}
class _DiscoverySubscription:
"""Class to hold and manage the subscription data."""
def __init__(
self,
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
ws_msg_id: int,
aiozc: HaAsyncZeroconf,
discovery: ZeroconfDiscovery,
) -> None:
"""Initialize the subscription data."""
self.hass = hass
self.discovery = discovery
self.aiozc = aiozc
self.ws_msg_id = ws_msg_id
self.connection = connection
@callback
def _async_unsubscribe(
self, cancel_callbacks: tuple[Callable[[], None], ...]
) -> None:
"""Unsubscribe the callback."""
for cancel_callback in cancel_callbacks:
cancel_callback()
async def async_start(self) -> None:
"""Start the subscription."""
connection = self.connection
listeners = (
self.discovery.async_register_service_update_listener(
self._async_on_update
),
self.discovery.async_register_service_removed_listener(
self._async_on_remove
),
)
connection.subscriptions[self.ws_msg_id] = partial(
self._async_unsubscribe, listeners
)
self.connection.send_message(
json_bytes(websocket_api.result_message(self.ws_msg_id))
)
await self._async_update_from_cache()
async def _async_update_from_cache(self) -> None:
"""Load the records from the cache."""
tasks: list[asyncio.Task[None]] = []
now = current_time_millis()
for record in self._async_get_ptr_records(self.aiozc.zeroconf):
try:
info = AsyncServiceInfo(record.name, record.alias)
except BadTypeInNameException as ex:
_LOGGER.debug(
"Ignoring record with bad type in name: %s: %s", record.alias, ex
)
continue
if info.load_from_cache(self.aiozc.zeroconf, now):
self._async_on_update(info)
else:
tasks.append(
self.hass.async_create_background_task(
self._async_handle_service(info),
f"zeroconf resolve {record.alias}",
),
)
if tasks:
await asyncio.gather(*tasks)
def _async_get_ptr_records(self, zc: Zeroconf) -> list[DNSPointer]:
"""Return all PTR records for the HAP type."""
return cast(
list[DNSPointer],
list(
chain.from_iterable(
zc.cache.async_all_by_details(zc_type, TYPE_PTR, CLASS_IN)
for zc_type in self.discovery.zeroconf_types
)
),
)
async def _async_handle_service(self, info: AsyncServiceInfo) -> None:
"""Add a device that became visible via zeroconf."""
await info.async_request(self.aiozc.zeroconf, REQUEST_TIMEOUT)
self._async_on_update(info)
def _async_event_message(self, message: dict[str, Any]) -> None:
self.connection.send_message(
json_bytes(websocket_api.event_message(self.ws_msg_id, message))
)
def _async_on_update(self, info: AsyncServiceInfo) -> None:
if info.type in self.discovery.zeroconf_types:
self._async_event_message({"add": [serialize_service_info(info)]})
def _async_on_remove(self, name: str) -> None:
self._async_event_message({"remove": [{"name": name}]})
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "zeroconf/subscribe_discovery",
}
)
@websocket_api.async_response
async def ws_subscribe_discovery(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe advertisements websocket command."""
discovery = hass.data[DATA_DISCOVERY]
aiozc: HaAsyncZeroconf = hass.data[DOMAIN]
await _DiscoverySubscription(
hass, connection, msg["id"], aiozc, discovery
).async_start()

View File

@ -3,7 +3,6 @@
from unittest.mock import Mock, patch
import pytest
import zeroconf
from homeassistant.components.zeroconf import async_get_instance
from homeassistant.components.zeroconf.usage import install_multiple_zeroconf_catcher
@ -15,6 +14,16 @@ from tests.common import extract_stack_to_frame
DOMAIN = "zeroconf"
class MockZeroconf:
"""Mock Zeroconf class."""
def __init__(self, *args, **kwargs) -> None:
"""Initialize the mock."""
def __new__(cls, *args, **kwargs) -> "MockZeroconf":
"""Return the shared instance."""
@pytest.mark.usefixtures("mock_async_zeroconf", "mock_zeroconf")
async def test_multiple_zeroconf_instances(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
@ -24,12 +33,13 @@ async def test_multiple_zeroconf_instances(
zeroconf_instance = await async_get_instance(hass)
install_multiple_zeroconf_catcher(zeroconf_instance)
with patch("zeroconf.Zeroconf", MockZeroconf):
install_multiple_zeroconf_catcher(zeroconf_instance)
new_zeroconf_instance = zeroconf.Zeroconf()
assert new_zeroconf_instance == zeroconf_instance
new_zeroconf_instance = MockZeroconf()
assert new_zeroconf_instance == zeroconf_instance
assert "Zeroconf" in caplog.text
assert "Zeroconf" in caplog.text
@pytest.mark.usefixtures("mock_async_zeroconf", "mock_zeroconf")
@ -41,44 +51,45 @@ async def test_multiple_zeroconf_instances_gives_shared(
zeroconf_instance = await async_get_instance(hass)
install_multiple_zeroconf_catcher(zeroconf_instance)
with patch("zeroconf.Zeroconf", MockZeroconf):
install_multiple_zeroconf_catcher(zeroconf_instance)
correct_frame = Mock(
filename="/config/custom_components/burncpu/light.py",
lineno="23",
line="self.light.is_on",
)
with (
patch(
"homeassistant.helpers.frame.linecache.getline",
return_value=correct_frame.line,
),
patch(
"homeassistant.helpers.frame.get_current_frame",
return_value=extract_stack_to_frame(
[
Mock(
filename="/home/dev/homeassistant/core.py",
lineno="23",
line="do_something()",
),
correct_frame,
Mock(
filename="/home/dev/homeassistant/components/zeroconf/usage.py",
lineno="23",
line="self.light.is_on",
),
Mock(
filename="/home/dev/mdns/lights.py",
lineno="2",
line="something()",
),
]
correct_frame = Mock(
filename="/config/custom_components/burncpu/light.py",
lineno="23",
line="self.light.is_on",
)
with (
patch(
"homeassistant.helpers.frame.linecache.getline",
return_value=correct_frame.line,
),
),
):
assert zeroconf.Zeroconf() == zeroconf_instance
patch(
"homeassistant.helpers.frame.get_current_frame",
return_value=extract_stack_to_frame(
[
Mock(
filename="/home/dev/homeassistant/core.py",
lineno="23",
line="do_something()",
),
correct_frame,
Mock(
filename="/home/dev/homeassistant/components/zeroconf/usage.py",
lineno="23",
line="self.light.is_on",
),
Mock(
filename="/home/dev/mdns/lights.py",
lineno="2",
line="something()",
),
]
),
),
):
assert MockZeroconf() == zeroconf_instance
assert "custom_components/burncpu/light.py" in caplog.text
assert "23" in caplog.text
assert "self.light.is_on" in caplog.text
assert "custom_components/burncpu/light.py" in caplog.text
assert "23" in caplog.text
assert "self.light.is_on" in caplog.text

View File

@ -0,0 +1,194 @@
"""The tests for the zeroconf WebSocket API."""
import asyncio
import socket
from unittest.mock import patch
from zeroconf import (
DNSAddress,
DNSPointer,
DNSService,
DNSText,
RecordUpdate,
const,
current_time_millis,
)
from homeassistant.components.zeroconf import DOMAIN, async_get_async_instance
from homeassistant.core import HomeAssistant
from homeassistant.generated import zeroconf as zc_gen
from homeassistant.setup import async_setup_component
from tests.typing import WebSocketGenerator
async def test_subscribe_discovery(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test zeroconf subscribe_discovery."""
instance = await async_get_async_instance(hass)
instance.zeroconf.cache.async_add_records(
[
DNSPointer(
"_fakeservice._tcp.local.",
const._TYPE_PTR,
const._CLASS_IN,
const._DNS_OTHER_TTL,
"wrong._wrongservice._tcp.local.",
),
DNSPointer(
"_fakeservice._tcp.local.",
const._TYPE_PTR,
const._CLASS_IN,
const._DNS_OTHER_TTL,
"foo2._fakeservice._tcp.local.",
),
DNSService(
"foo2._fakeservice._tcp.local.",
const._TYPE_SRV,
const._CLASS_IN,
const._DNS_OTHER_TTL,
0,
0,
1234,
"foo2.local.",
),
DNSAddress(
"foo2.local.",
const._TYPE_A,
const._CLASS_IN,
const._DNS_HOST_TTL,
socket.inet_aton("127.0.0.1"),
),
DNSText(
"foo2.local.",
const._TYPE_TXT,
const._CLASS_IN,
const._DNS_HOST_TTL,
b"\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5"
b"\x05c#=12\x04s#=1",
),
DNSPointer(
"_fakeservice._tcp.local.",
const._TYPE_PTR,
const._CLASS_IN,
const._DNS_OTHER_TTL,
"foo3._fakeservice._tcp.local.",
),
DNSService(
"foo3._fakeservice._tcp.local.",
const._TYPE_SRV,
const._CLASS_IN,
const._DNS_OTHER_TTL,
0,
0,
1234,
"foo3.local.",
),
DNSText(
"foo3.local.",
const._TYPE_TXT,
const._CLASS_IN,
const._DNS_HOST_TTL,
b"\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5"
b"\x05c#=12\x04s#=1",
),
]
)
with patch.dict(
zc_gen.ZEROCONF,
{"_fakeservice._tcp.local.": []},
clear=True,
):
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
client = await hass_ws_client()
await client.send_json(
{
"id": 1,
"type": "zeroconf/subscribe_discovery",
}
)
async with asyncio.timeout(1):
response = await client.receive_json()
assert response["success"]
async with asyncio.timeout(1):
response = await client.receive_json()
assert response["event"] == {
"add": [
{
"ip_addresses": ["127.0.0.1"],
"name": "foo2._fakeservice._tcp.local.",
"port": 1234,
"properties": {},
"type": "_fakeservice._tcp.local.",
}
]
}
# now late inject the address record
records = [
DNSAddress(
"foo3.local.",
const._TYPE_A,
const._CLASS_IN,
const._DNS_HOST_TTL,
socket.inet_aton("127.0.0.1"),
),
]
instance.zeroconf.cache.async_add_records(records)
instance.zeroconf.record_manager.async_updates(
current_time_millis(),
[RecordUpdate(record, None) for record in records],
)
# Now for the add
async with asyncio.timeout(1):
response = await client.receive_json()
assert response["event"] == {
"add": [
{
"ip_addresses": ["127.0.0.1"],
"name": "foo3._fakeservice._tcp.local.",
"port": 1234,
"properties": {},
"type": "_fakeservice._tcp.local.",
}
]
}
# Now for the update
async with asyncio.timeout(1):
response = await client.receive_json()
assert response["event"] == {
"add": [
{
"ip_addresses": ["127.0.0.1"],
"name": "foo3._fakeservice._tcp.local.",
"port": 1234,
"properties": {},
"type": "_fakeservice._tcp.local.",
}
]
}
# now move time forward and remove the record
future = current_time_millis() + (4500 * 1000)
records = instance.zeroconf.cache.async_expire(future)
record_updates = [RecordUpdate(record, record) for record in records]
instance.zeroconf.record_manager.async_updates(future, record_updates)
instance.zeroconf.record_manager.async_updates_complete(True)
removes: set[str] = set()
for _ in range(3):
async with asyncio.timeout(1):
response = await client.receive_json()
assert "remove" in response["event"]
removes.add(next(iter(response["event"]["remove"]))["name"])
assert len(removes) == 3
assert removes == {
"foo2._fakeservice._tcp.local.",
"foo3._fakeservice._tcp.local.",
"wrong._wrongservice._tcp.local.",
}