Add support for subscribing to bluetooth callbacks by address (#74773)

This commit is contained in:
J. Nick Koston 2022-07-11 17:14:00 +02:00 committed by GitHub
parent eb922b2a1f
commit c2fefe03b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 415 additions and 22 deletions

View File

@ -8,7 +8,7 @@ import fnmatch
from functools import cached_property
import logging
import platform
from typing import Final
from typing import Final, TypedDict
from bleak import BleakError
from bleak.backends.device import MANUFACTURERS, BLEDevice
@ -26,7 +26,11 @@ from homeassistant.core import (
from homeassistant.data_entry_flow import BaseServiceInfo
from homeassistant.helpers import discovery_flow
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import BluetoothMatcher, async_get_bluetooth
from homeassistant.loader import (
BluetoothMatcher,
BluetoothMatcherOptional,
async_get_bluetooth,
)
from . import models
from .const import DOMAIN
@ -38,6 +42,19 @@ _LOGGER = logging.getLogger(__name__)
MAX_REMEMBER_ADDRESSES: Final = 2048
class BluetoothCallbackMatcherOptional(TypedDict, total=False):
"""Matcher for the bluetooth integration for callback optional fields."""
address: str
class BluetoothCallbackMatcher(
BluetoothMatcherOptional,
BluetoothCallbackMatcherOptional,
):
"""Callback matcher for the bluetooth integration."""
class BluetoothScanningMode(Enum):
"""The mode of scanning for bluetooth devices."""
@ -50,6 +67,7 @@ SCANNING_MODE_TO_BLEAK = {
BluetoothScanningMode.PASSIVE: "passive",
}
ADDRESS: Final = "address"
LOCAL_NAME: Final = "local_name"
SERVICE_UUID: Final = "service_uuid"
MANUFACTURER_ID: Final = "manufacturer_id"
@ -102,11 +120,34 @@ BluetoothChange = Enum("BluetoothChange", "ADVERTISEMENT")
BluetoothCallback = Callable[[BluetoothServiceInfo, BluetoothChange], None]
@hass_callback
def async_discovered_service_info(
hass: HomeAssistant,
) -> list[BluetoothServiceInfo]:
"""Return the discovered devices list."""
if DOMAIN not in hass.data:
return []
manager: BluetoothManager = hass.data[DOMAIN]
return manager.async_discovered_service_info()
@hass_callback
def async_address_present(
hass: HomeAssistant,
address: str,
) -> bool:
"""Check if an address is present in the bluetooth device list."""
if DOMAIN not in hass.data:
return False
manager: BluetoothManager = hass.data[DOMAIN]
return manager.async_address_present(address)
@hass_callback
def async_register_callback(
hass: HomeAssistant,
callback: BluetoothCallback,
match_dict: BluetoothMatcher | None,
match_dict: BluetoothCallbackMatcher | None,
) -> Callable[[], None]:
"""Register to receive a callback on bluetooth change.
@ -128,9 +169,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
def _ble_device_matches(
matcher: BluetoothMatcher, device: BLEDevice, advertisement_data: AdvertisementData
matcher: BluetoothCallbackMatcher | BluetoothMatcher,
device: BLEDevice,
advertisement_data: AdvertisementData,
) -> bool:
"""Check if a ble device and advertisement_data matches the matcher."""
if (
matcher_address := matcher.get(ADDRESS)
) is not None and device.address != matcher_address:
return False
if (
matcher_local_name := matcher.get(LOCAL_NAME)
) is not None and not fnmatch.fnmatch(
@ -192,7 +240,9 @@ class BluetoothManager:
self._integration_matchers = integration_matchers
self.scanner: HaBleakScanner | None = None
self._cancel_device_detected: CALLBACK_TYPE | None = None
self._callbacks: list[tuple[BluetoothCallback, BluetoothMatcher | None]] = []
self._callbacks: list[
tuple[BluetoothCallback, BluetoothCallbackMatcher | None]
] = []
# Some devices use a random address so we need to use
# an LRU to avoid memory issues.
self._matched: LRU = LRU(MAX_REMEMBER_ADDRESSES)
@ -227,14 +277,22 @@ class BluetoothManager:
) -> None:
"""Handle a detected device."""
matched_domains: set[str] | None = None
if device.address not in self._matched:
match_key = (device.address, bool(advertisement_data.manufacturer_data))
match_key_has_mfr_data = (device.address, True)
# If we matched without manufacturer_data, we need to do it again
# since we may think the device is unsupported otherwise
if (
match_key_has_mfr_data not in self._matched
and match_key not in self._matched
):
matched_domains = {
matcher["domain"]
for matcher in self._integration_matchers
if _ble_device_matches(matcher, device, advertisement_data)
}
if matched_domains:
self._matched[device.address] = True
self._matched[match_key] = True
_LOGGER.debug(
"Device detected: %s with advertisement_data: %s matched domains: %s",
device,
@ -275,18 +333,61 @@ class BluetoothManager:
@hass_callback
def async_register_callback(
self, callback: BluetoothCallback, match_dict: BluetoothMatcher | None = None
self,
callback: BluetoothCallback,
matcher: BluetoothCallbackMatcher | None = None,
) -> Callable[[], None]:
"""Register a callback."""
callback_entry = (callback, match_dict)
callback_entry = (callback, matcher)
self._callbacks.append(callback_entry)
@hass_callback
def _async_remove_callback() -> None:
self._callbacks.remove(callback_entry)
# If we have history for the subscriber, we can trigger the callback
# immediately with the last packet so the subscriber can see the
# device.
if (
matcher
and (address := matcher.get(ADDRESS))
and models.HA_BLEAK_SCANNER
and (device_adv_data := models.HA_BLEAK_SCANNER.history.get(address))
):
try:
callback(
BluetoothServiceInfo.from_advertisement(*device_adv_data),
BluetoothChange.ADVERTISEMENT,
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error in bluetooth callback")
return _async_remove_callback
@hass_callback
def async_address_present(self, address: str) -> bool:
"""Return if the address is present."""
return bool(
models.HA_BLEAK_SCANNER
and any(
device.address == address
for device in models.HA_BLEAK_SCANNER.discovered_devices
)
)
@hass_callback
def async_discovered_service_info(self) -> list[BluetoothServiceInfo]:
"""Return if the address is present."""
if models.HA_BLEAK_SCANNER:
discovered = models.HA_BLEAK_SCANNER.discovered_devices
history = models.HA_BLEAK_SCANNER.history
return [
BluetoothServiceInfo.from_advertisement(*history[device.address])
for device in discovered
if device.address in history
]
return []
async def async_stop(self, event: Event) -> None:
"""Stop bluetooth discovery."""
if self._cancel_device_detected:

View File

@ -8,7 +8,11 @@ from typing import Any, Final, cast
from bleak import BleakScanner
from bleak.backends.device import BLEDevice
from bleak.backends.scanner import AdvertisementData, AdvertisementDataCallback
from bleak.backends.scanner import (
AdvertisementData,
AdvertisementDataCallback,
BaseBleakScanner,
)
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.core import CALLBACK_TYPE, callback as hass_callback
@ -52,7 +56,7 @@ class HaBleakScanner(BleakScanner): # type: ignore[misc]
self._callbacks: list[
tuple[AdvertisementDataCallback, dict[str, set[str]]]
] = []
self._history: LRU = LRU(MAX_HISTORY_SIZE)
self.history: LRU = LRU(MAX_HISTORY_SIZE)
super().__init__(*args, **kwargs)
@hass_callback
@ -70,7 +74,7 @@ class HaBleakScanner(BleakScanner): # type: ignore[misc]
# Replay the history since otherwise we miss devices
# that were already discovered before the callback was registered
# or we are in passive mode
for device, advertisement_data in self._history.values():
for device, advertisement_data in self.history.values():
_dispatch_callback(callback, filters, device, advertisement_data)
return _remove_callback
@ -83,31 +87,46 @@ class HaBleakScanner(BleakScanner): # type: ignore[misc]
Here we get the actual callback from bleak and dispatch
it to all the wrapped HaBleakScannerWrapper classes
"""
self._history[device.address] = (device, advertisement_data)
self.history[device.address] = (device, advertisement_data)
for callback_filters in self._callbacks:
_dispatch_callback(*callback_filters, device, advertisement_data)
class HaBleakScannerWrapper(BleakScanner): # type: ignore[misc]
class HaBleakScannerWrapper(BaseBleakScanner): # type: ignore[misc]
"""A wrapper that uses the single instance."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the BleakScanner."""
self._detection_cancel: CALLBACK_TYPE | None = None
self._mapped_filters: dict[str, set[str]] = {}
if "filters" in kwargs:
self._mapped_filters = {k: set(v) for k, v in kwargs["filters"].items()}
if "service_uuids" in kwargs:
self._mapped_filters[FILTER_UUIDS] = set(kwargs["service_uuids"])
self._adv_data_callback: AdvertisementDataCallback | None = None
self._map_filters(*args, **kwargs)
super().__init__(*args, **kwargs)
async def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop scanning for devices."""
return
async def start(self, *args: Any, **kwargs: Any) -> None:
"""Start scanning for devices."""
return
def _map_filters(self, *args: Any, **kwargs: Any) -> bool:
"""Map the filters."""
mapped_filters = {}
if filters := kwargs.get("filters"):
if FILTER_UUIDS not in filters:
_LOGGER.warning("Only %s filters are supported", FILTER_UUIDS)
mapped_filters = {k: set(v) for k, v in filters.items()}
if service_uuids := kwargs.get("service_uuids"):
mapped_filters[FILTER_UUIDS] = set(service_uuids)
if mapped_filters == self._mapped_filters:
return False
self._mapped_filters = mapped_filters
return True
def set_scanning_filter(self, *args: Any, **kwargs: Any) -> None:
"""Set the filters to use."""
if self._map_filters(*args, **kwargs):
self._setup_detection_callback()
def _cancel_callback(self) -> None:
"""Cancel callback."""
@ -127,8 +146,15 @@ class HaBleakScannerWrapper(BleakScanner): # type: ignore[misc]
This method takes the callback and registers it with the long running
scanner.
"""
self._adv_data_callback = callback
self._setup_detection_callback()
def _setup_detection_callback(self) -> None:
"""Set up the detection callback."""
if self._adv_data_callback is None:
return
self._cancel_callback()
super().register_detection_callback(callback)
super().register_detection_callback(self._adv_data_callback)
assert HA_BLEAK_SCANNER is not None
self._detection_cancel = HA_BLEAK_SCANNER.async_register_callback(
self._callback, self._mapped_filters

View File

@ -78,6 +78,30 @@ async def test_setup_and_stop_no_bluetooth(hass, caplog):
assert "Could not create bluetooth scanner" in caplog.text
async def test_calling_async_discovered_devices_no_bluetooth(hass, caplog):
"""Test we fail gracefully when asking for discovered devices and there is no blueooth."""
mock_bt = []
with patch(
"homeassistant.components.bluetooth.HaBleakScanner", side_effect=BleakError
) as mock_ha_bleak_scanner, patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt
), patch.object(
hass.config_entries.flow, "async_init"
):
assert await async_setup_component(
hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}}
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
assert len(mock_ha_bleak_scanner.mock_calls) == 1
assert "Could not create bluetooth scanner" in caplog.text
assert not bluetooth.async_discovered_service_info(hass)
assert not bluetooth.async_address_present(hass, "aa:bb:bb:dd:ee:ff")
async def test_discovery_match_by_service_uuid(hass, mock_bleak_scanner_start):
"""Test bluetooth discovery match by service_uuid."""
mock_bt = [
@ -207,8 +231,47 @@ async def test_discovery_match_by_manufacturer_id_and_first_byte(
assert len(mock_config_flow.mock_calls) == 0
async def test_async_discovered_device_api(hass, mock_bleak_scanner_start):
"""Test the async_discovered_device_api."""
mock_bt = []
with patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt
), patch(
"bleak.BleakScanner.discovered_devices", # Must patch before we setup
[MagicMock(address="44:44:33:11:23:45")],
):
assert not bluetooth.async_discovered_service_info(hass)
assert not bluetooth.async_address_present(hass, "44:44:22:22:11:22")
assert await async_setup_component(
hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}}
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert len(mock_bleak_scanner_start.mock_calls) == 1
assert not bluetooth.async_discovered_service_info(hass)
wrong_device = BLEDevice("44:44:33:11:23:42", "wrong_name")
wrong_adv = AdvertisementData(local_name="wrong_name", service_uuids=[])
models.HA_BLEAK_SCANNER._callback(wrong_device, wrong_adv)
switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand")
switchbot_adv = AdvertisementData(local_name="wohand", service_uuids=[])
models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
service_infos = bluetooth.async_discovered_service_info(hass)
assert len(service_infos) == 1
# wrong_name should not appear because bleak no longer sees it
assert service_infos[0].name == "wohand"
assert bluetooth.async_address_present(hass, "44:44:33:11:23:42") is False
assert bluetooth.async_address_present(hass, "44:44:33:11:23:45") is True
async def test_register_callbacks(hass, mock_bleak_scanner_start):
"""Test configured options for a device are loaded via config entry."""
"""Test registering a callback."""
mock_bt = []
callbacks = []
@ -284,6 +347,92 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start):
assert service_info.manufacturer_id is None
async def test_register_callback_by_address(hass, mock_bleak_scanner_start):
"""Test registering a callback by address."""
mock_bt = []
callbacks = []
def _fake_subscriber(
service_info: BluetoothServiceInfo, change: BluetoothChange
) -> None:
"""Fake subscriber for the BleakScanner."""
callbacks.append((service_info, change))
if len(callbacks) >= 3:
raise ValueError
with patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt
), patch.object(hass.config_entries.flow, "async_init"):
assert await async_setup_component(
hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}}
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
cancel = bluetooth.async_register_callback(
hass,
_fake_subscriber,
{"address": "44:44:33:11:23:45"},
)
assert len(mock_bleak_scanner_start.mock_calls) == 1
switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand")
switchbot_adv = AdvertisementData(
local_name="wohand",
service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"],
manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"},
service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"},
)
models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv)
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv)
await hass.async_block_till_done()
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
# 3rd callback raises ValueError but is still tracked
models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv)
await hass.async_block_till_done()
cancel()
# 4th callback should not be tracked since we canceled
models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv)
await hass.async_block_till_done()
# Now register again with a callback that fails to
# make sure we do not perm fail
cancel = bluetooth.async_register_callback(
hass,
_fake_subscriber,
{"address": "44:44:33:11:23:45"},
)
cancel()
# Now register again, since the 3rd callback
# should fail but we should still record it
cancel = bluetooth.async_register_callback(
hass,
_fake_subscriber,
{"address": "44:44:33:11:23:45"},
)
cancel()
assert len(callbacks) == 3
for idx in range(3):
service_info: BluetoothServiceInfo = callbacks[idx][0]
assert service_info.name == "wohand"
assert service_info.manufacturer == "Nordic Semiconductor ASA"
assert service_info.manufacturer_id == 89
async def test_wrapped_instance_with_filter(hass, mock_bleak_scanner_start):
"""Test consumers can use the wrapped instance with a filter as if it was normal BleakScanner."""
with patch(
@ -438,3 +587,120 @@ async def test_wrapped_instance_with_broken_callbacks(hass, mock_bleak_scanner_s
models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 1
async def test_wrapped_instance_changes_uuids(hass, mock_bleak_scanner_start):
"""Test consumers can use the wrapped instance can change the uuids later."""
with patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=[]
), patch.object(hass.config_entries.flow, "async_init"):
assert await async_setup_component(
hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}}
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
detected = []
def _device_detected(
device: BLEDevice, advertisement_data: AdvertisementData
) -> None:
"""Handle a detected device."""
detected.append((device, advertisement_data))
switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand")
switchbot_adv = AdvertisementData(
local_name="wohand",
service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"],
manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"},
service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"},
)
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
assert models.HA_BLEAK_SCANNER is not None
scanner = models.HaBleakScannerWrapper()
scanner.set_scanning_filter(service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"])
scanner.register_detection_callback(_device_detected)
type(models.HA_BLEAK_SCANNER).discovered_devices = [MagicMock()]
for _ in range(2):
models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 2
# The UUIDs list we created in the wrapped scanner with should be respected
# and we should not get another callback
models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv)
assert len(detected) == 2
async def test_wrapped_instance_changes_filters(hass, mock_bleak_scanner_start):
"""Test consumers can use the wrapped instance can change the filter later."""
with patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=[]
), patch.object(hass.config_entries.flow, "async_init"):
assert await async_setup_component(
hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}}
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
detected = []
def _device_detected(
device: BLEDevice, advertisement_data: AdvertisementData
) -> None:
"""Handle a detected device."""
detected.append((device, advertisement_data))
switchbot_device = BLEDevice("44:44:33:11:23:42", "wohand")
switchbot_adv = AdvertisementData(
local_name="wohand",
service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"],
manufacturer_data={89: b"\xd8.\xad\xcd\r\x85"},
service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"},
)
empty_device = BLEDevice("11:22:33:44:55:62", "empty")
empty_adv = AdvertisementData(local_name="empty")
assert models.HA_BLEAK_SCANNER is not None
scanner = models.HaBleakScannerWrapper()
scanner.set_scanning_filter(
filters={"UUIDs": ["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]}
)
scanner.register_detection_callback(_device_detected)
type(models.HA_BLEAK_SCANNER).discovered_devices = [MagicMock()]
for _ in range(2):
models.HA_BLEAK_SCANNER._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 2
# The UUIDs list we created in the wrapped scanner with should be respected
# and we should not get another callback
models.HA_BLEAK_SCANNER._callback(empty_device, empty_adv)
assert len(detected) == 2
async def test_wrapped_instance_unsupported_filter(
hass, mock_bleak_scanner_start, caplog
):
"""Test we want when their filter is ineffective."""
with patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=[]
), patch.object(hass.config_entries.flow, "async_init"):
assert await async_setup_component(
hass, bluetooth.DOMAIN, {bluetooth.DOMAIN: {}}
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert models.HA_BLEAK_SCANNER is not None
scanner = models.HaBleakScannerWrapper()
scanner.set_scanning_filter(
filters={"unsupported": ["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]}
)
assert "Only UUIDs filters are supported" in caplog.text