From c2fefe03b2dc800f42de695f0b73a8f26621d882 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 11 Jul 2022 17:14:00 +0200 Subject: [PATCH] Add support for subscribing to bluetooth callbacks by address (#74773) --- .../components/bluetooth/__init__.py | 119 +++++++- homeassistant/components/bluetooth/models.py | 50 +++- tests/components/bluetooth/test_init.py | 268 +++++++++++++++++- 3 files changed, 415 insertions(+), 22 deletions(-) diff --git a/homeassistant/components/bluetooth/__init__.py b/homeassistant/components/bluetooth/__init__.py index cf7f1884869..26c1af3fb92 100644 --- a/homeassistant/components/bluetooth/__init__.py +++ b/homeassistant/components/bluetooth/__init__.py @@ -8,7 +8,7 @@ import fnmatch from functools import cached_property import logging import platform -from typing import Final +from typing import Final, TypedDict from bleak import BleakError from bleak.backends.device import MANUFACTURERS, BLEDevice @@ -26,7 +26,11 @@ from homeassistant.core import ( from homeassistant.data_entry_flow import BaseServiceInfo from homeassistant.helpers import discovery_flow from homeassistant.helpers.typing import ConfigType -from homeassistant.loader import BluetoothMatcher, async_get_bluetooth +from homeassistant.loader import ( + BluetoothMatcher, + BluetoothMatcherOptional, + async_get_bluetooth, +) from . import models from .const import DOMAIN @@ -38,6 +42,19 @@ _LOGGER = logging.getLogger(__name__) MAX_REMEMBER_ADDRESSES: Final = 2048 +class BluetoothCallbackMatcherOptional(TypedDict, total=False): + """Matcher for the bluetooth integration for callback optional fields.""" + + address: str + + +class BluetoothCallbackMatcher( + BluetoothMatcherOptional, + BluetoothCallbackMatcherOptional, +): + """Callback matcher for the bluetooth integration.""" + + class BluetoothScanningMode(Enum): """The mode of scanning for bluetooth devices.""" @@ -50,6 +67,7 @@ SCANNING_MODE_TO_BLEAK = { BluetoothScanningMode.PASSIVE: "passive", } +ADDRESS: Final = "address" LOCAL_NAME: Final = "local_name" SERVICE_UUID: Final = "service_uuid" MANUFACTURER_ID: Final = "manufacturer_id" @@ -102,11 +120,34 @@ BluetoothChange = Enum("BluetoothChange", "ADVERTISEMENT") BluetoothCallback = Callable[[BluetoothServiceInfo, BluetoothChange], None] +@hass_callback +def async_discovered_service_info( + hass: HomeAssistant, +) -> list[BluetoothServiceInfo]: + """Return the discovered devices list.""" + if DOMAIN not in hass.data: + return [] + manager: BluetoothManager = hass.data[DOMAIN] + return manager.async_discovered_service_info() + + +@hass_callback +def async_address_present( + hass: HomeAssistant, + address: str, +) -> bool: + """Check if an address is present in the bluetooth device list.""" + if DOMAIN not in hass.data: + return False + manager: BluetoothManager = hass.data[DOMAIN] + return manager.async_address_present(address) + + @hass_callback def async_register_callback( hass: HomeAssistant, callback: BluetoothCallback, - match_dict: BluetoothMatcher | None, + match_dict: BluetoothCallbackMatcher | None, ) -> Callable[[], None]: """Register to receive a callback on bluetooth change. @@ -128,9 +169,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: def _ble_device_matches( - matcher: BluetoothMatcher, device: BLEDevice, advertisement_data: AdvertisementData + matcher: BluetoothCallbackMatcher | BluetoothMatcher, + device: BLEDevice, + advertisement_data: AdvertisementData, ) -> bool: """Check if a ble device and advertisement_data matches the matcher.""" + if ( + matcher_address := matcher.get(ADDRESS) + ) is not None and device.address != matcher_address: + return False + if ( matcher_local_name := matcher.get(LOCAL_NAME) ) is not None and not fnmatch.fnmatch( @@ -192,7 +240,9 @@ class BluetoothManager: self._integration_matchers = integration_matchers self.scanner: HaBleakScanner | None = None self._cancel_device_detected: CALLBACK_TYPE | None = None - self._callbacks: list[tuple[BluetoothCallback, BluetoothMatcher | None]] = [] + self._callbacks: list[ + tuple[BluetoothCallback, BluetoothCallbackMatcher | None] + ] = [] # Some devices use a random address so we need to use # an LRU to avoid memory issues. self._matched: LRU = LRU(MAX_REMEMBER_ADDRESSES) @@ -227,14 +277,22 @@ class BluetoothManager: ) -> None: """Handle a detected device.""" matched_domains: set[str] | None = None - if device.address not in self._matched: + match_key = (device.address, bool(advertisement_data.manufacturer_data)) + match_key_has_mfr_data = (device.address, True) + + # If we matched without manufacturer_data, we need to do it again + # since we may think the device is unsupported otherwise + if ( + match_key_has_mfr_data not in self._matched + and match_key not in self._matched + ): matched_domains = { matcher["domain"] for matcher in self._integration_matchers if _ble_device_matches(matcher, device, advertisement_data) } if matched_domains: - self._matched[device.address] = True + self._matched[match_key] = True _LOGGER.debug( "Device detected: %s with advertisement_data: %s matched domains: %s", device, @@ -275,18 +333,61 @@ class BluetoothManager: @hass_callback def async_register_callback( - self, callback: BluetoothCallback, match_dict: BluetoothMatcher | None = None + self, + callback: BluetoothCallback, + matcher: BluetoothCallbackMatcher | None = None, ) -> Callable[[], None]: """Register a callback.""" - callback_entry = (callback, match_dict) + callback_entry = (callback, matcher) self._callbacks.append(callback_entry) @hass_callback def _async_remove_callback() -> None: self._callbacks.remove(callback_entry) + # If we have history for the subscriber, we can trigger the callback + # immediately with the last packet so the subscriber can see the + # device. + if ( + matcher + and (address := matcher.get(ADDRESS)) + and models.HA_BLEAK_SCANNER + and (device_adv_data := models.HA_BLEAK_SCANNER.history.get(address)) + ): + try: + callback( + BluetoothServiceInfo.from_advertisement(*device_adv_data), + BluetoothChange.ADVERTISEMENT, + ) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error in bluetooth callback") + return _async_remove_callback + @hass_callback + def async_address_present(self, address: str) -> bool: + """Return if the address is present.""" + return bool( + models.HA_BLEAK_SCANNER + and any( + device.address == address + for device in models.HA_BLEAK_SCANNER.discovered_devices + ) + ) + + @hass_callback + def async_discovered_service_info(self) -> list[BluetoothServiceInfo]: + """Return if the address is present.""" + if models.HA_BLEAK_SCANNER: + discovered = models.HA_BLEAK_SCANNER.discovered_devices + history = models.HA_BLEAK_SCANNER.history + return [ + BluetoothServiceInfo.from_advertisement(*history[device.address]) + for device in discovered + if device.address in history + ] + return [] + async def async_stop(self, event: Event) -> None: """Stop bluetooth discovery.""" if self._cancel_device_detected: diff --git a/homeassistant/components/bluetooth/models.py b/homeassistant/components/bluetooth/models.py index a2651c587f7..43d4d0cb923 100644 --- a/homeassistant/components/bluetooth/models.py +++ b/homeassistant/components/bluetooth/models.py @@ -8,7 +8,11 @@ from typing import Any, Final, cast from bleak import BleakScanner from bleak.backends.device import BLEDevice -from bleak.backends.scanner import AdvertisementData, AdvertisementDataCallback +from bleak.backends.scanner import ( + AdvertisementData, + AdvertisementDataCallback, + BaseBleakScanner, +) from lru import LRU # pylint: disable=no-name-in-module from homeassistant.core import CALLBACK_TYPE, callback as hass_callback @@ -52,7 +56,7 @@ class HaBleakScanner(BleakScanner): # type: ignore[misc] self._callbacks: list[ tuple[AdvertisementDataCallback, dict[str, set[str]]] ] = [] - self._history: LRU = LRU(MAX_HISTORY_SIZE) + self.history: LRU = LRU(MAX_HISTORY_SIZE) super().__init__(*args, **kwargs) @hass_callback @@ -70,7 +74,7 @@ class HaBleakScanner(BleakScanner): # type: ignore[misc] # Replay the history since otherwise we miss devices # that were already discovered before the callback was registered # or we are in passive mode - for device, advertisement_data in self._history.values(): + for device, advertisement_data in self.history.values(): _dispatch_callback(callback, filters, device, advertisement_data) return _remove_callback @@ -83,31 +87,46 @@ class HaBleakScanner(BleakScanner): # type: ignore[misc] Here we get the actual callback from bleak and dispatch it to all the wrapped HaBleakScannerWrapper classes """ - self._history[device.address] = (device, advertisement_data) + self.history[device.address] = (device, advertisement_data) for callback_filters in self._callbacks: _dispatch_callback(*callback_filters, device, advertisement_data) -class HaBleakScannerWrapper(BleakScanner): # type: ignore[misc] +class HaBleakScannerWrapper(BaseBleakScanner): # type: ignore[misc] """A wrapper that uses the single instance.""" def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the BleakScanner.""" self._detection_cancel: CALLBACK_TYPE | None = None self._mapped_filters: dict[str, set[str]] = {} - if "filters" in kwargs: - self._mapped_filters = {k: set(v) for k, v in kwargs["filters"].items()} - if "service_uuids" in kwargs: - self._mapped_filters[FILTER_UUIDS] = set(kwargs["service_uuids"]) + self._adv_data_callback: AdvertisementDataCallback | None = None + self._map_filters(*args, **kwargs) super().__init__(*args, **kwargs) async def stop(self, *args: Any, **kwargs: Any) -> None: """Stop scanning for devices.""" - return async def start(self, *args: Any, **kwargs: Any) -> None: """Start scanning for devices.""" - return + + def _map_filters(self, *args: Any, **kwargs: Any) -> bool: + """Map the filters.""" + mapped_filters = {} + if filters := kwargs.get("filters"): + if FILTER_UUIDS not in filters: + _LOGGER.warning("Only %s filters are supported", FILTER_UUIDS) + mapped_filters = {k: set(v) for k, v in filters.items()} + if service_uuids := kwargs.get("service_uuids"): + mapped_filters[FILTER_UUIDS] = set(service_uuids) + if mapped_filters == self._mapped_filters: + return False + self._mapped_filters = mapped_filters + return True + + def set_scanning_filter(self, *args: Any, **kwargs: Any) -> None: + """Set the filters to use.""" + if self._map_filters(*args, **kwargs): + self._setup_detection_callback() def _cancel_callback(self) -> None: """Cancel callback.""" @@ -127,8 +146,15 @@ class HaBleakScannerWrapper(BleakScanner): # type: ignore[misc] This method takes the callback and registers it with the long running scanner. """ + self._adv_data_callback = callback + self._setup_detection_callback() + + def _setup_detection_callback(self) -> None: + """Set up the detection callback.""" + if self._adv_data_callback is None: + return self._cancel_callback() - super().register_detection_callback(callback) + super().register_detection_callback(self._adv_data_callback) assert HA_BLEAK_SCANNER is not None self._detection_cancel = HA_BLEAK_SCANNER.async_register_callback( self._callback, self._mapped_filters diff --git a/tests/components/bluetooth/test_init.py b/tests/components/bluetooth/test_init.py index 1e0647df01c..f43ef4737f6 100644 --- a/tests/components/bluetooth/test_init.py +++ b/tests/components/bluetooth/test_init.py @@ -78,6 +78,30 @@ async def test_setup_and_stop_no_bluetooth(hass, caplog): assert "Could not create bluetooth scanner" in caplog.text +async def test_calling_async_discovered_devices_no_bluetooth(hass, caplog): + """Test we fail gracefully when asking for discovered devices and there is no blueooth.""" + mock_bt = [] + with patch( + "homeassistant.components.bluetooth.HaBleakScanner", side_effect=BleakError + ) as mock_ha_bleak_scanner, patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ), patch.object( + hass.config_entries.flow, "async_init" + ): + assert await async_setup_component( + hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}} + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + assert len(mock_ha_bleak_scanner.mock_calls) == 1 + assert "Could not create bluetooth scanner" in caplog.text + assert not bluetooth.async_discovered_service_info(hass) + assert not bluetooth.async_address_present(hass, "aa:bb:bb:dd:ee:ff") + + async def test_discovery_match_by_service_uuid(hass, mock_bleak_scanner_start): """Test bluetooth discovery match by service_uuid.""" mock_bt = [ @@ -207,8 +231,47 @@ async def test_discovery_match_by_manufacturer_id_and_first_byte( assert len(mock_config_flow.mock_calls) == 0 +async def test_async_discovered_device_api(hass, mock_bleak_scanner_start): + """Test the async_discovered_device_api.""" + mock_bt = [] + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ), patch( + "bleak.BleakScanner.discovered_devices", # Must patch before we setup + [MagicMock(address="44:44:33:11:23:45")], + ): + assert not bluetooth.async_discovered_service_info(hass) + assert not bluetooth.async_address_present(hass, "44:44:22:22:11:22") + + assert await async_setup_component( + hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}} + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + assert not bluetooth.async_discovered_service_info(hass) + + wrong_device = BLEDevice("44:44:33:11:23:42", "wrong_name") + wrong_adv = AdvertisementData(local_name="wrong_name", service_uuids=[]) + models.HA_BLEAK_SCANNER._callback(wrong_device, wrong_adv) + switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand") + switchbot_adv = AdvertisementData(local_name="wohand", service_uuids=[]) + models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv) + await hass.async_block_till_done() + + service_infos = bluetooth.async_discovered_service_info(hass) + assert len(service_infos) == 1 + # wrong_name should not appear because bleak no longer sees it + assert service_infos[0].name == "wohand" + + assert bluetooth.async_address_present(hass, "44:44:33:11:23:42") is False + assert bluetooth.async_address_present(hass, "44:44:33:11:23:45") is True + + async def test_register_callbacks(hass, mock_bleak_scanner_start): - """Test configured options for a device are loaded via config entry.""" + """Test registering a callback.""" mock_bt = [] callbacks = [] @@ -284,6 +347,92 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start): assert service_info.manufacturer_id is None +async def test_register_callback_by_address(hass, mock_bleak_scanner_start): + """Test registering a callback by address.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + if len(callbacks) >= 3: + raise ValueError + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ), patch.object(hass.config_entries.flow, "async_init"): + assert await async_setup_component( + hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}} + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {"address": "44:44:33:11:23:45"}, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand") + switchbot_adv = AdvertisementData( + local_name="wohand", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"}, + service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"}, + ) + + models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv) + + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv) + await hass.async_block_till_done() + + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + # 3rd callback raises ValueError but is still tracked + models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv) + await hass.async_block_till_done() + + cancel() + + # 4th callback should not be tracked since we canceled + models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv) + await hass.async_block_till_done() + + # Now register again with a callback that fails to + # make sure we do not perm fail + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {"address": "44:44:33:11:23:45"}, + ) + cancel() + + # Now register again, since the 3rd callback + # should fail but we should still record it + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {"address": "44:44:33:11:23:45"}, + ) + cancel() + + assert len(callbacks) == 3 + + for idx in range(3): + service_info: BluetoothServiceInfo = callbacks[idx][0] + assert service_info.name == "wohand" + assert service_info.manufacturer == "Nordic Semiconductor ASA" + assert service_info.manufacturer_id == 89 + + async def test_wrapped_instance_with_filter(hass, mock_bleak_scanner_start): """Test consumers can use the wrapped instance with a filter as if it was normal BleakScanner.""" with patch( @@ -438,3 +587,120 @@ async def test_wrapped_instance_with_broken_callbacks(hass, mock_bleak_scanner_s models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv) await hass.async_block_till_done() assert len(detected) == 1 + + +async def test_wrapped_instance_changes_uuids(hass, mock_bleak_scanner_start): + """Test consumers can use the wrapped instance can change the uuids later.""" + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=[] + ), patch.object(hass.config_entries.flow, "async_init"): + assert await async_setup_component( + hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}} + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + detected = [] + + def _device_detected( + device: BLEDevice, advertisement_data: AdvertisementData + ) -> None: + """Handle a detected device.""" + detected.append((device, advertisement_data)) + + switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand") + switchbot_adv = AdvertisementData( + local_name="wohand", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"}, + service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"}, + ) + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + assert models.HA_BLEAK_SCANNER is not None + scanner = models.HaBleakScannerWrapper() + scanner.set_scanning_filter(service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]) + scanner.register_detection_callback(_device_detected) + + type(models.HA_BLEAK_SCANNER).discovered_devices = [MagicMock()] + for _ in range(2): + models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv) + await hass.async_block_till_done() + + assert len(detected) == 2 + + # The UUIDs list we created in the wrapped scanner with should be respected + # and we should not get another callback + models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv) + assert len(detected) == 2 + + +async def test_wrapped_instance_changes_filters(hass, mock_bleak_scanner_start): + """Test consumers can use the wrapped instance can change the filter later.""" + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=[] + ), patch.object(hass.config_entries.flow, "async_init"): + assert await async_setup_component( + hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}} + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + detected = [] + + def _device_detected( + device: BLEDevice, advertisement_data: AdvertisementData + ) -> None: + """Handle a detected device.""" + detected.append((device, advertisement_data)) + + switchbot_device = BLEDevice("44:44:33:11:23:42", "wohand") + switchbot_adv = AdvertisementData( + local_name="wohand", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"}, + service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"}, + ) + empty_device = BLEDevice("11:22:33:44:55:62", "empty") + empty_adv = AdvertisementData(local_name="empty") + + assert models.HA_BLEAK_SCANNER is not None + scanner = models.HaBleakScannerWrapper() + scanner.set_scanning_filter( + filters={"UUIDs": ["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]} + ) + scanner.register_detection_callback(_device_detected) + + type(models.HA_BLEAK_SCANNER).discovered_devices = [MagicMock()] + for _ in range(2): + models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv) + await hass.async_block_till_done() + + assert len(detected) == 2 + + # The UUIDs list we created in the wrapped scanner with should be respected + # and we should not get another callback + models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv) + assert len(detected) == 2 + + +async def test_wrapped_instance_unsupported_filter( + hass, mock_bleak_scanner_start, caplog +): + """Test we want when their filter is ineffective.""" + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=[] + ), patch.object(hass.config_entries.flow, "async_init"): + assert await async_setup_component( + hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}} + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + assert models.HA_BLEAK_SCANNER is not None + scanner = models.HaBleakScannerWrapper() + scanner.set_scanning_filter( + filters={"unsupported": ["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]} + ) + assert "Only UUIDs filters are supported" in caplog.text