Move esphome gatt services cache to be per device (#81265)

This commit is contained in:
J. Nick Koston 2022-10-30 18:02:54 -05:00 committed by GitHub
parent 11d7e1e45f
commit c8a3392471
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 24 deletions

View File

@ -265,9 +265,9 @@ class ESPHomeClient(BaseBleakClient):
A :py:class:`bleak.backends.service.BleakGATTServiceCollection` with this device's services tree. A :py:class:`bleak.backends.service.BleakGATTServiceCollection` with this device's services tree.
""" """
address_as_int = self._address_as_int address_as_int = self._address_as_int
domain_data = self.domain_data entry_data = self.entry_data
if dangerous_use_bleak_cache and ( if dangerous_use_bleak_cache and (
cached_services := domain_data.get_gatt_services_cache(address_as_int) cached_services := entry_data.get_gatt_services_cache(address_as_int)
): ):
_LOGGER.debug( _LOGGER.debug(
"Cached services hit for %s - %s", "Cached services hit for %s - %s",
@ -311,7 +311,7 @@ class ESPHomeClient(BaseBleakClient):
self._ble_device.name, self._ble_device.name,
self._ble_device.address, self._ble_device.address,
) )
domain_data.set_gatt_services_cache(address_as_int, services) entry_data.set_gatt_services_cache(address_as_int, services)
return services return services
def _resolve_characteristic( def _resolve_characteristic(

View File

@ -1,13 +1,9 @@
"""Support for esphome domain data.""" """Support for esphome domain data."""
from __future__ import annotations from __future__ import annotations
from collections.abc import MutableMapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypeVar, cast from typing import TypeVar, cast
from bleak.backends.service import BleakGATTServiceCollection
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
@ -17,7 +13,6 @@ from .entry_data import RuntimeEntryData
STORAGE_VERSION = 1 STORAGE_VERSION = 1
DOMAIN = "esphome" DOMAIN = "esphome"
MAX_CACHED_SERVICES = 128
_DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData") _DomainDataSelfT = TypeVar("_DomainDataSelfT", bound="DomainData")
@ -29,21 +24,6 @@ class DomainData:
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict) _stores: dict[str, Store] = field(default_factory=dict)
_entry_by_unique_id: dict[str, ConfigEntry] = field(default_factory=dict) _entry_by_unique_id: dict[str, ConfigEntry] = field(default_factory=dict)
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
def get_gatt_services_cache(
self, address: int
) -> BleakGATTServiceCollection | None:
"""Get the BleakGATTServiceCollection for the given address."""
return self._gatt_services_cache.get(address)
def set_gatt_services_cache(
self, address: int, services: BleakGATTServiceCollection
) -> None:
"""Set the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache[address] = services
def get_by_unique_id(self, unique_id: str) -> ConfigEntry: def get_by_unique_id(self, unique_id: str) -> ConfigEntry:
"""Get the config entry by its unique ID.""" """Get the config entry by its unique ID."""

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable, MutableMapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
from typing import Any, cast from typing import Any, cast
@ -30,6 +30,8 @@ from aioesphomeapi import (
UserService, UserService,
) )
from aioesphomeapi.model import ButtonInfo from aioesphomeapi.model import ButtonInfo
from bleak.backends.service import BleakGATTServiceCollection
from lru import LRU # pylint: disable=no-name-in-module
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
@ -57,6 +59,7 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = {
SwitchInfo: Platform.SWITCH, SwitchInfo: Platform.SWITCH,
TextSensorInfo: Platform.SENSOR, TextSensorInfo: Platform.SENSOR,
} }
MAX_CACHED_SERVICES = 128
@dataclass @dataclass
@ -92,6 +95,21 @@ class RuntimeEntryData:
_ble_connection_free_futures: list[asyncio.Future[int]] = field( _ble_connection_free_futures: list[asyncio.Future[int]] = field(
default_factory=list default_factory=list
) )
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) # type: ignore[no-any-return]
)
def get_gatt_services_cache(
self, address: int
) -> BleakGATTServiceCollection | None:
"""Get the BleakGATTServiceCollection for the given address."""
return self._gatt_services_cache.get(address)
def set_gatt_services_cache(
self, address: int, services: BleakGATTServiceCollection
) -> None:
"""Set the BleakGATTServiceCollection for the given address."""
self._gatt_services_cache[address] = services
@callback @callback
def async_update_ble_connection_limits(self, free: int, limit: int) -> None: def async_update_ble_connection_limits(self, free: int, limit: int) -> None: