Move esphome gatt services cache to be per device (#81265)

This commit is contained in:
J. Nick Koston 2022-10-30 18:02:54 -05:00 committed by GitHub
parent 11d7e1e45f
commit c8a3392471
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 24 deletions

View File

@ -265,9 +265,9 @@ class ESPHomeClient(BaseBleakClient):
A :py:class:`bleak.backends.service.BleakGATTServiceCollection` with this device's services tree.
"""
address_as_int = self._address_as_int
domain_data = self.domain_data
entry_data = self.entry_data
if dangerous_use_bleak_cache and (
cached_services := domain_data.get_gatt_services_cache(address_as_int)
cached_services := entry_data.get_gatt_services_cache(address_as_int)
):
_LOGGER.debug(
"Cached services hit for %s - %s",
@ -311,7 +311,7 @@ class ESPHomeClient(BaseBleakClient):
self._ble_device.name,
self._ble_device.address,
)
domain_data.set_gatt_services_cache(address_as_int, services)
entry_data.set_gatt_services_cache(address_as_int, services)
return services
def _resolve_characteristic(

View File

@ -1,13 +1,9 @@
"""Support for esphome domain data."""
from __future__ import annotations
from collections.abc import MutableMapping
from dataclasses import dataclass, field
from typing import TypeVar, cast
from bleak.backends.service import BleakGATTServiceCollection
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
@ -17,7 +13,6 @@ from .entry_data import RuntimeEntryData
STORAGE_VERSION = 1
DOMAIN = "esphome"
MAX_CACHED_SERVICES = 128
_DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData")
@ -29,21 +24,6 @@ class DomainData:
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict)
_entry_by_unique_id: dict[str, ConfigEntry] = field(default_factory=dict)
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
def get_gatt_services_cache(
self, address: int
) -> BleakGATTServiceCollection | None:
"""Get the BleakGATTServiceCollection for the given address."""
return self._gatt_services_cache.get(address)
def set_gatt_services_cache(
self, address: int, services: BleakGATTServiceCollection
) -> None:
"""Set the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache[address] = services
def get_by_unique_id(self, unique_id: str) -> ConfigEntry:
"""Get the config entry by its unique ID."""

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, MutableMapping
from dataclasses import dataclass, field
import logging
from typing import Any, cast
@ -30,6 +30,8 @@ from aioesphomeapi import (
UserService,
)
from aioesphomeapi.model import ButtonInfo
from bleak.backends.service import BleakGATTServiceCollection
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
@ -57,6 +59,7 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = {
SwitchInfo: Platform.SWITCH,
TextSensorInfo: Platform.SENSOR,
}
MAX_CACHED_SERVICES = 128
@dataclass
@ -92,6 +95,21 @@ class RuntimeEntryData:
_ble_connection_free_futures: list[asyncio.Future[int]] = field(
default_factory=list
)
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
def get_gatt_services_cache(
self, address: int
) -> BleakGATTServiceCollection | None:
"""Get the BleakGATTServiceCollection for the given address."""
return self._gatt_services_cache.get(address)
def set_gatt_services_cache(
self, address: int, services: BleakGATTServiceCollection
) -> None:
"""Set the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache[address] = services
@callback
def async_update_ble_connection_limits(self, free: int, limit: int) -> None: