From 61d5ed1dcfb2554ed7e4c259421da2b069c7bce2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 26 Aug 2022 22:07:51 -0500 Subject: [PATCH] Index bluetooth matchers to resolve performance concerns with many adapters/remotes (#77372) --- .../components/bluetooth/__init__.py | 1 + homeassistant/components/bluetooth/manager.py | 63 +- homeassistant/components/bluetooth/match.py | 271 +++++++- tests/components/bluetooth/test_init.py | 618 +++++++++++++++++- 4 files changed, 865 insertions(+), 88 deletions(-) diff --git a/homeassistant/components/bluetooth/__init__.py b/homeassistant/components/bluetooth/__init__.py index 208bbe6952b..632635f7dbc 100644 --- a/homeassistant/components/bluetooth/__init__.py +++ b/homeassistant/components/bluetooth/__init__.py @@ -209,6 +209,7 @@ async def async_get_adapter_from_address( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the bluetooth integration.""" integration_matcher = IntegrationMatcher(await async_get_bluetooth(hass)) + integration_matcher.async_setup() manager = BluetoothManager(hass, integration_matcher) manager.async_setup() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, manager.async_stop) diff --git a/homeassistant/components/bluetooth/manager.py b/homeassistant/components/bluetooth/manager.py index 2fff99c830c..be5038a6d31 100644 --- a/homeassistant/components/bluetooth/manager.py +++ b/homeassistant/components/bluetooth/manager.py @@ -27,8 +27,11 @@ from .const import ( ) from .match import ( ADDRESS, + CALLBACK, CONNECTABLE, BluetoothCallbackMatcher, + BluetoothCallbackMatcherIndex, + BluetoothCallbackMatcherWithCallback, IntegrationMatcher, ble_device_matches, ) @@ -132,12 +135,7 @@ class BluetoothManager: self._connectable_unavailable_callbacks: dict[ str, list[Callable[[str], None]] ] = {} - self._callbacks: list[ - tuple[BluetoothCallback, BluetoothCallbackMatcher | None] - ] = [] - self._connectable_callbacks: list[ - tuple[BluetoothCallback, BluetoothCallbackMatcher | None] - ] = [] + self._callback_index = BluetoothCallbackMatcherIndex() self._bleak_callbacks: list[ tuple[AdvertisementDataCallback, dict[str, set[str]]] ] = [] @@ -255,7 +253,7 @@ class BluetoothManager: device = service_info.device connectable = service_info.connectable address = device.address - all_history = self._get_history_by_type(connectable) + all_history = self._connectable_history if connectable else self._history old_service_info = all_history.get(address) if old_service_info and _prefer_previous_adv(old_service_info, service_info): return @@ -281,24 +279,13 @@ class BluetoothManager: matched_domains, ) - if ( - not matched_domains - and not self._callbacks - and not self._connectable_callbacks - ): - return + for match in self._callback_index.match_callbacks(service_info): + callback = match[CALLBACK] + try: + callback(service_info, BluetoothChange.ADVERTISEMENT) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Error in bluetooth callback") - for connectable_callback in (True, False): - for callback, matcher in self._get_callbacks_by_type(connectable_callback): - if matcher and not ble_device_matches(matcher, service_info): - continue - try: - callback(service_info, BluetoothChange.ADVERTISEMENT) - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Error in bluetooth callback") - - if not matched_domains: - return for domain in matched_domains: discovery_flow.async_create_flow( self.hass, @@ -330,28 +317,30 @@ class BluetoothManager: matcher: BluetoothCallbackMatcher | None, ) -> Callable[[], None]: """Register a callback.""" + callback_matcher = BluetoothCallbackMatcherWithCallback(callback=callback) if not matcher: - matcher = BluetoothCallbackMatcher(connectable=True) - if CONNECTABLE not in matcher: - matcher[CONNECTABLE] = True - connectable = matcher[CONNECTABLE] + callback_matcher[CONNECTABLE] = True + else: + # We could write out every item in the typed dict here + # but that would be a bit inefficient and verbose. + callback_matcher.update(matcher) # type: ignore[typeddict-item] + callback_matcher[CONNECTABLE] = matcher.get(CONNECTABLE, True) - callback_entry = (callback, matcher) - callbacks = self._get_callbacks_by_type(connectable) - callbacks.append(callback_entry) + connectable = callback_matcher[CONNECTABLE] + self._callback_index.add_with_address(callback_matcher) @hass_callback def _async_remove_callback() -> None: - callbacks.remove(callback_entry) + self._callback_index.remove_with_address(callback_matcher) # If we have history for the subscriber, we can trigger the callback # immediately with the last packet so the subscriber can see the # device. all_history = self._get_history_by_type(connectable) if ( - (address := matcher.get(ADDRESS)) + (address := callback_matcher.get(ADDRESS)) and (service_info := all_history.get(address)) - and ble_device_matches(matcher, service_info) + and ble_device_matches(callback_matcher, service_info) ): try: callback(service_info, BluetoothChange.ADVERTISEMENT) @@ -407,12 +396,6 @@ class BluetoothManager: """Return the history by type.""" return self._connectable_history if connectable else self._history - def _get_callbacks_by_type( - self, connectable: bool - ) -> list[tuple[BluetoothCallback, BluetoothCallbackMatcher | None]]: - """Return the callbacks by type.""" - return self._connectable_callbacks if connectable else self._callbacks - def async_register_scanner( self, scanner: BaseHaScanner, connectable: bool ) -> CALLBACK_TYPE: diff --git a/homeassistant/components/bluetooth/match.py b/homeassistant/components/bluetooth/match.py index 4a0aa8ee995..e9f535200b5 100644 --- a/homeassistant/components/bluetooth/match.py +++ b/homeassistant/components/bluetooth/match.py @@ -5,13 +5,14 @@ from dataclasses import dataclass from fnmatch import translate from functools import lru_cache import re -from typing import TYPE_CHECKING, Final, TypedDict +from typing import TYPE_CHECKING, Final, Generic, TypedDict, TypeVar from lru import LRU # pylint: disable=no-name-in-module +from homeassistant.core import callback from homeassistant.loader import BluetoothMatcher, BluetoothMatcherOptional -from .models import BluetoothServiceInfoBleak +from .models import BluetoothCallback, BluetoothServiceInfoBleak if TYPE_CHECKING: from collections.abc import MutableMapping @@ -21,7 +22,8 @@ if TYPE_CHECKING: MAX_REMEMBER_ADDRESSES: Final = 2048 - +CALLBACK: Final = "callback" +DOMAIN: Final = "domain" ADDRESS: Final = "address" CONNECTABLE: Final = "connectable" LOCAL_NAME: Final = "local_name" @@ -30,6 +32,8 @@ SERVICE_DATA_UUID: Final = "service_data_uuid" MANUFACTURER_ID: Final = "manufacturer_id" MANUFACTURER_DATA_START: Final = "manufacturer_data_start" +LOCAL_NAME_MIN_MATCH_LENGTH = 3 + class BluetoothCallbackMatcherOptional(TypedDict, total=False): """Matcher for the bluetooth integration for callback optional fields.""" @@ -44,6 +48,19 @@ class BluetoothCallbackMatcher( """Callback matcher for the bluetooth integration.""" +class _BluetoothCallbackMatcherWithCallback(TypedDict): + """Callback for the bluetooth integration.""" + + callback: BluetoothCallback + + +class BluetoothCallbackMatcherWithCallback( + _BluetoothCallbackMatcherWithCallback, + BluetoothCallbackMatcher, +): + """Callback matcher for the bluetooth integration that stores the callback.""" + + @dataclass(frozen=False) class IntegrationMatchHistory: """Track which fields have been seen.""" @@ -86,23 +103,26 @@ class IntegrationMatcher: self._matched_connectable: MutableMapping[str, IntegrationMatchHistory] = LRU( MAX_REMEMBER_ADDRESSES ) + self._index = BluetoothMatcherIndex() + + @callback + def async_setup(self) -> None: + """Set up the matcher.""" + for matcher in self._integration_matchers: + self._index.add(matcher) + self._index.build() def async_clear_address(self, address: str) -> None: """Clear the history matches for a set of domains.""" self._matched.pop(address, None) self._matched_connectable.pop(address, None) - def _get_matched_by_type( - self, connectable: bool - ) -> MutableMapping[str, IntegrationMatchHistory]: - """Return the matches by type.""" - return self._matched_connectable if connectable else self._matched - def match_domains(self, service_info: BluetoothServiceInfoBleak) -> set[str]: """Return the domains that are matched.""" device = service_info.device advertisement_data = service_info.advertisement - matched = self._get_matched_by_type(service_info.connectable) + connectable = service_info.connectable + matched = self._matched_connectable if connectable else self._matched matched_domains: set[str] = set() if (previous_match := matched.get(device.address)) and seen_all_fields( previous_match, advertisement_data @@ -110,9 +130,7 @@ class IntegrationMatcher: # We have seen all fields so we can skip the rest of the matchers return matched_domains matched_domains = { - matcher["domain"] - for matcher in self._integration_matchers - if ble_device_matches(matcher, service_info) + matcher[DOMAIN] for matcher in self._index.match(service_info) } if not matched_domains: return matched_domains @@ -131,14 +149,209 @@ class IntegrationMatcher: return matched_domains +_T = TypeVar("_T", BluetoothMatcher, BluetoothCallbackMatcherWithCallback) + + +class BluetoothMatcherIndexBase(Generic[_T]): + """Bluetooth matcher base for the bluetooth integration. + + The indexer puts each matcher in the bucket that it is most + likely to match. This allows us to only check the service infos + against each bucket to see if we should match against the data. + + This is optimized for cases were no service infos will be matched in + any bucket and we can quickly reject the service info as not matching. + """ + + def __init__(self) -> None: + """Initialize the matcher index.""" + self.local_name: dict[str, list[_T]] = {} + self.service_uuid: dict[str, list[_T]] = {} + self.service_data_uuid: dict[str, list[_T]] = {} + self.manufacturer_id: dict[int, list[_T]] = {} + self.service_uuid_set: set[str] = set() + self.service_data_uuid_set: set[str] = set() + self.manufacturer_id_set: set[int] = set() + + def add(self, matcher: _T) -> None: + """Add a matcher to the index. + + Matchers must end up only in one bucket. + + We put them in the bucket that they are most likely to match. + """ + if LOCAL_NAME in matcher: + self.local_name.setdefault( + _local_name_to_index_key(matcher[LOCAL_NAME]), [] + ).append(matcher) + return + + if SERVICE_UUID in matcher: + self.service_uuid.setdefault(matcher[SERVICE_UUID], []).append(matcher) + return + + if SERVICE_DATA_UUID in matcher: + self.service_data_uuid.setdefault(matcher[SERVICE_DATA_UUID], []).append( + matcher + ) + return + + if MANUFACTURER_ID in matcher: + self.manufacturer_id.setdefault(matcher[MANUFACTURER_ID], []).append( + matcher + ) + return + + def remove(self, matcher: _T) -> None: + """Remove a matcher from the index. + + Matchers only end up in one bucket, so once we have + removed one, we are done. + """ + if LOCAL_NAME in matcher: + self.local_name[_local_name_to_index_key(matcher[LOCAL_NAME])].remove( + matcher + ) + return + + if SERVICE_UUID in matcher: + self.service_uuid[matcher[SERVICE_UUID]].remove(matcher) + return + + if SERVICE_DATA_UUID in matcher: + self.service_data_uuid[matcher[SERVICE_DATA_UUID]].remove(matcher) + return + + if MANUFACTURER_ID in matcher: + self.manufacturer_id[matcher[MANUFACTURER_ID]].remove(matcher) + return + + def build(self) -> None: + """Rebuild the index sets.""" + self.service_uuid_set = set(self.service_uuid) + self.service_data_uuid_set = set(self.service_data_uuid) + self.manufacturer_id_set = set(self.manufacturer_id) + + def match(self, service_info: BluetoothServiceInfoBleak) -> list[_T]: + """Check for a match.""" + matches = [] + if len(service_info.name) >= LOCAL_NAME_MIN_MATCH_LENGTH: + for matcher in self.local_name.get( + service_info.name[:LOCAL_NAME_MIN_MATCH_LENGTH], [] + ): + if ble_device_matches(matcher, service_info): + matches.append(matcher) + + for service_data_uuid in self.service_data_uuid_set.intersection( + service_info.service_data + ): + for matcher in self.service_data_uuid[service_data_uuid]: + if ble_device_matches(matcher, service_info): + matches.append(matcher) + + for manufacturer_id in self.manufacturer_id_set.intersection( + service_info.manufacturer_data + ): + for matcher in self.manufacturer_id[manufacturer_id]: + if ble_device_matches(matcher, service_info): + matches.append(matcher) + + for service_uuid in self.service_uuid_set.intersection( + service_info.service_uuids + ): + for matcher in self.service_uuid[service_uuid]: + if ble_device_matches(matcher, service_info): + matches.append(matcher) + + return matches + + +class BluetoothMatcherIndex(BluetoothMatcherIndexBase[BluetoothMatcher]): + """Bluetooth matcher for the bluetooth integration.""" + + +class BluetoothCallbackMatcherIndex( + BluetoothMatcherIndexBase[BluetoothCallbackMatcherWithCallback] +): + """Bluetooth matcher for the bluetooth integration that supports matching on addresses.""" + + def __init__(self) -> None: + """Initialize the matcher index.""" + super().__init__() + self.address: dict[str, list[BluetoothCallbackMatcherWithCallback]] = {} + + def add_with_address(self, matcher: BluetoothCallbackMatcherWithCallback) -> None: + """Add a matcher to the index. + + Matchers must end up only in one bucket. + + We put them in the bucket that they are most likely to match. + """ + if ADDRESS in matcher: + self.address.setdefault(matcher[ADDRESS], []).append(matcher) + return + + super().add(matcher) + self.build() + + def remove_with_address( + self, matcher: BluetoothCallbackMatcherWithCallback + ) -> None: + """Remove a matcher from the index. + + Matchers only end up in one bucket, so once we have + removed one, we are done. + """ + if ADDRESS in matcher: + self.address[matcher[ADDRESS]].remove(matcher) + return + + super().remove(matcher) + self.build() + + def match_callbacks( + self, service_info: BluetoothServiceInfoBleak + ) -> list[BluetoothCallbackMatcherWithCallback]: + """Check for a match.""" + matches = self.match(service_info) + for matcher in self.address.get(service_info.address, []): + if ble_device_matches(matcher, service_info): + matches.append(matcher) + return matches + + +def _local_name_to_index_key(local_name: str) -> str: + """Convert a local name to an index. + + We check the local name matchers here and raise a ValueError + if they try to setup a matcher that will is overly broad + as would match too many devices and cause a performance hit. + """ + if len(local_name) < LOCAL_NAME_MIN_MATCH_LENGTH: + raise ValueError( + "Local name matchers must be at least " + f"{LOCAL_NAME_MIN_MATCH_LENGTH} characters long ({local_name})" + ) + match_part = local_name[:LOCAL_NAME_MIN_MATCH_LENGTH] + if "*" in match_part or "[" in match_part: + raise ValueError( + "Local name matchers may not have patterns in the first " + f"{LOCAL_NAME_MIN_MATCH_LENGTH} characters because they " + f"would match too broadly ({local_name})" + ) + return match_part + + def ble_device_matches( - matcher: BluetoothCallbackMatcher | BluetoothMatcher, + matcher: BluetoothMatcherOptional, service_info: BluetoothServiceInfoBleak, ) -> bool: """Check if a ble device and advertisement_data matches the matcher.""" device = service_info.device - if (address := matcher.get(ADDRESS)) is not None and device.address != address: - return False + + # Do don't check address here since all callers already + # check the address and we don't want to double check + # since it would result in an unreachable reject case. if matcher.get(CONNECTABLE, True) and not service_info.connectable: return False @@ -146,28 +359,26 @@ def ble_device_matches( advertisement_data = service_info.advertisement if ( service_uuid := matcher.get(SERVICE_UUID) - ) is not None and service_uuid not in advertisement_data.service_uuids: + ) and service_uuid not in advertisement_data.service_uuids: return False if ( service_data_uuid := matcher.get(SERVICE_DATA_UUID) - ) is not None and service_data_uuid not in advertisement_data.service_data: + ) and service_data_uuid not in advertisement_data.service_data: return False - if ( - manfacturer_id := matcher.get(MANUFACTURER_ID) - ) is not None and manfacturer_id not in advertisement_data.manufacturer_data: - return False - - if (manufacturer_data_start := matcher.get(MANUFACTURER_DATA_START)) is not None: - manufacturer_data_start_bytes = bytearray(manufacturer_data_start) - if not any( - manufacturer_data.startswith(manufacturer_data_start_bytes) - for manufacturer_data in advertisement_data.manufacturer_data.values() - ): + if manfacturer_id := matcher.get(MANUFACTURER_ID): + if manfacturer_id not in advertisement_data.manufacturer_data: return False + if manufacturer_data_start := matcher.get(MANUFACTURER_DATA_START): + manufacturer_data_start_bytes = bytearray(manufacturer_data_start) + if not any( + manufacturer_data.startswith(manufacturer_data_start_bytes) + for manufacturer_data in advertisement_data.manufacturer_data.values() + ): + return False - if (local_name := matcher.get(LOCAL_NAME)) is not None and ( + if (local_name := matcher.get(LOCAL_NAME)) and ( (device_name := advertisement_data.local_name or device.name) is None or not _memorized_fnmatch( device_name, diff --git a/tests/components/bluetooth/test_init.py b/tests/components/bluetooth/test_init.py index 9b958e2fade..a005a71f048 100644 --- a/tests/components/bluetooth/test_init.py +++ b/tests/components/bluetooth/test_init.py @@ -26,6 +26,14 @@ from homeassistant.components.bluetooth.const import ( SOURCE_LOCAL, UNAVAILABLE_TRACK_SECONDS, ) +from homeassistant.components.bluetooth.match import ( + ADDRESS, + CONNECTABLE, + LOCAL_NAME, + MANUFACTURER_ID, + SERVICE_DATA_UUID, + SERVICE_UUID, +) from homeassistant.config_entries import ConfigEntryState from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import HomeAssistant, callback @@ -987,8 +995,6 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start, enable_bluetoo ) -> None: """Fake subscriber for the BleakScanner.""" callbacks.append((service_info, change)) - if len(callbacks) >= 3: - raise ValueError with patch( "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt @@ -1001,7 +1007,7 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start, enable_bluetoo cancel = bluetooth.async_register_callback( hass, _fake_subscriber, - {"service_uuids": {"cba20d00-224d-11e6-9fb8-0002a5d5c51b"}}, + {SERVICE_UUID: "cba20d00-224d-11e6-9fb8-0002a5d5c51b"}, BluetoothScanningMode.ACTIVE, ) @@ -1026,17 +1032,15 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start, enable_bluetoo empty_device = BLEDevice("11:22:33:44:55:66", "empty") empty_adv = AdvertisementData(local_name="empty") - # 3rd callback raises ValueError but is still tracked inject_advertisement(hass, empty_device, empty_adv) await hass.async_block_till_done() cancel() - # 4th callback should not be tracked since we canceled inject_advertisement(hass, empty_device, empty_adv) await hass.async_block_till_done() - assert len(callbacks) == 3 + assert len(callbacks) == 1 service_info: BluetoothServiceInfo = callbacks[0][0] assert service_info.name == "wohand" @@ -1044,17 +1048,63 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start, enable_bluetoo assert service_info.manufacturer == "Nordic Semiconductor ASA" assert service_info.manufacturer_id == 89 - service_info: BluetoothServiceInfo = callbacks[1][0] - assert service_info.name == "empty" - assert service_info.source == SOURCE_LOCAL - assert service_info.manufacturer is None - assert service_info.manufacturer_id is None - service_info: BluetoothServiceInfo = callbacks[2][0] - assert service_info.name == "empty" +async def test_register_callbacks_raises_exception( + hass, mock_bleak_scanner_start, enable_bluetooth, caplog +): + """Test registering a callback that raises ValueError.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, + change: BluetoothChange, + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + raise ValueError + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ), patch.object(hass.config_entries.flow, "async_init"): + await async_setup_with_default_adapter(hass) + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {SERVICE_UUID: "cba20d00-224d-11e6-9fb8-0002a5d5c51b"}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand") + switchbot_adv = AdvertisementData( + local_name="wohand", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"}, + service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"}, + ) + + inject_advertisement(hass, switchbot_device, switchbot_adv) + + cancel() + + inject_advertisement(hass, switchbot_device, switchbot_adv) + await hass.async_block_till_done() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "wohand" assert service_info.source == SOURCE_LOCAL - assert service_info.manufacturer is None - assert service_info.manufacturer_id is None + assert service_info.manufacturer == "Nordic Semiconductor ASA" + assert service_info.manufacturer_id == 89 + + assert "ValueError" in caplog.text async def test_register_callback_by_address( @@ -1124,7 +1174,7 @@ async def test_register_callback_by_address( cancel = bluetooth.async_register_callback( hass, _fake_subscriber, - {"address": "44:44:33:11:23:45"}, + {ADDRESS: "44:44:33:11:23:45"}, BluetoothScanningMode.ACTIVE, ) cancel() @@ -1134,7 +1184,7 @@ async def test_register_callback_by_address( cancel = bluetooth.async_register_callback( hass, _fake_subscriber, - {"address": "44:44:33:11:23:45"}, + {ADDRESS: "44:44:33:11:23:45"}, BluetoothScanningMode.ACTIVE, ) cancel() @@ -1148,6 +1198,537 @@ async def test_register_callback_by_address( assert service_info.manufacturer_id == 89 +async def test_register_callback_by_address_connectable_only( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by address connectable only.""" + mock_bt = [] + connectable_callbacks = [] + non_connectable_callbacks = [] + + def _fake_connectable_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + connectable_callbacks.append((service_info, change)) + + def _fake_non_connectable_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + non_connectable_callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_connectable_subscriber, + {ADDRESS: "44:44:33:11:23:45", CONNECTABLE: True}, + BluetoothScanningMode.ACTIVE, + ) + cancel2 = bluetooth.async_register_callback( + hass, + _fake_non_connectable_subscriber, + {ADDRESS: "44:44:33:11:23:45", CONNECTABLE: False}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand") + switchbot_adv = AdvertisementData( + local_name="wohand", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"}, + service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"}, + ) + + inject_advertisement_with_time_and_source_connectable( + hass, switchbot_device, switchbot_adv, time.monotonic(), "test", False + ) + inject_advertisement_with_time_and_source_connectable( + hass, switchbot_device, switchbot_adv, time.monotonic(), "test", True + ) + + cancel() + cancel2() + + assert len(connectable_callbacks) == 1 + # Non connectable will take either a connectable + # or non-connectable device + assert len(non_connectable_callbacks) == 2 + + +async def test_register_callback_by_manufacturer_id( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by manufacturer_id.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {MANUFACTURER_ID: 76}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + apple_device = BLEDevice("44:44:33:11:23:45", "apple") + apple_adv = AdvertisementData( + local_name="apple", + manufacturer_data={76: b"\xd8.\xad\xcd\r\x85"}, + ) + + inject_advertisement(hass, apple_device, apple_adv) + + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + inject_advertisement(hass, empty_device, empty_adv) + await hass.async_block_till_done() + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "apple" + assert service_info.manufacturer == "Apple, Inc." + assert service_info.manufacturer_id == 76 + + +async def test_register_callback_by_address_connectable_manufacturer_id( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by address, manufacturer_id, and connectable.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {MANUFACTURER_ID: 76, CONNECTABLE: False, ADDRESS: "44:44:33:11:23:45"}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + apple_device = BLEDevice("44:44:33:11:23:45", "apple") + apple_adv = AdvertisementData( + local_name="apple", + manufacturer_data={76: b"\xd8.\xad\xcd\r\x85"}, + ) + + inject_advertisement(hass, apple_device, apple_adv) + + apple_device_wrong_address = BLEDevice("44:44:33:11:23:46", "apple") + + inject_advertisement(hass, apple_device_wrong_address, apple_adv) + await hass.async_block_till_done() + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "apple" + assert service_info.manufacturer == "Apple, Inc." + assert service_info.manufacturer_id == 76 + + +async def test_register_callback_by_manufacturer_id_and_address( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by manufacturer_id and address.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {MANUFACTURER_ID: 76, ADDRESS: "44:44:33:11:23:45"}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + apple_device = BLEDevice("44:44:33:11:23:45", "apple") + apple_adv = AdvertisementData( + local_name="apple", + manufacturer_data={76: b"\xd8.\xad\xcd\r\x85"}, + ) + + inject_advertisement(hass, apple_device, apple_adv) + + yale_device = BLEDevice("44:44:33:11:23:45", "apple") + yale_adv = AdvertisementData( + local_name="yale", + manufacturer_data={465: b"\xd8.\xad\xcd\r\x85"}, + ) + + inject_advertisement(hass, yale_device, yale_adv) + await hass.async_block_till_done() + + other_apple_device = BLEDevice("44:44:33:11:23:22", "apple") + other_apple_adv = AdvertisementData( + local_name="apple", + manufacturer_data={76: b"\xd8.\xad\xcd\r\x85"}, + ) + inject_advertisement(hass, other_apple_device, other_apple_adv) + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "apple" + assert service_info.manufacturer == "Apple, Inc." + assert service_info.manufacturer_id == 76 + + +async def test_register_callback_by_service_uuid_and_address( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by service_uuid and address.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + { + SERVICE_UUID: "cba20d00-224d-11e6-9fb8-0002a5d5c51b", + ADDRESS: "44:44:33:11:23:45", + }, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + switchbot_dev = BLEDevice("44:44:33:11:23:45", "switchbot") + switchbot_adv = AdvertisementData( + local_name="switchbot", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + ) + + inject_advertisement(hass, switchbot_dev, switchbot_adv) + + switchbot_missing_service_uuid_dev = BLEDevice("44:44:33:11:23:45", "switchbot") + switchbot_missing_service_uuid_adv = AdvertisementData( + local_name="switchbot", + ) + + inject_advertisement( + hass, switchbot_missing_service_uuid_dev, switchbot_missing_service_uuid_adv + ) + await hass.async_block_till_done() + + service_uuid_wrong_address_dev = BLEDevice("44:44:33:11:23:22", "switchbot2") + service_uuid_wrong_address_adv = AdvertisementData( + local_name="switchbot2", + service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"], + ) + inject_advertisement( + hass, service_uuid_wrong_address_dev, service_uuid_wrong_address_adv + ) + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "switchbot" + + +async def test_register_callback_by_service_data_uuid_and_address( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by service_data_uuid and address.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + { + SERVICE_DATA_UUID: "cba20d00-224d-11e6-9fb8-0002a5d5c51b", + ADDRESS: "44:44:33:11:23:45", + }, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + switchbot_dev = BLEDevice("44:44:33:11:23:45", "switchbot") + switchbot_adv = AdvertisementData( + local_name="switchbot", + service_data={"cba20d00-224d-11e6-9fb8-0002a5d5c51b": b"x"}, + ) + + inject_advertisement(hass, switchbot_dev, switchbot_adv) + + switchbot_missing_service_uuid_dev = BLEDevice("44:44:33:11:23:45", "switchbot") + switchbot_missing_service_uuid_adv = AdvertisementData( + local_name="switchbot", + ) + + inject_advertisement( + hass, switchbot_missing_service_uuid_dev, switchbot_missing_service_uuid_adv + ) + await hass.async_block_till_done() + + service_uuid_wrong_address_dev = BLEDevice("44:44:33:11:23:22", "switchbot2") + service_uuid_wrong_address_adv = AdvertisementData( + local_name="switchbot2", + service_data={"cba20d00-224d-11e6-9fb8-0002a5d5c51b": b"x"}, + ) + inject_advertisement( + hass, service_uuid_wrong_address_dev, service_uuid_wrong_address_adv + ) + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "switchbot" + + +async def test_register_callback_by_local_name( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by local_name.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {LOCAL_NAME: "apple"}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + apple_device = BLEDevice("44:44:33:11:23:45", "apple") + apple_adv = AdvertisementData( + local_name="apple", + manufacturer_data={76: b"\xd8.\xad\xcd\r\x85"}, + ) + + inject_advertisement(hass, apple_device, apple_adv) + + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + inject_advertisement(hass, empty_device, empty_adv) + + apple_device_2 = BLEDevice("44:44:33:11:23:45", "apple") + apple_adv_2 = AdvertisementData( + local_name="apple2", + manufacturer_data={76: b"\xd8.\xad\xcd\r\x85"}, + ) + inject_advertisement(hass, apple_device_2, apple_adv_2) + + await hass.async_block_till_done() + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "apple" + assert service_info.manufacturer == "Apple, Inc." + assert service_info.manufacturer_id == 76 + + +async def test_register_callback_by_local_name_overly_broad( + hass, mock_bleak_scanner_start, enable_bluetooth, caplog +): + """Test registering a callback by local_name that is too broad.""" + mock_bt = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with pytest.raises(ValueError): + bluetooth.async_register_callback( + hass, + _fake_subscriber, + {LOCAL_NAME: "a"}, + BluetoothScanningMode.ACTIVE, + ) + + with pytest.raises(ValueError): + bluetooth.async_register_callback( + hass, + _fake_subscriber, + {LOCAL_NAME: "ab*"}, + BluetoothScanningMode.ACTIVE, + ) + + +async def test_register_callback_by_service_data_uuid( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by service_data_uuid.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {SERVICE_DATA_UUID: "0000fe95-0000-1000-8000-00805f9b34fb"}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + apple_device = BLEDevice("44:44:33:11:23:45", "xiaomi") + apple_adv = AdvertisementData( + local_name="xiaomi", + service_data={ + "0000fe95-0000-1000-8000-00805f9b34fb": b"\xd8.\xad\xcd\r\x85" + }, + ) + + inject_advertisement(hass, apple_device, apple_adv) + + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + inject_advertisement(hass, empty_device, empty_adv) + await hass.async_block_till_done() + + cancel() + + assert len(callbacks) == 1 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "xiaomi" + + async def test_register_callback_survives_reload( hass, mock_bleak_scanner_start, enable_bluetooth ): @@ -1169,7 +1750,7 @@ async def test_register_callback_survives_reload( hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) await hass.async_block_till_done() - bluetooth.async_register_callback( + cancel = bluetooth.async_register_callback( hass, _fake_subscriber, {"address": "44:44:33:11:23:45"}, @@ -1203,6 +1784,7 @@ async def test_register_callback_survives_reload( assert service_info.name == "wohand" assert service_info.manufacturer == "Nordic Semiconductor ASA" assert service_info.manufacturer_id == 89 + cancel() async def test_process_advertisements_bail_on_good_advertisement(