USB device add/remove callbacks (#131224)

This commit is contained in:
puddly 2025-01-16 16:53:15 -05:00 committed by GitHub
parent eb651a8a71
commit 9b66ba61a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 243 additions and 16 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Coroutine, Sequence
from collections.abc import Callable, Coroutine, Sequence
import dataclasses
from datetime import datetime, timedelta
import fnmatch
@ -48,12 +48,15 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
PORT_EVENT_CALLBACK_TYPE = Callable[[set[USBDevice], set[USBDevice]], None]
POLLING_MONITOR_SCAN_PERIOD = timedelta(seconds=5)
REQUEST_SCAN_COOLDOWN = 10 # 10 second cooldown
__all__ = [
"USBCallbackMatcher",
"async_is_plugged_in",
"async_register_port_event_callback",
"async_register_scan_request_callback",
]
@ -85,6 +88,15 @@ def async_register_initial_scan_callback(
return discovery.async_register_initial_scan_callback(callback)
@hass_callback
def async_register_port_event_callback(
hass: HomeAssistant, callback: PORT_EVENT_CALLBACK_TYPE
) -> CALLBACK_TYPE:
"""Register to receive a callback when a USB device is connected or disconnected."""
discovery: USBDiscovery = hass.data[DOMAIN]
return discovery.async_register_port_event_callback(callback)
@hass_callback
def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> bool:
"""Return True is a USB device is present."""
@ -108,8 +120,25 @@ def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> boo
usb_discovery: USBDiscovery = hass.data[DOMAIN]
return any(
_is_matching(USBDevice(*device_tuple), matcher)
for device_tuple in usb_discovery.seen
_is_matching(
USBDevice(
device=device,
vid=vid,
pid=pid,
serial_number=serial_number,
manufacturer=manufacturer,
description=description,
),
matcher,
)
for (
device,
vid,
pid,
serial_number,
manufacturer,
description,
) in usb_discovery.seen
)
@ -229,6 +258,8 @@ class USBDiscovery:
self._request_callbacks: list[CALLBACK_TYPE] = []
self.initial_scan_done = False
self._initial_scan_callbacks: list[CALLBACK_TYPE] = []
self._port_event_callbacks: set[PORT_EVENT_CALLBACK_TYPE] = set()
self._last_processed_devices: set[USBDevice] = set()
async def async_setup(self) -> None:
"""Set up USB Discovery."""
@ -324,20 +355,23 @@ class USBDiscovery:
return None
observer = MonitorObserver(
monitor, callback=self._device_discovered, name="usb-observer"
monitor, callback=self._device_event, name="usb-observer"
)
observer.start()
return observer
def _device_discovered(self, device: Device) -> None:
"""Call when the observer discovers a new usb tty device."""
if device.action != "add":
def _device_event(self, device: Device) -> None:
"""Call when the observer receives a USB device event."""
if device.action not in ("add", "remove"):
return
_LOGGER.debug(
"Discovered Device at path: %s, triggering scan serial",
device.device_path,
_LOGGER.info(
"Received a udev device event %r for %s, triggering scan",
device.action,
device.device_node,
)
self.hass.create_task(self._async_scan())
@hass_callback
@ -374,6 +408,20 @@ class USBDiscovery:
return _async_remove_callback
@hass_callback
def async_register_port_event_callback(
self,
callback: PORT_EVENT_CALLBACK_TYPE,
) -> CALLBACK_TYPE:
"""Register a port event callback."""
self._port_event_callbacks.add(callback)
@hass_callback
def _async_remove_callback() -> None:
self._port_event_callbacks.discard(callback)
return _async_remove_callback
async def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
"""Process a USB discovery."""
_LOGGER.debug("Discovered USB Device: %s", device)
@ -418,11 +466,11 @@ class USBDiscovery:
async def _async_process_ports(self, ports: Sequence[ListPortInfo]) -> None:
"""Process each discovered port."""
usb_devices = [
usb_devices = {
usb_device_from_port(port)
for port in ports
if port.vid is not None or port.pid is not None
]
}
# CP2102N chips create *two* serial ports on macOS: `/dev/cu.usbserial-` and
# `/dev/cu.SLAB_USBtoUART*`. The former does not work and we should ignore them.
@ -433,7 +481,7 @@ class USBDiscovery:
if dev.device.startswith("/dev/cu.SLAB_USBtoUART")
}
usb_devices = [
usb_devices = {
dev
for dev in usb_devices
if dev.serial_number not in silabs_serials
@ -441,7 +489,22 @@ class USBDiscovery:
dev.serial_number in silabs_serials
and dev.device.startswith("/dev/cu.SLAB_USBtoUART")
)
]
}
added_devices = usb_devices - self._last_processed_devices
removed_devices = self._last_processed_devices - usb_devices
self._last_processed_devices = usb_devices
_LOGGER.debug(
"Added devices: %r, removed devices: %r", added_devices, removed_devices
)
if added_devices or removed_devices:
for callback in self._port_event_callbacks.copy():
try:
callback(added_devices, removed_devices)
except Exception:
_LOGGER.exception("Error in USB port event callback")
for usb_device in usb_devices:
await self._async_process_discovered_usb_device(usb_device)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass
@dataclass
@dataclass(slots=True, frozen=True, kw_only=True)
class USBDevice:
"""A usb device."""

View File

@ -2,6 +2,7 @@
import asyncio
from datetime import timedelta
import logging
import os
from typing import Any
from unittest.mock import MagicMock, Mock, call, patch, sentinel
@ -9,6 +10,7 @@ from unittest.mock import MagicMock, Mock, call, patch, sentinel
import pytest
from homeassistant.components import usb
from homeassistant.components.usb.utils import usb_device_from_port
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant
from homeassistant.helpers.service_info.usb import UsbServiceInfo
@ -80,7 +82,7 @@ async def test_observer_discovery(
async def _mock_monitor_observer_callback(callback):
await hass.async_add_executor_job(
callback, MagicMock(action="create", device_path="/dev/new")
callback, MagicMock(action="add", device_path="/dev/new")
)
def _create_mock_monitor_observer(monitor, callback, name):
@ -1235,3 +1237,165 @@ def test_deprecated_constants(
replacement,
"2026.2",
)
@patch("homeassistant.components.usb.REQUEST_SCAN_COOLDOWN", 0)
async def test_register_port_event_callback(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test the registration of a port event callback."""
port1 = Mock(
device=slae_sh_device.device,
vid=12345,
pid=12345,
serial_number=slae_sh_device.serial_number,
manufacturer=slae_sh_device.manufacturer,
description=slae_sh_device.description,
)
port2 = Mock(
device=conbee_device.device,
vid=12346,
pid=12346,
serial_number=conbee_device.serial_number,
manufacturer=conbee_device.manufacturer,
description=conbee_device.description,
)
port1_usb = usb_device_from_port(port1)
port2_usb = usb_device_from_port(port2)
ws_client = await hass_ws_client(hass)
mock_callback1 = Mock()
mock_callback2 = Mock()
# Start off with no ports
with (
patch("pyudev.Context", side_effect=ImportError),
patch("homeassistant.components.usb.comports", return_value=[]),
):
assert await async_setup_component(hass, "usb", {"usb": {}})
_cancel1 = usb.async_register_port_event_callback(hass, mock_callback1)
cancel2 = usb.async_register_port_event_callback(hass, mock_callback2)
assert mock_callback1.mock_calls == []
assert mock_callback2.mock_calls == []
# Add two new ports
with patch("homeassistant.components.usb.comports", return_value=[port1, port2]):
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
assert mock_callback1.mock_calls == [call({port1_usb, port2_usb}, set())]
assert mock_callback2.mock_calls == [call({port1_usb, port2_usb}, set())]
# Cancel the second callback
cancel2()
cancel2()
mock_callback1.reset_mock()
mock_callback2.reset_mock()
# Remove port 2
with patch("homeassistant.components.usb.comports", return_value=[port1]):
await ws_client.send_json({"id": 2, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert mock_callback1.mock_calls == [call(set(), {port2_usb})]
assert mock_callback2.mock_calls == [] # The second callback was unregistered
mock_callback1.reset_mock()
mock_callback2.reset_mock()
# Keep port 2 removed
with patch("homeassistant.components.usb.comports", return_value=[port1]):
await ws_client.send_json({"id": 3, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
# Nothing changed so no callback is called
assert mock_callback1.mock_calls == []
assert mock_callback2.mock_calls == []
# Unplug one and plug in the other
with patch("homeassistant.components.usb.comports", return_value=[port2]):
await ws_client.send_json({"id": 4, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert mock_callback1.mock_calls == [call({port2_usb}, {port1_usb})]
assert mock_callback2.mock_calls == []
@patch("homeassistant.components.usb.REQUEST_SCAN_COOLDOWN", 0)
async def test_register_port_event_callback_failure(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test port event callback failure handling."""
port1 = Mock(
device=slae_sh_device.device,
vid=12345,
pid=12345,
serial_number=slae_sh_device.serial_number,
manufacturer=slae_sh_device.manufacturer,
description=slae_sh_device.description,
)
port2 = Mock(
device=conbee_device.device,
vid=12346,
pid=12346,
serial_number=conbee_device.serial_number,
manufacturer=conbee_device.manufacturer,
description=conbee_device.description,
)
port1_usb = usb_device_from_port(port1)
port2_usb = usb_device_from_port(port2)
ws_client = await hass_ws_client(hass)
mock_callback1 = Mock(side_effect=RuntimeError("Failure 1"))
mock_callback2 = Mock(side_effect=RuntimeError("Failure 2"))
# Start off with no ports
with (
patch("pyudev.Context", side_effect=ImportError),
patch("homeassistant.components.usb.comports", return_value=[]),
):
assert await async_setup_component(hass, "usb", {"usb": {}})
usb.async_register_port_event_callback(hass, mock_callback1)
usb.async_register_port_event_callback(hass, mock_callback2)
assert mock_callback1.mock_calls == []
assert mock_callback2.mock_calls == []
# Add two new ports
with (
patch("homeassistant.components.usb.comports", return_value=[port1, port2]),
caplog.at_level(logging.ERROR, logger="homeassistant.components.usb"),
):
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
# Both were called even though they raised exceptions
assert mock_callback1.mock_calls == [call({port1_usb, port2_usb}, set())]
assert mock_callback2.mock_calls == [call({port1_usb, port2_usb}, set())]
assert caplog.text.count("Error in USB port event callback") == 2
assert "Failure 1" in caplog.text
assert "Failure 2" in caplog.text