Fix async_scanner_devices_by_address unexpectedly combining Bluetooth scanners (#94990)

This commit is contained in:
J. Nick Koston 2023-06-22 16:50:21 +02:00 committed by GitHub
parent 6ec6369c27
commit 1459bf4011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -246,9 +246,12 @@ class BluetoothManager:
self, address: str, connectable: bool self, address: str, connectable: bool
) -> list[BluetoothScannerDevice]: ) -> list[BluetoothScannerDevice]:
"""Get BluetoothScannerDevice by address.""" """Get BluetoothScannerDevice by address."""
scanners = self._get_scanners_by_type(True)
if not connectable: if not connectable:
scanners.extend(self._get_scanners_by_type(False)) scanners: Iterable[BaseHaScanner] = itertools.chain(
self._connectable_scanners, self._non_connectable_scanners
)
else:
scanners = self._connectable_scanners
return [ return [
BluetoothScannerDevice(scanner, *device_adv) BluetoothScannerDevice(scanner, *device_adv)
for scanner in scanners for scanner in scanners
@ -267,21 +270,19 @@ class BluetoothManager:
""" """
yield from itertools.chain.from_iterable( yield from itertools.chain.from_iterable(
scanner.discovered_devices_and_advertisement_data scanner.discovered_devices_and_advertisement_data
for scanner in self._get_scanners_by_type(True) for scanner in self._connectable_scanners
) )
if not connectable: if not connectable:
yield from itertools.chain.from_iterable( yield from itertools.chain.from_iterable(
scanner.discovered_devices_and_advertisement_data scanner.discovered_devices_and_advertisement_data
for scanner in self._get_scanners_by_type(False) for scanner in self._non_connectable_scanners
) )
@hass_callback @hass_callback
def async_discovered_devices(self, connectable: bool) -> list[BLEDevice]: def async_discovered_devices(self, connectable: bool) -> list[BLEDevice]:
"""Return all of combined best path to discovered from all the scanners.""" """Return all of combined best path to discovered from all the scanners."""
return [ histories = self._connectable_history if connectable else self._all_history
history.device return [history.device for history in histories.values()]
for history in self._get_history_by_type(connectable).values()
]
@hass_callback @hass_callback
def async_setup_unavailable_tracking(self) -> None: def async_setup_unavailable_tracking(self) -> None:
@ -303,7 +304,10 @@ class BluetoothManager:
intervals = tracker.intervals intervals = tracker.intervals
for connectable in (True, False): for connectable in (True, False):
unavailable_callbacks = self._get_unavailable_callbacks_by_type(connectable) if connectable:
unavailable_callbacks = self._connectable_unavailable_callbacks
else:
unavailable_callbacks = self._unavailable_callbacks
history = connectable_history if connectable else all_history history = connectable_history if connectable else all_history
disappeared = set(history).difference( disappeared = set(history).difference(
self._async_all_discovered_addresses(connectable) self._async_all_discovered_addresses(connectable)
@ -583,7 +587,10 @@ class BluetoothManager:
connectable: bool, connectable: bool,
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Register a callback.""" """Register a callback."""
unavailable_callbacks = self._get_unavailable_callbacks_by_type(connectable) if connectable:
unavailable_callbacks = self._connectable_unavailable_callbacks
else:
unavailable_callbacks = self._unavailable_callbacks
unavailable_callbacks.setdefault(address, []).append(callback) unavailable_callbacks.setdefault(address, []).append(callback)
@hass_callback @hass_callback
@ -620,13 +627,13 @@ class BluetoothManager:
# If we have history for the subscriber, we can trigger the callback # If we have history for the subscriber, we can trigger the callback
# immediately with the last packet so the subscriber can see the # immediately with the last packet so the subscriber can see the
# device. # device.
all_history = self._get_history_by_type(connectable) history = self._connectable_history if connectable else self._all_history
service_infos: Iterable[BluetoothServiceInfoBleak] = [] service_infos: Iterable[BluetoothServiceInfoBleak] = []
if address := callback_matcher.get(ADDRESS): if address := callback_matcher.get(ADDRESS):
if service_info := all_history.get(address): if service_info := history.get(address):
service_infos = [service_info] service_infos = [service_info]
else: else:
service_infos = all_history.values() service_infos = history.values()
for service_info in service_infos: for service_info in service_infos:
if ble_device_matches(callback_matcher, service_info): if ble_device_matches(callback_matcher, service_info):
@ -642,29 +649,32 @@ class BluetoothManager:
self, address: str, connectable: bool self, address: str, connectable: bool
) -> BLEDevice | None: ) -> BLEDevice | None:
"""Return the BLEDevice if present.""" """Return the BLEDevice if present."""
all_history = self._get_history_by_type(connectable) histories = self._connectable_history if connectable else self._all_history
if history := all_history.get(address): if history := histories.get(address):
return history.device return history.device
return None return None
@hass_callback @hass_callback
def async_address_present(self, address: str, connectable: bool) -> bool: def async_address_present(self, address: str, connectable: bool) -> bool:
"""Return if the address is present.""" """Return if the address is present."""
return address in self._get_history_by_type(connectable) histories = self._connectable_history if connectable else self._all_history
return address in histories
@hass_callback @hass_callback
def async_discovered_service_info( def async_discovered_service_info(
self, connectable: bool self, connectable: bool
) -> Iterable[BluetoothServiceInfoBleak]: ) -> Iterable[BluetoothServiceInfoBleak]:
"""Return all the discovered services info.""" """Return all the discovered services info."""
return self._get_history_by_type(connectable).values() histories = self._connectable_history if connectable else self._all_history
return histories.values()
@hass_callback @hass_callback
def async_last_service_info( def async_last_service_info(
self, address: str, connectable: bool self, address: str, connectable: bool
) -> BluetoothServiceInfoBleak | None: ) -> BluetoothServiceInfoBleak | None:
"""Return the last service info for an address.""" """Return the last service info for an address."""
return self._get_history_by_type(connectable).get(address) histories = self._connectable_history if connectable else self._all_history
return histories.get(address)
def _async_trigger_matching_discovery( def _async_trigger_matching_discovery(
self, service_info: BluetoothServiceInfoBleak self, service_info: BluetoothServiceInfoBleak
@ -688,26 +698,6 @@ class BluetoothManager:
if service_info := self._all_history.get(address): if service_info := self._all_history.get(address):
self._async_trigger_matching_discovery(service_info) self._async_trigger_matching_discovery(service_info)
def _get_scanners_by_type(self, connectable: bool) -> list[BaseHaScanner]:
"""Return the scanners by type."""
if connectable:
return self._connectable_scanners
return self._non_connectable_scanners
def _get_unavailable_callbacks_by_type(
self, connectable: bool
) -> dict[str, list[Callable[[BluetoothServiceInfoBleak], None]]]:
"""Return the unavailable callbacks by type."""
if connectable:
return self._connectable_unavailable_callbacks
return self._unavailable_callbacks
def _get_history_by_type(
self, connectable: bool
) -> dict[str, BluetoothServiceInfoBleak]:
"""Return the history by type."""
return self._connectable_history if connectable else self._all_history
def async_register_scanner( def async_register_scanner(
self, self,
scanner: BaseHaScanner, scanner: BaseHaScanner,
@ -716,7 +706,10 @@ class BluetoothManager:
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
"""Register a new scanner.""" """Register a new scanner."""
_LOGGER.debug("Registering scanner %s", scanner.name) _LOGGER.debug("Registering scanner %s", scanner.name)
scanners = self._get_scanners_by_type(connectable) if connectable:
scanners = self._connectable_scanners
else:
scanners = self._non_connectable_scanners
def _unregister_scanner() -> None: def _unregister_scanner() -> None:
_LOGGER.debug("Unregistering scanner %s", scanner.name) _LOGGER.debug("Unregistering scanner %s", scanner.name)