mirror of
https://github.com/home-assistant/core.git
synced 2025-07-10 06:47:09 +00:00
Add WebSocket API to zeroconf to observe discovery (#143540)
* Add WebSocket API to zeroconf to observe discovery * Add WebSocket API to zeroconf to observe discovery * increase timeout * cover * cover * cover * cover * cover * cover * fix lasting side effects * cleanup merge * format
This commit is contained in:
parent
34d17ca458
commit
4e7d396e5b
@ -36,6 +36,7 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import async_get_homekit, async_get_zeroconf, bind_hass
|
||||
from homeassistant.setup import async_when_setup_or_start
|
||||
|
||||
from . import websocket_api
|
||||
from .const import DOMAIN, ZEROCONF_TYPE
|
||||
from .discovery import ( # noqa: F401
|
||||
DATA_DISCOVERY,
|
||||
@ -198,6 +199,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
)
|
||||
await discovery.async_setup()
|
||||
hass.data[DATA_DISCOVERY] = discovery
|
||||
websocket_api.async_setup(hass)
|
||||
|
||||
async def _async_zeroconf_hass_start(hass: HomeAssistant, comp: str) -> None:
|
||||
"""Expose Home Assistant on zeroconf when it starts.
|
||||
|
@ -2,9 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import contextlib
|
||||
from fnmatch import translate
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
import logging
|
||||
import re
|
||||
@ -190,6 +191,26 @@ class ZeroconfDiscovery:
|
||||
self.homekit_model_lookups = homekit_model_lookups
|
||||
self.homekit_model_matchers = homekit_model_matchers
|
||||
self.async_service_browser: AsyncServiceBrowser | None = None
|
||||
self._service_update_listeners: set[Callable[[AsyncServiceInfo], None]] = set()
|
||||
self._service_removed_listeners: set[Callable[[str], None]] = set()
|
||||
|
||||
@callback
|
||||
def async_register_service_update_listener(
|
||||
self,
|
||||
listener: Callable[[AsyncServiceInfo], None],
|
||||
) -> Callable[[], None]:
|
||||
"""Register a service update listener."""
|
||||
self._service_update_listeners.add(listener)
|
||||
return partial(self._service_update_listeners.remove, listener)
|
||||
|
||||
@callback
|
||||
def async_register_service_removed_listener(
|
||||
self,
|
||||
listener: Callable[[str], None],
|
||||
) -> Callable[[], None]:
|
||||
"""Register a service removed listener."""
|
||||
self._service_removed_listeners.add(listener)
|
||||
return partial(self._service_removed_listeners.remove, listener)
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Start discovery."""
|
||||
@ -258,6 +279,8 @@ class ZeroconfDiscovery:
|
||||
|
||||
if state_change is ServiceStateChange.Removed:
|
||||
self._async_dismiss_discoveries(name)
|
||||
for listener in self._service_removed_listeners:
|
||||
listener(name)
|
||||
return
|
||||
|
||||
self._async_service_update(zeroconf, service_type, name)
|
||||
@ -304,6 +327,8 @@ class ZeroconfDiscovery:
|
||||
self, async_service_info: AsyncServiceInfo, service_type: str, name: str
|
||||
) -> None:
|
||||
"""Process a zeroconf update."""
|
||||
for listener in self._service_update_listeners:
|
||||
listener(async_service_info)
|
||||
info = info_from_service(async_service_info)
|
||||
if not info:
|
||||
# Prevent the browser thread from collapsing
|
||||
|
163
homeassistant/components/zeroconf/websocket_api.py
Normal file
163
homeassistant/components/zeroconf/websocket_api.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""The zeroconf integration websocket apis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
from zeroconf import BadTypeInNameException, DNSPointer, Zeroconf, current_time_millis
|
||||
from zeroconf.asyncio import AsyncServiceInfo, IPVersion
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
|
||||
from .const import DOMAIN, REQUEST_TIMEOUT
|
||||
from .discovery import DATA_DISCOVERY, ZeroconfDiscovery
|
||||
from .models import HaAsyncZeroconf
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
CLASS_IN = 1
|
||||
TYPE_PTR = 12
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the zeroconf websocket API."""
|
||||
websocket_api.async_register_command(hass, ws_subscribe_discovery)
|
||||
|
||||
|
||||
def serialize_service_info(service_info: AsyncServiceInfo) -> dict[str, Any]:
|
||||
"""Serialize an AsyncServiceInfo object."""
|
||||
return {
|
||||
"name": service_info.name,
|
||||
"type": service_info.type,
|
||||
"port": service_info.port,
|
||||
"properties": service_info.decoded_properties,
|
||||
"ip_addresses": [
|
||||
str(ip) for ip in service_info.ip_addresses_by_version(IPVersion.All)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class _DiscoverySubscription:
|
||||
"""Class to hold and manage the subscription data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
ws_msg_id: int,
|
||||
aiozc: HaAsyncZeroconf,
|
||||
discovery: ZeroconfDiscovery,
|
||||
) -> None:
|
||||
"""Initialize the subscription data."""
|
||||
self.hass = hass
|
||||
self.discovery = discovery
|
||||
self.aiozc = aiozc
|
||||
self.ws_msg_id = ws_msg_id
|
||||
self.connection = connection
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe(
|
||||
self, cancel_callbacks: tuple[Callable[[], None], ...]
|
||||
) -> None:
|
||||
"""Unsubscribe the callback."""
|
||||
for cancel_callback in cancel_callbacks:
|
||||
cancel_callback()
|
||||
|
||||
async def async_start(self) -> None:
|
||||
"""Start the subscription."""
|
||||
connection = self.connection
|
||||
listeners = (
|
||||
self.discovery.async_register_service_update_listener(
|
||||
self._async_on_update
|
||||
),
|
||||
self.discovery.async_register_service_removed_listener(
|
||||
self._async_on_remove
|
||||
),
|
||||
)
|
||||
connection.subscriptions[self.ws_msg_id] = partial(
|
||||
self._async_unsubscribe, listeners
|
||||
)
|
||||
self.connection.send_message(
|
||||
json_bytes(websocket_api.result_message(self.ws_msg_id))
|
||||
)
|
||||
await self._async_update_from_cache()
|
||||
|
||||
async def _async_update_from_cache(self) -> None:
|
||||
"""Load the records from the cache."""
|
||||
tasks: list[asyncio.Task[None]] = []
|
||||
now = current_time_millis()
|
||||
for record in self._async_get_ptr_records(self.aiozc.zeroconf):
|
||||
try:
|
||||
info = AsyncServiceInfo(record.name, record.alias)
|
||||
except BadTypeInNameException as ex:
|
||||
_LOGGER.debug(
|
||||
"Ignoring record with bad type in name: %s: %s", record.alias, ex
|
||||
)
|
||||
continue
|
||||
if info.load_from_cache(self.aiozc.zeroconf, now):
|
||||
self._async_on_update(info)
|
||||
else:
|
||||
tasks.append(
|
||||
self.hass.async_create_background_task(
|
||||
self._async_handle_service(info),
|
||||
f"zeroconf resolve {record.alias}",
|
||||
),
|
||||
)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _async_get_ptr_records(self, zc: Zeroconf) -> list[DNSPointer]:
|
||||
"""Return all PTR records for the HAP type."""
|
||||
return cast(
|
||||
list[DNSPointer],
|
||||
list(
|
||||
chain.from_iterable(
|
||||
zc.cache.async_all_by_details(zc_type, TYPE_PTR, CLASS_IN)
|
||||
for zc_type in self.discovery.zeroconf_types
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async def _async_handle_service(self, info: AsyncServiceInfo) -> None:
|
||||
"""Add a device that became visible via zeroconf."""
|
||||
await info.async_request(self.aiozc.zeroconf, REQUEST_TIMEOUT)
|
||||
self._async_on_update(info)
|
||||
|
||||
def _async_event_message(self, message: dict[str, Any]) -> None:
|
||||
self.connection.send_message(
|
||||
json_bytes(websocket_api.event_message(self.ws_msg_id, message))
|
||||
)
|
||||
|
||||
def _async_on_update(self, info: AsyncServiceInfo) -> None:
|
||||
if info.type in self.discovery.zeroconf_types:
|
||||
self._async_event_message({"add": [serialize_service_info(info)]})
|
||||
|
||||
def _async_on_remove(self, name: str) -> None:
|
||||
self._async_event_message({"remove": [{"name": name}]})
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "zeroconf/subscribe_discovery",
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def ws_subscribe_discovery(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle subscribe advertisements websocket command."""
|
||||
discovery = hass.data[DATA_DISCOVERY]
|
||||
aiozc: HaAsyncZeroconf = hass.data[DOMAIN]
|
||||
await _DiscoverySubscription(
|
||||
hass, connection, msg["id"], aiozc, discovery
|
||||
).async_start()
|
@ -3,7 +3,6 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import zeroconf
|
||||
|
||||
from homeassistant.components.zeroconf import async_get_instance
|
||||
from homeassistant.components.zeroconf.usage import install_multiple_zeroconf_catcher
|
||||
@ -15,6 +14,16 @@ from tests.common import extract_stack_to_frame
|
||||
DOMAIN = "zeroconf"
|
||||
|
||||
|
||||
class MockZeroconf:
|
||||
"""Mock Zeroconf class."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize the mock."""
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> "MockZeroconf":
|
||||
"""Return the shared instance."""
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_async_zeroconf", "mock_zeroconf")
|
||||
async def test_multiple_zeroconf_instances(
|
||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||
@ -24,9 +33,10 @@ async def test_multiple_zeroconf_instances(
|
||||
|
||||
zeroconf_instance = await async_get_instance(hass)
|
||||
|
||||
with patch("zeroconf.Zeroconf", MockZeroconf):
|
||||
install_multiple_zeroconf_catcher(zeroconf_instance)
|
||||
|
||||
new_zeroconf_instance = zeroconf.Zeroconf()
|
||||
new_zeroconf_instance = MockZeroconf()
|
||||
assert new_zeroconf_instance == zeroconf_instance
|
||||
|
||||
assert "Zeroconf" in caplog.text
|
||||
@ -41,6 +51,7 @@ async def test_multiple_zeroconf_instances_gives_shared(
|
||||
|
||||
zeroconf_instance = await async_get_instance(hass)
|
||||
|
||||
with patch("zeroconf.Zeroconf", MockZeroconf):
|
||||
install_multiple_zeroconf_catcher(zeroconf_instance)
|
||||
|
||||
correct_frame = Mock(
|
||||
@ -77,7 +88,7 @@ async def test_multiple_zeroconf_instances_gives_shared(
|
||||
),
|
||||
),
|
||||
):
|
||||
assert zeroconf.Zeroconf() == zeroconf_instance
|
||||
assert MockZeroconf() == zeroconf_instance
|
||||
|
||||
assert "custom_components/burncpu/light.py" in caplog.text
|
||||
assert "23" in caplog.text
|
||||
|
194
tests/components/zeroconf/test_websocket_api.py
Normal file
194
tests/components/zeroconf/test_websocket_api.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""The tests for the zeroconf WebSocket API."""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
from zeroconf import (
|
||||
DNSAddress,
|
||||
DNSPointer,
|
||||
DNSService,
|
||||
DNSText,
|
||||
RecordUpdate,
|
||||
const,
|
||||
current_time_millis,
|
||||
)
|
||||
|
||||
from homeassistant.components.zeroconf import DOMAIN, async_get_async_instance
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.generated import zeroconf as zc_gen
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_subscribe_discovery(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test zeroconf subscribe_discovery."""
|
||||
instance = await async_get_async_instance(hass)
|
||||
instance.zeroconf.cache.async_add_records(
|
||||
[
|
||||
DNSPointer(
|
||||
"_fakeservice._tcp.local.",
|
||||
const._TYPE_PTR,
|
||||
const._CLASS_IN,
|
||||
const._DNS_OTHER_TTL,
|
||||
"wrong._wrongservice._tcp.local.",
|
||||
),
|
||||
DNSPointer(
|
||||
"_fakeservice._tcp.local.",
|
||||
const._TYPE_PTR,
|
||||
const._CLASS_IN,
|
||||
const._DNS_OTHER_TTL,
|
||||
"foo2._fakeservice._tcp.local.",
|
||||
),
|
||||
DNSService(
|
||||
"foo2._fakeservice._tcp.local.",
|
||||
const._TYPE_SRV,
|
||||
const._CLASS_IN,
|
||||
const._DNS_OTHER_TTL,
|
||||
0,
|
||||
0,
|
||||
1234,
|
||||
"foo2.local.",
|
||||
),
|
||||
DNSAddress(
|
||||
"foo2.local.",
|
||||
const._TYPE_A,
|
||||
const._CLASS_IN,
|
||||
const._DNS_HOST_TTL,
|
||||
socket.inet_aton("127.0.0.1"),
|
||||
),
|
||||
DNSText(
|
||||
"foo2.local.",
|
||||
const._TYPE_TXT,
|
||||
const._CLASS_IN,
|
||||
const._DNS_HOST_TTL,
|
||||
b"\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5"
|
||||
b"\x05c#=12\x04s#=1",
|
||||
),
|
||||
DNSPointer(
|
||||
"_fakeservice._tcp.local.",
|
||||
const._TYPE_PTR,
|
||||
const._CLASS_IN,
|
||||
const._DNS_OTHER_TTL,
|
||||
"foo3._fakeservice._tcp.local.",
|
||||
),
|
||||
DNSService(
|
||||
"foo3._fakeservice._tcp.local.",
|
||||
const._TYPE_SRV,
|
||||
const._CLASS_IN,
|
||||
const._DNS_OTHER_TTL,
|
||||
0,
|
||||
0,
|
||||
1234,
|
||||
"foo3.local.",
|
||||
),
|
||||
DNSText(
|
||||
"foo3.local.",
|
||||
const._TYPE_TXT,
|
||||
const._CLASS_IN,
|
||||
const._DNS_HOST_TTL,
|
||||
b"\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5"
|
||||
b"\x05c#=12\x04s#=1",
|
||||
),
|
||||
]
|
||||
)
|
||||
with patch.dict(
|
||||
zc_gen.ZEROCONF,
|
||||
{"_fakeservice._tcp.local.": []},
|
||||
clear=True,
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
client = await hass_ws_client()
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 1,
|
||||
"type": "zeroconf/subscribe_discovery",
|
||||
}
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {
|
||||
"add": [
|
||||
{
|
||||
"ip_addresses": ["127.0.0.1"],
|
||||
"name": "foo2._fakeservice._tcp.local.",
|
||||
"port": 1234,
|
||||
"properties": {},
|
||||
"type": "_fakeservice._tcp.local.",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# now late inject the address record
|
||||
records = [
|
||||
DNSAddress(
|
||||
"foo3.local.",
|
||||
const._TYPE_A,
|
||||
const._CLASS_IN,
|
||||
const._DNS_HOST_TTL,
|
||||
socket.inet_aton("127.0.0.1"),
|
||||
),
|
||||
]
|
||||
instance.zeroconf.cache.async_add_records(records)
|
||||
instance.zeroconf.record_manager.async_updates(
|
||||
current_time_millis(),
|
||||
[RecordUpdate(record, None) for record in records],
|
||||
)
|
||||
# Now for the add
|
||||
async with asyncio.timeout(1):
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {
|
||||
"add": [
|
||||
{
|
||||
"ip_addresses": ["127.0.0.1"],
|
||||
"name": "foo3._fakeservice._tcp.local.",
|
||||
"port": 1234,
|
||||
"properties": {},
|
||||
"type": "_fakeservice._tcp.local.",
|
||||
}
|
||||
]
|
||||
}
|
||||
# Now for the update
|
||||
async with asyncio.timeout(1):
|
||||
response = await client.receive_json()
|
||||
assert response["event"] == {
|
||||
"add": [
|
||||
{
|
||||
"ip_addresses": ["127.0.0.1"],
|
||||
"name": "foo3._fakeservice._tcp.local.",
|
||||
"port": 1234,
|
||||
"properties": {},
|
||||
"type": "_fakeservice._tcp.local.",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# now move time forward and remove the record
|
||||
future = current_time_millis() + (4500 * 1000)
|
||||
records = instance.zeroconf.cache.async_expire(future)
|
||||
record_updates = [RecordUpdate(record, record) for record in records]
|
||||
instance.zeroconf.record_manager.async_updates(future, record_updates)
|
||||
instance.zeroconf.record_manager.async_updates_complete(True)
|
||||
|
||||
removes: set[str] = set()
|
||||
for _ in range(3):
|
||||
async with asyncio.timeout(1):
|
||||
response = await client.receive_json()
|
||||
assert "remove" in response["event"]
|
||||
removes.add(next(iter(response["event"]["remove"]))["name"])
|
||||
|
||||
assert len(removes) == 3
|
||||
assert removes == {
|
||||
"foo2._fakeservice._tcp.local.",
|
||||
"foo3._fakeservice._tcp.local.",
|
||||
"wrong._wrongservice._tcp.local.",
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user