Revert "Move esphome gatt services cache to be per device" #81265 (#83793)

This commit is contained in:
J. Nick Koston 2022-12-11 17:50:18 -10:00 committed by GitHub
parent 531873fb4d
commit 95641fa780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 46 deletions

View File

@ -230,12 +230,14 @@ class ESPHomeClient(BaseBleakClient):
Boolean representing connection status. Boolean representing connection status.
""" """
await self._wait_for_free_connection_slot(CONNECT_FREE_SLOT_TIMEOUT) await self._wait_for_free_connection_slot(CONNECT_FREE_SLOT_TIMEOUT)
domain_data = self.domain_data
entry_data = self.entry_data entry_data = self.entry_data
self._mtu = entry_data.get_gatt_mtu_cache(self._address_as_int)
self._mtu = domain_data.get_gatt_mtu_cache(self._address_as_int)
has_cache = bool( has_cache = bool(
dangerous_use_bleak_cache dangerous_use_bleak_cache
and self._connection_version >= MIN_BLUETOOTH_PROXY_VERSION_HAS_CACHE and self._connection_version >= MIN_BLUETOOTH_PROXY_VERSION_HAS_CACHE
and entry_data.get_gatt_services_cache(self._address_as_int) and domain_data.get_gatt_services_cache(self._address_as_int)
and self._mtu and self._mtu
) )
connected_future: asyncio.Future[bool] = asyncio.Future() connected_future: asyncio.Future[bool] = asyncio.Future()
@ -257,7 +259,7 @@ class ESPHomeClient(BaseBleakClient):
self._is_connected = True self._is_connected = True
if not self._mtu: if not self._mtu:
self._mtu = mtu self._mtu = mtu
entry_data.set_gatt_mtu_cache(self._address_as_int, mtu) domain_data.set_gatt_mtu_cache(self._address_as_int, mtu)
else: else:
self._async_ble_device_disconnected() self._async_ble_device_disconnected()
@ -392,14 +394,14 @@ class ESPHomeClient(BaseBleakClient):
A :py:class:`bleak.backends.service.BleakGATTServiceCollection` with this device's services tree. A :py:class:`bleak.backends.service.BleakGATTServiceCollection` with this device's services tree.
""" """
address_as_int = self._address_as_int address_as_int = self._address_as_int
entry_data = self.entry_data domain_data = self.domain_data
# If the connection version >= 3, we must use the cache # If the connection version >= 3, we must use the cache
# because the esp has already wiped the services list to # because the esp has already wiped the services list to
# save memory. # save memory.
if ( if (
self._connection_version >= MIN_BLUETOOTH_PROXY_VERSION_HAS_CACHE self._connection_version >= MIN_BLUETOOTH_PROXY_VERSION_HAS_CACHE
or dangerous_use_bleak_cache or dangerous_use_bleak_cache
) and (cached_services := entry_data.get_gatt_services_cache(address_as_int)): ) and (cached_services := domain_data.get_gatt_services_cache(address_as_int)):
_LOGGER.debug( _LOGGER.debug(
"%s: %s - %s: Cached services hit", "%s: %s - %s: Cached services hit",
self._source_name, self._source_name,
@ -458,7 +460,7 @@ class ESPHomeClient(BaseBleakClient):
self._ble_device.name, self._ble_device.name,
self._ble_device.address, self._ble_device.address,
) )
entry_data.set_gatt_services_cache(address_as_int, services) domain_data.set_gatt_services_cache(address_as_int, services)
return services return services
def _resolve_characteristic( def _resolve_characteristic(
@ -475,8 +477,8 @@ class ESPHomeClient(BaseBleakClient):
async def clear_cache(self) -> None: async def clear_cache(self) -> None:
"""Clear the GATT cache.""" """Clear the GATT cache."""
self.entry_data.clear_gatt_services_cache(self._address_as_int) self.domain_data.clear_gatt_services_cache(self._address_as_int)
self.entry_data.clear_gatt_mtu_cache(self._address_as_int) self.domain_data.clear_gatt_mtu_cache(self._address_as_int)
@verify_connected @verify_connected
@api_error_as_bleak_error @api_error_as_bleak_error

View File

@ -1,9 +1,13 @@
"""Support for esphome domain data.""" """Support for esphome domain data."""
from __future__ import annotations from __future__ import annotations
from collections.abc import MutableMapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypeVar, cast from typing import TypeVar, cast
from bleak.backends.service import BleakGATTServiceCollection
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
@ -13,6 +17,7 @@ from .entry_data import RuntimeEntryData
STORAGE_VERSION = 1 STORAGE_VERSION = 1
DOMAIN = "esphome" DOMAIN = "esphome"
MAX_CACHED_SERVICES = 128
_DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData") _DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData")
@ -23,6 +28,40 @@ class DomainData:
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict) _stores: dict[str, Store] = field(default_factory=dict)
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
_gatt_mtu_cache: MutableMapping[int, int] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
def get_gatt_services_cache(
self, address: int
) -> BleakGATTServiceCollection | None:
"""Get the BleakGATTServiceCollection for the given address."""
return self._gatt_services_cache.get(address)
def set_gatt_services_cache(
self, address: int, services: BleakGATTServiceCollection
) -> None:
"""Set the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache[address] = services
def clear_gatt_services_cache(self, address: int) -> None:
"""Clear the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache.pop(address, None)
def get_gatt_mtu_cache(self, address: int) -> int | None:
"""Get the mtu cache for the given address."""
return self._gatt_mtu_cache.get(address)
def set_gatt_mtu_cache(self, address: int, mtu: int) -> None:
"""Set the mtu cache for the given address."""
self._gatt_mtu_cache[address] = mtu
def clear_gatt_mtu_cache(self, address: int) -> None:
"""Clear the mtu cache for the given address."""
self._gatt_mtu_cache.pop(address, None)
def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
"""Return the runtime entry data associated with this config entry. """Return the runtime entry data associated with this config entry.

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, MutableMapping from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
from typing import Any, cast from typing import Any, cast
@ -30,8 +30,6 @@ from aioesphomeapi import (
UserService, UserService,
) )
from aioesphomeapi.model import ButtonInfo from aioesphomeapi.model import ButtonInfo
from bleak.backends.service import BleakGATTServiceCollection
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
@ -59,7 +57,6 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = {
SwitchInfo: Platform.SWITCH, SwitchInfo: Platform.SWITCH,
TextSensorInfo: Platform.SENSOR, TextSensorInfo: Platform.SENSOR,
} }
MAX_CACHED_SERVICES = 128
@dataclass @dataclass
@ -95,46 +92,12 @@ class RuntimeEntryData:
_ble_connection_free_futures: list[asyncio.Future[int]] = field( _ble_connection_free_futures: list[asyncio.Future[int]] = field(
default_factory=list default_factory=list
) )
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
_gatt_mtu_cache: MutableMapping[int, int] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
@property @property
def name(self) -> str: def name(self) -> str:
"""Return the name of the device.""" """Return the name of the device."""
return self.device_info.name if self.device_info else self.entry_id return self.device_info.name if self.device_info else self.entry_id
def get_gatt_services_cache(
self, address: int
) -> BleakGATTServiceCollection | None:
"""Get the BleakGATTServiceCollection for the given address."""
return self._gatt_services_cache.get(address)
def set_gatt_services_cache(
self, address: int, services: BleakGATTServiceCollection
) -> None:
"""Set the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache[address] = services
def clear_gatt_services_cache(self, address: int) -> None:
"""Clear the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache.pop(address, None)
def get_gatt_mtu_cache(self, address: int) -> int | None:
"""Get the mtu cache for the given address."""
return self._gatt_mtu_cache.get(address)
def set_gatt_mtu_cache(self, address: int, mtu: int) -> None:
"""Set the mtu cache for the given address."""
self._gatt_mtu_cache[address] = mtu
def clear_gatt_mtu_cache(self, address: int) -> None:
"""Clear the mtu cache for the given address."""
self._gatt_mtu_cache.pop(address, None)
@callback @callback
def async_update_ble_connection_limits(self, free: int, limit: int) -> None: def async_update_ble_connection_limits(self, free: int, limit: int) -> None:
"""Update the BLE connection limits.""" """Update the BLE connection limits."""