diff --git a/homeassistant/components/usb/__init__.py b/homeassistant/components/usb/__init__.py index b688d821db4..ec65143b984 100644 --- a/homeassistant/components/usb/__init__.py +++ b/homeassistant/components/usb/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Coroutine, Sequence +from collections.abc import Callable, Coroutine, Sequence import dataclasses from datetime import datetime, timedelta import fnmatch @@ -48,12 +48,15 @@ if TYPE_CHECKING: _LOGGER = logging.getLogger(__name__) +PORT_EVENT_CALLBACK_TYPE = Callable[[set[USBDevice], set[USBDevice]], None] + POLLING_MONITOR_SCAN_PERIOD = timedelta(seconds=5) REQUEST_SCAN_COOLDOWN = 10 # 10 second cooldown __all__ = [ "USBCallbackMatcher", "async_is_plugged_in", + "async_register_port_event_callback", "async_register_scan_request_callback", ] @@ -85,6 +88,15 @@ def async_register_initial_scan_callback( return discovery.async_register_initial_scan_callback(callback) +@hass_callback +def async_register_port_event_callback( + hass: HomeAssistant, callback: PORT_EVENT_CALLBACK_TYPE +) -> CALLBACK_TYPE: + """Register to receive a callback when a USB device is connected or disconnected.""" + discovery: USBDiscovery = hass.data[DOMAIN] + return discovery.async_register_port_event_callback(callback) + + @hass_callback def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> bool: """Return True is a USB device is present.""" @@ -108,8 +120,25 @@ def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> boo usb_discovery: USBDiscovery = hass.data[DOMAIN] return any( - _is_matching(USBDevice(*device_tuple), matcher) - for device_tuple in usb_discovery.seen + _is_matching( + USBDevice( + device=device, + vid=vid, + pid=pid, + serial_number=serial_number, + manufacturer=manufacturer, + description=description, + ), + matcher, + ) + for ( + device, + vid, + pid, + serial_number, + manufacturer, + description, + ) in usb_discovery.seen ) @@ -229,6 +258,8 @@ class USBDiscovery: self._request_callbacks: list[CALLBACK_TYPE] = [] self.initial_scan_done = False self._initial_scan_callbacks: list[CALLBACK_TYPE] = [] + self._port_event_callbacks: set[PORT_EVENT_CALLBACK_TYPE] = set() + self._last_processed_devices: set[USBDevice] = set() async def async_setup(self) -> None: """Set up USB Discovery.""" @@ -324,20 +355,23 @@ class USBDiscovery: return None observer = MonitorObserver( - monitor, callback=self._device_discovered, name="usb-observer" + monitor, callback=self._device_event, name="usb-observer" ) observer.start() return observer - def _device_discovered(self, device: Device) -> None: - """Call when the observer discovers a new usb tty device.""" - if device.action != "add": + def _device_event(self, device: Device) -> None: + """Call when the observer receives a USB device event.""" + if device.action not in ("add", "remove"): return - _LOGGER.debug( - "Discovered Device at path: %s, triggering scan serial", - device.device_path, + + _LOGGER.info( + "Received a udev device event %r for %s, triggering scan", + device.action, + device.device_node, ) + self.hass.create_task(self._async_scan()) @hass_callback @@ -374,6 +408,20 @@ class USBDiscovery: return _async_remove_callback + @hass_callback + def async_register_port_event_callback( + self, + callback: PORT_EVENT_CALLBACK_TYPE, + ) -> CALLBACK_TYPE: + """Register a port event callback.""" + self._port_event_callbacks.add(callback) + + @hass_callback + def _async_remove_callback() -> None: + self._port_event_callbacks.discard(callback) + + return _async_remove_callback + async def _async_process_discovered_usb_device(self, device: USBDevice) -> None: """Process a USB discovery.""" _LOGGER.debug("Discovered USB Device: %s", device) @@ -418,11 +466,11 @@ class USBDiscovery: async def _async_process_ports(self, ports: Sequence[ListPortInfo]) -> None: """Process each discovered port.""" - usb_devices = [ + usb_devices = { usb_device_from_port(port) for port in ports if port.vid is not None or port.pid is not None - ] + } # CP2102N chips create *two* serial ports on macOS: `/dev/cu.usbserial-` and # `/dev/cu.SLAB_USBtoUART*`. The former does not work and we should ignore them. @@ -433,7 +481,7 @@ class USBDiscovery: if dev.device.startswith("/dev/cu.SLAB_USBtoUART") } - usb_devices = [ + usb_devices = { dev for dev in usb_devices if dev.serial_number not in silabs_serials @@ -441,7 +489,22 @@ class USBDiscovery: dev.serial_number in silabs_serials and dev.device.startswith("/dev/cu.SLAB_USBtoUART") ) - ] + } + + added_devices = usb_devices - self._last_processed_devices + removed_devices = self._last_processed_devices - usb_devices + self._last_processed_devices = usb_devices + + _LOGGER.debug( + "Added devices: %r, removed devices: %r", added_devices, removed_devices + ) + + if added_devices or removed_devices: + for callback in self._port_event_callbacks.copy(): + try: + callback(added_devices, removed_devices) + except Exception: + _LOGGER.exception("Error in USB port event callback") for usb_device in usb_devices: await self._async_process_discovered_usb_device(usb_device) diff --git a/homeassistant/components/usb/models.py b/homeassistant/components/usb/models.py index efc5b11c26e..11eccd9cd9b 100644 --- a/homeassistant/components/usb/models.py +++ b/homeassistant/components/usb/models.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass -@dataclass +@dataclass(slots=True, frozen=True, kw_only=True) class USBDevice: """A usb device.""" diff --git a/tests/components/usb/test_init.py b/tests/components/usb/test_init.py index f4002c81e40..8f8ed672374 100644 --- a/tests/components/usb/test_init.py +++ b/tests/components/usb/test_init.py @@ -2,6 +2,7 @@ import asyncio from datetime import timedelta +import logging import os from typing import Any from unittest.mock import MagicMock, Mock, call, patch, sentinel @@ -9,6 +10,7 @@ from unittest.mock import MagicMock, Mock, call, patch, sentinel import pytest from homeassistant.components import usb +from homeassistant.components.usb.utils import usb_device_from_port from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant from homeassistant.helpers.service_info.usb import UsbServiceInfo @@ -80,7 +82,7 @@ async def test_observer_discovery( async def _mock_monitor_observer_callback(callback): await hass.async_add_executor_job( - callback, MagicMock(action="create", device_path="/dev/new") + callback, MagicMock(action="add", device_path="/dev/new") ) def _create_mock_monitor_observer(monitor, callback, name): @@ -1235,3 +1237,165 @@ def test_deprecated_constants( replacement, "2026.2", ) + + +@patch("homeassistant.components.usb.REQUEST_SCAN_COOLDOWN", 0) +async def test_register_port_event_callback( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test the registration of a port event callback.""" + + port1 = Mock( + device=slae_sh_device.device, + vid=12345, + pid=12345, + serial_number=slae_sh_device.serial_number, + manufacturer=slae_sh_device.manufacturer, + description=slae_sh_device.description, + ) + + port2 = Mock( + device=conbee_device.device, + vid=12346, + pid=12346, + serial_number=conbee_device.serial_number, + manufacturer=conbee_device.manufacturer, + description=conbee_device.description, + ) + + port1_usb = usb_device_from_port(port1) + port2_usb = usb_device_from_port(port2) + + ws_client = await hass_ws_client(hass) + + mock_callback1 = Mock() + mock_callback2 = Mock() + + # Start off with no ports + with ( + patch("pyudev.Context", side_effect=ImportError), + patch("homeassistant.components.usb.comports", return_value=[]), + ): + assert await async_setup_component(hass, "usb", {"usb": {}}) + + _cancel1 = usb.async_register_port_event_callback(hass, mock_callback1) + cancel2 = usb.async_register_port_event_callback(hass, mock_callback2) + + assert mock_callback1.mock_calls == [] + assert mock_callback2.mock_calls == [] + + # Add two new ports + with patch("homeassistant.components.usb.comports", return_value=[port1, port2]): + await ws_client.send_json({"id": 1, "type": "usb/scan"}) + response = await ws_client.receive_json() + assert response["success"] + + assert mock_callback1.mock_calls == [call({port1_usb, port2_usb}, set())] + assert mock_callback2.mock_calls == [call({port1_usb, port2_usb}, set())] + + # Cancel the second callback + cancel2() + cancel2() + + mock_callback1.reset_mock() + mock_callback2.reset_mock() + + # Remove port 2 + with patch("homeassistant.components.usb.comports", return_value=[port1]): + await ws_client.send_json({"id": 2, "type": "usb/scan"}) + response = await ws_client.receive_json() + assert response["success"] + await hass.async_block_till_done() + + assert mock_callback1.mock_calls == [call(set(), {port2_usb})] + assert mock_callback2.mock_calls == [] # The second callback was unregistered + + mock_callback1.reset_mock() + mock_callback2.reset_mock() + + # Keep port 2 removed + with patch("homeassistant.components.usb.comports", return_value=[port1]): + await ws_client.send_json({"id": 3, "type": "usb/scan"}) + response = await ws_client.receive_json() + assert response["success"] + await hass.async_block_till_done() + + # Nothing changed so no callback is called + assert mock_callback1.mock_calls == [] + assert mock_callback2.mock_calls == [] + + # Unplug one and plug in the other + with patch("homeassistant.components.usb.comports", return_value=[port2]): + await ws_client.send_json({"id": 4, "type": "usb/scan"}) + response = await ws_client.receive_json() + assert response["success"] + await hass.async_block_till_done() + + assert mock_callback1.mock_calls == [call({port2_usb}, {port1_usb})] + assert mock_callback2.mock_calls == [] + + +@patch("homeassistant.components.usb.REQUEST_SCAN_COOLDOWN", 0) +async def test_register_port_event_callback_failure( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test port event callback failure handling.""" + + port1 = Mock( + device=slae_sh_device.device, + vid=12345, + pid=12345, + serial_number=slae_sh_device.serial_number, + manufacturer=slae_sh_device.manufacturer, + description=slae_sh_device.description, + ) + + port2 = Mock( + device=conbee_device.device, + vid=12346, + pid=12346, + serial_number=conbee_device.serial_number, + manufacturer=conbee_device.manufacturer, + description=conbee_device.description, + ) + + port1_usb = usb_device_from_port(port1) + port2_usb = usb_device_from_port(port2) + + ws_client = await hass_ws_client(hass) + + mock_callback1 = Mock(side_effect=RuntimeError("Failure 1")) + mock_callback2 = Mock(side_effect=RuntimeError("Failure 2")) + + # Start off with no ports + with ( + patch("pyudev.Context", side_effect=ImportError), + patch("homeassistant.components.usb.comports", return_value=[]), + ): + assert await async_setup_component(hass, "usb", {"usb": {}}) + + usb.async_register_port_event_callback(hass, mock_callback1) + usb.async_register_port_event_callback(hass, mock_callback2) + + assert mock_callback1.mock_calls == [] + assert mock_callback2.mock_calls == [] + + # Add two new ports + with ( + patch("homeassistant.components.usb.comports", return_value=[port1, port2]), + caplog.at_level(logging.ERROR, logger="homeassistant.components.usb"), + ): + await ws_client.send_json({"id": 1, "type": "usb/scan"}) + response = await ws_client.receive_json() + assert response["success"] + await hass.async_block_till_done() + + # Both were called even though they raised exceptions + assert mock_callback1.mock_calls == [call({port1_usb, port2_usb}, set())] + assert mock_callback2.mock_calls == [call({port1_usb, port2_usb}, set())] + + assert caplog.text.count("Error in USB port event callback") == 2 + assert "Failure 1" in caplog.text + assert "Failure 2" in caplog.text