diff --git a/homeassistant/components/bluetooth/manager.py b/homeassistant/components/bluetooth/manager.py index 533867496bf..0dcf11fd1e2 100644 --- a/homeassistant/components/bluetooth/manager.py +++ b/homeassistant/components/bluetooth/manager.py @@ -410,11 +410,11 @@ class BluetoothManager: callback_matcher[CONNECTABLE] = matcher.get(CONNECTABLE, True) connectable = callback_matcher[CONNECTABLE] - self._callback_index.add_with_address(callback_matcher) + self._callback_index.add_callback_matcher(callback_matcher) @hass_callback def _async_remove_callback() -> None: - self._callback_index.remove_with_address(callback_matcher) + self._callback_index.remove_callback_matcher(callback_matcher) # If we have history for the subscriber, we can trigger the callback # immediately with the last packet so the subscriber can see the diff --git a/homeassistant/components/bluetooth/match.py b/homeassistant/components/bluetooth/match.py index dd1c9c1fa3c..1a59ee6fe4c 100644 --- a/homeassistant/components/bluetooth/match.py +++ b/homeassistant/components/bluetooth/match.py @@ -173,7 +173,7 @@ class BluetoothMatcherIndexBase(Generic[_T]): self.service_data_uuid_set: set[str] = set() self.manufacturer_id_set: set[int] = set() - def add(self, matcher: _T) -> None: + def add(self, matcher: _T) -> bool: """Add a matcher to the index. Matchers must end up only in one bucket. @@ -185,26 +185,28 @@ class BluetoothMatcherIndexBase(Generic[_T]): self.local_name.setdefault( _local_name_to_index_key(matcher[LOCAL_NAME]), [] ).append(matcher) - return + return True # Manufacturer data is 2nd cheapest since its all ints if MANUFACTURER_ID in matcher: self.manufacturer_id.setdefault(matcher[MANUFACTURER_ID], []).append( matcher ) - return + return True if SERVICE_UUID in matcher: self.service_uuid.setdefault(matcher[SERVICE_UUID], []).append(matcher) - return + return True if SERVICE_DATA_UUID in matcher: self.service_data_uuid.setdefault(matcher[SERVICE_DATA_UUID], []).append( matcher ) - return + return True - def remove(self, matcher: _T) -> None: + return False + + def remove(self, matcher: _T) -> bool: """Remove a matcher from the index. Matchers only end up in one bucket, so once we have @@ -214,19 +216,21 @@ class BluetoothMatcherIndexBase(Generic[_T]): self.local_name[_local_name_to_index_key(matcher[LOCAL_NAME])].remove( matcher ) - return + return True if MANUFACTURER_ID in matcher: self.manufacturer_id[matcher[MANUFACTURER_ID]].remove(matcher) - return + return True if SERVICE_UUID in matcher: self.service_uuid[matcher[SERVICE_UUID]].remove(matcher) - return + return True if SERVICE_DATA_UUID in matcher: self.service_data_uuid[matcher[SERVICE_DATA_UUID]].remove(matcher) - return + return True + + return False def build(self) -> None: """Rebuild the index sets.""" @@ -284,8 +288,11 @@ class BluetoothCallbackMatcherIndex( """Initialize the matcher index.""" super().__init__() self.address: dict[str, list[BluetoothCallbackMatcherWithCallback]] = {} + self.connectable: list[BluetoothCallbackMatcherWithCallback] = [] - def add_with_address(self, matcher: BluetoothCallbackMatcherWithCallback) -> None: + def add_callback_matcher( + self, matcher: BluetoothCallbackMatcherWithCallback + ) -> None: """Add a matcher to the index. Matchers must end up only in one bucket. @@ -296,10 +303,15 @@ class BluetoothCallbackMatcherIndex( self.address.setdefault(matcher[ADDRESS], []).append(matcher) return - super().add(matcher) - self.build() + if super().add(matcher): + self.build() + return - def remove_with_address( + if CONNECTABLE in matcher: + self.connectable.append(matcher) + return + + def remove_callback_matcher( self, matcher: BluetoothCallbackMatcherWithCallback ) -> None: """Remove a matcher from the index. @@ -311,8 +323,13 @@ class BluetoothCallbackMatcherIndex( self.address[matcher[ADDRESS]].remove(matcher) return - super().remove(matcher) - self.build() + if super().remove(matcher): + self.build() + return + + if CONNECTABLE in matcher: + self.connectable.remove(matcher) + return def match_callbacks( self, service_info: BluetoothServiceInfoBleak @@ -322,6 +339,9 @@ class BluetoothCallbackMatcherIndex( for matcher in self.address.get(service_info.address, []): if ble_device_matches(matcher, service_info): matches.append(matcher) + for matcher in self.connectable: + if ble_device_matches(matcher, service_info): + matches.append(matcher) return matches @@ -355,7 +375,6 @@ def ble_device_matches( # Don't check address here since all callers already # check the address and we don't want to double check # since it would result in an unreachable reject case. - if matcher.get(CONNECTABLE, True) and not service_info.connectable: return False diff --git a/tests/components/bluetooth/test_init.py b/tests/components/bluetooth/test_init.py index 1c3c58bc7ab..32feb3d7b0f 100644 --- a/tests/components/bluetooth/test_init.py +++ b/tests/components/bluetooth/test_init.py @@ -1327,6 +1327,61 @@ async def test_register_callback_by_manufacturer_id( assert service_info.manufacturer_id == 21 +async def test_register_callback_by_connectable( + hass, mock_bleak_scanner_start, enable_bluetooth +): + """Test registering a callback by connectable.""" + mock_bt = [] + callbacks = [] + + def _fake_subscriber( + service_info: BluetoothServiceInfo, change: BluetoothChange + ) -> None: + """Fake subscriber for the BleakScanner.""" + callbacks.append((service_info, change)) + + with patch( + "homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt + ): + await async_setup_with_default_adapter(hass) + + with patch.object(hass.config_entries.flow, "async_init"): + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + cancel = bluetooth.async_register_callback( + hass, + _fake_subscriber, + {CONNECTABLE: False}, + BluetoothScanningMode.ACTIVE, + ) + + assert len(mock_bleak_scanner_start.mock_calls) == 1 + + apple_device = BLEDevice("44:44:33:11:23:45", "rtx") + apple_adv = AdvertisementData( + local_name="rtx", + manufacturer_data={7676: b"\xd8.\xad\xcd\r\x85"}, + ) + + inject_advertisement(hass, apple_device, apple_adv) + + empty_device = BLEDevice("11:22:33:44:55:66", "empty") + empty_adv = AdvertisementData(local_name="empty") + + inject_advertisement(hass, empty_device, empty_adv) + await hass.async_block_till_done() + + cancel() + + assert len(callbacks) == 2 + + service_info: BluetoothServiceInfo = callbacks[0][0] + assert service_info.name == "rtx" + service_info: BluetoothServiceInfo = callbacks[1][0] + assert service_info.name == "empty" + + async def test_not_filtering_wanted_apple_devices( hass, mock_bleak_scanner_start, enable_bluetooth ):