Implement a polling fallback for USB monitor (#130918)

This commit is contained in:
puddly 2025-01-16 16:14:53 -05:00 committed by GitHub
parent 762bc7b8d1
commit 9331b1572c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 25 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Coroutine, Sequence from collections.abc import Coroutine, Sequence
import dataclasses import dataclasses
from datetime import datetime, timedelta
import fnmatch import fnmatch
from functools import partial from functools import partial
import logging import logging
@ -33,6 +34,7 @@ from homeassistant.helpers.deprecation import (
check_if_deprecated_constant, check_if_deprecated_constant,
dir_with_deprecated_constants, dir_with_deprecated_constants,
) )
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.service_info.usb import UsbServiceInfo as _UsbServiceInfo from homeassistant.helpers.service_info.usb import UsbServiceInfo as _UsbServiceInfo
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import USBMatcher, async_get_usb from homeassistant.loader import USBMatcher, async_get_usb
@ -46,6 +48,7 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
POLLING_MONITOR_SCAN_PERIOD = timedelta(seconds=5)
REQUEST_SCAN_COOLDOWN = 10 # 10 second cooldown REQUEST_SCAN_COOLDOWN = 10 # 10 second cooldown
__all__ = [ __all__ = [
@ -229,7 +232,9 @@ class USBDiscovery:
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Set up USB Discovery.""" """Set up USB Discovery."""
await self._async_start_monitor() if await self._async_supports_monitoring():
await self._async_start_monitor()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self.async_start) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self.async_start)
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self.async_stop) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self.async_stop)
@ -243,26 +248,54 @@ class USBDiscovery:
if self._request_debouncer: if self._request_debouncer:
self._request_debouncer.async_shutdown() self._request_debouncer.async_shutdown()
async def _async_start_monitor(self) -> None: async def _async_supports_monitoring(self) -> bool:
"""Start monitoring hardware with pyudev."""
if not sys.platform.startswith("linux"):
return
info = await system_info.async_get_system_info(self.hass) info = await system_info.async_get_system_info(self.hass)
if info.get("docker"): return not info.get("docker")
return
async def _async_start_monitor(self) -> None:
"""Start monitoring hardware."""
if not await self._async_start_monitor_udev():
_LOGGER.info(
"Falling back to periodic filesystem polling for development, libudev "
"is not present"
)
self._async_start_monitor_polling()
@hass_callback
def _async_start_monitor_polling(self) -> None:
"""Start monitoring hardware with polling (for development only!)."""
async def _scan(event_time: datetime) -> None:
await self._async_scan_serial()
stop_callback = async_track_time_interval(
self.hass, _scan, POLLING_MONITOR_SCAN_PERIOD
)
@hass_callback
def _stop_polling(event: Event) -> None:
stop_callback()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop_polling)
async def _async_start_monitor_udev(self) -> bool:
"""Start monitoring hardware with pyudev. Returns True if successful."""
if not sys.platform.startswith("linux"):
return False
if not ( if not (
observer := await self.hass.async_add_executor_job( observer := await self.hass.async_add_executor_job(
self._get_monitor_observer self._get_monitor_observer
) )
): ):
return return False
def _stop_observer(event: Event) -> None: def _stop_observer(event: Event) -> None:
observer.stop() observer.stop()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop_observer) self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop_observer)
self.observer_active = True self.observer_active = True
return True
def _get_monitor_observer(self) -> MonitorObserver | None: def _get_monitor_observer(self) -> MonitorObserver | None:
"""Get the monitor observer. """Get the monitor observer.

View File

@ -1,7 +1,8 @@
"""Tests for the USB Discovery integration.""" """Tests for the USB Discovery integration."""
import asyncio
from datetime import timedelta
import os import os
import sys
from typing import Any from typing import Any
from unittest.mock import MagicMock, Mock, call, patch, sentinel from unittest.mock import MagicMock, Mock, call, patch, sentinel
@ -59,10 +60,6 @@ def mock_venv():
yield yield
@pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Only works on linux",
)
async def test_observer_discovery( async def test_observer_discovery(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, venv hass: HomeAssistant, hass_ws_client: WebSocketGenerator, venv
) -> None: ) -> None:
@ -93,6 +90,7 @@ async def test_observer_discovery(
return mock_observer return mock_observer
with ( with (
patch("sys.platform", "linux"),
patch("pyudev.Context"), patch("pyudev.Context"),
patch("pyudev.MonitorObserver", new=_create_mock_monitor_observer), patch("pyudev.MonitorObserver", new=_create_mock_monitor_observer),
patch("pyudev.Monitor.filter_by"), patch("pyudev.Monitor.filter_by"),
@ -115,10 +113,65 @@ async def test_observer_discovery(
assert mock_observer.mock_calls == [call.start(), call.__bool__(), call.stop()] assert mock_observer.mock_calls == [call.start(), call.__bool__(), call.stop()]
@pytest.mark.skipif( async def test_polling_discovery(
not sys.platform.startswith("linux"), hass: HomeAssistant, hass_ws_client: WebSocketGenerator, venv
reason="Only works on linux", ) -> None:
) """Test that polling can discover a device without raising an exception."""
new_usb = [{"domain": "test1", "vid": "3039"}]
mock_comports_found_device = asyncio.Event()
def get_comports() -> list:
nonlocal mock_comports
# Only "find" a device after a few invocations
if len(mock_comports.mock_calls) < 5:
return []
mock_comports_found_device.set()
return [
MagicMock(
device=slae_sh_device.device,
vid=12345,
pid=12345,
serial_number=slae_sh_device.serial_number,
manufacturer=slae_sh_device.manufacturer,
description=slae_sh_device.description,
)
]
with (
patch("sys.platform", "linux"),
patch(
"homeassistant.components.usb.USBDiscovery._get_monitor_observer",
return_value=None,
),
patch(
"homeassistant.components.usb.POLLING_MONITOR_SCAN_PERIOD",
timedelta(seconds=0.01),
),
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
patch(
"homeassistant.components.usb.comports", side_effect=get_comports
) as mock_comports,
patch.object(hass.config_entries.flow, "async_init") as mock_config_flow,
):
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
# Wait until a new device is discovered after a few polling attempts
assert len(mock_config_flow.mock_calls) == 0
await mock_comports_found_device.wait()
await hass.async_block_till_done(wait_background_tasks=True)
assert len(mock_config_flow.mock_calls) == 1
assert mock_config_flow.mock_calls[0][1][0] == "test1"
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
async def test_removal_by_observer_before_started( async def test_removal_by_observer_before_started(
hass: HomeAssistant, operating_system hass: HomeAssistant, operating_system
) -> None: ) -> None:
@ -671,10 +724,6 @@ async def test_non_matching_discovered_by_scanner_after_started(
assert len(mock_config_flow.mock_calls) == 0 assert len(mock_config_flow.mock_calls) == 0
@pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Only works on linux",
)
async def test_observer_on_wsl_fallback_without_throwing_exception( async def test_observer_on_wsl_fallback_without_throwing_exception(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, venv hass: HomeAssistant, hass_ws_client: WebSocketGenerator, venv
) -> None: ) -> None:
@ -713,10 +762,6 @@ async def test_observer_on_wsl_fallback_without_throwing_exception(
assert mock_config_flow.mock_calls[0][1][0] == "test1" assert mock_config_flow.mock_calls[0][1][0] == "test1"
@pytest.mark.skipif(
not sys.platform.startswith("linux"),
reason="Only works on linux",
)
async def test_not_discovered_by_observer_before_started_on_docker( async def test_not_discovered_by_observer_before_started_on_docker(
hass: HomeAssistant, docker hass: HomeAssistant, docker
) -> None: ) -> None: