Migrate StateAttributes to use a table manager (#89760)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2023-03-15 15:26:29 -10:00 committed by GitHub
parent ccab45520b
commit e379aa23bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 274 additions and 227 deletions

View File

@ -11,10 +11,9 @@ import queue
import sqlite3 import sqlite3
import threading import threading
import time import time
from typing import Any, TypeVar, cast from typing import Any, TypeVar
import async_timeout import async_timeout
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -30,7 +29,6 @@ from homeassistant.const import (
MATCH_ALL, MATCH_ALL,
) )
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.entity import entity_sources
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
async_track_time_change, async_track_time_change,
async_track_time_interval, async_track_time_interval,
@ -40,7 +38,6 @@ from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import UNDEFINED, UndefinedType from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.enum import try_parse_enum from homeassistant.util.enum import try_parse_enum
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from . import migration, statistics from . import migration, statistics
from .const import ( from .const import (
@ -52,7 +49,6 @@ from .const import (
MAX_QUEUE_BACKLOG, MAX_QUEUE_BACKLOG,
MYSQLDB_PYMYSQL_URL_PREFIX, MYSQLDB_PYMYSQL_URL_PREFIX,
MYSQLDB_URL_PREFIX, MYSQLDB_URL_PREFIX,
SQLITE_MAX_BIND_VARS,
SQLITE_URL_PREFIX, SQLITE_URL_PREFIX,
SupportedDialect, SupportedDialect,
) )
@ -79,8 +75,6 @@ from .models import (
) )
from .pool import POOL_SIZE, MutexPool, RecorderPool from .pool import POOL_SIZE, MutexPool, RecorderPool
from .queries import ( from .queries import (
find_shared_attributes_id,
get_shared_attributes,
has_entity_ids_to_migrate, has_entity_ids_to_migrate,
has_event_type_to_migrate, has_event_type_to_migrate,
has_events_context_ids_to_migrate, has_events_context_ids_to_migrate,
@ -89,6 +83,7 @@ from .queries import (
from .run_history import RunHistory from .run_history import RunHistory
from .table_managers.event_data import EventDataManager from .table_managers.event_data import EventDataManager
from .table_managers.event_types import EventTypeManager from .table_managers.event_types import EventTypeManager
from .table_managers.state_attributes import StateAttributesManager
from .table_managers.states_meta import StatesMetaManager from .table_managers.states_meta import StatesMetaManager
from .tasks import ( from .tasks import (
AdjustLRUSizeTask, AdjustLRUSizeTask,
@ -115,7 +110,6 @@ from .tasks import (
) )
from .util import ( from .util import (
build_mysqldb_conv, build_mysqldb_conv,
chunked,
dburl_to_path, dburl_to_path,
end_incomplete_runs, end_incomplete_runs,
is_second_sunday, is_second_sunday,
@ -136,15 +130,6 @@ DEFAULT_URL = "sqlite:///{hass_config_path}"
# States and Events objects # States and Events objects
EXPIRE_AFTER_COMMITS = 120 EXPIRE_AFTER_COMMITS = 120
# The number of attribute ids to cache in memory
#
# Based on:
# - The number of overlapping attributes
# - How frequently states with overlapping attributes will change
# - How much memory our low end hardware has
STATE_ATTRIBUTES_ID_CACHE_SIZE = 2048
SHUTDOWN_TASK = object() SHUTDOWN_TASK = object()
COMMIT_TASK = CommitTask() COMMIT_TASK = CommitTask()
@ -206,7 +191,6 @@ class Recorder(threading.Thread):
self._queue_watch = threading.Event() self._queue_watch = threading.Event()
self.engine: Engine | None = None self.engine: Engine | None = None
self.run_history = RunHistory() self.run_history = RunHistory()
self._entity_sources = entity_sources(hass)
# The entity_filter is exposed on the recorder instance so that # The entity_filter is exposed on the recorder instance so that
# it can be used to see if an entity is being recorded and is called # it can be used to see if an entity is being recorded and is called
@ -217,11 +201,12 @@ class Recorder(threading.Thread):
self.schema_version = 0 self.schema_version = 0
self._commits_without_expire = 0 self._commits_without_expire = 0
self._old_states: dict[str | None, States] = {} self._old_states: dict[str | None, States] = {}
self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE)
self.event_data_manager = EventDataManager(self) self.event_data_manager = EventDataManager(self)
self.event_type_manager = EventTypeManager(self) self.event_type_manager = EventTypeManager(self)
self.states_meta_manager = StatesMetaManager(self) self.states_meta_manager = StatesMetaManager(self)
self._pending_state_attributes: dict[str, StateAttributes] = {} self.state_attributes_manager = StateAttributesManager(
self, exclude_attributes_by_domain
)
self._pending_expunge: list[States] = [] self._pending_expunge: list[States] = []
self.event_session: Session | None = None self.event_session: Session | None = None
self._get_session: Callable[[], Session] | None = None self._get_session: Callable[[], Session] | None = None
@ -231,7 +216,6 @@ class Recorder(threading.Thread):
self.migration_is_live = False self.migration_is_live = False
self._database_lock_task: DatabaseLockTask | None = None self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self._exclude_attributes_by_domain = exclude_attributes_by_domain
self._event_listener: CALLBACK_TYPE | None = None self._event_listener: CALLBACK_TYPE | None = None
self._queue_watcher: CALLBACK_TYPE | None = None self._queue_watcher: CALLBACK_TYPE | None = None
@ -507,11 +491,9 @@ class Recorder(threading.Thread):
If the number of entities has increased, increase the size of the LRU If the number of entities has increased, increase the size of the LRU
cache to avoid thrashing. cache to avoid thrashing.
""" """
state_attributes_lru = self._state_attributes_ids
current_size = state_attributes_lru.get_size()
new_size = self.hass.states.async_entity_ids_count() * 2 new_size = self.hass.states.async_entity_ids_count() * 2
if new_size > current_size: self.state_attributes_manager.adjust_lru_size(new_size)
state_attributes_lru.set_size(new_size) self.states_meta_manager.adjust_lru_size(new_size)
@callback @callback
def async_periodic_statistics(self) -> None: def async_periodic_statistics(self) -> None:
@ -776,33 +758,10 @@ class Recorder(threading.Thread):
non_state_change_events.append(event_) non_state_change_events.append(event_)
assert self.event_session is not None assert self.event_session is not None
self._pre_process_state_change_events(state_change_events)
self.event_data_manager.load(non_state_change_events, self.event_session) self.event_data_manager.load(non_state_change_events, self.event_session)
self.event_type_manager.load(non_state_change_events, self.event_session) self.event_type_manager.load(non_state_change_events, self.event_session)
self.states_meta_manager.load(state_change_events, self.event_session) self.states_meta_manager.load(state_change_events, self.event_session)
self.state_attributes_manager.load(state_change_events, self.event_session)
def _pre_process_state_change_events(self, events: list[Event]) -> None:
"""Load startup state attributes from the database.
Since the _state_attributes_ids cache is empty at startup
we restore it from the database to avoid having to look up
the attributes in the database for every state change
until its primed.
"""
assert self.event_session is not None
if hashes := {
StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
for event in events
if (
shared_attrs_bytes := self._serialize_state_attributes_from_event(event)
)
}:
with self.event_session.no_autoflush:
for hash_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for id_, shared_attrs in self.event_session.execute(
get_shared_attributes(hash_chunk)
).fetchall():
self._state_attributes_ids[shared_attrs] = id_
def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None: def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None:
"""Process a task, guarding against exceptions to ensure the loop does not collapse.""" """Process a task, guarding against exceptions to ensure the loop does not collapse."""
@ -932,24 +891,6 @@ class Recorder(threading.Thread):
if not self.commit_interval: if not self.commit_interval:
self._commit_event_session_or_retry() self._commit_event_session_or_retry()
def _find_shared_attr_in_db(self, attr_hash: int, shared_attrs: str) -> int | None:
"""Find shared attributes in the db from the hash and shared_attrs."""
#
# Avoid the event session being flushed since it will
# commit all the pending events and states to the database.
#
# The lookup has already have checked to see if the data is cached
# or going to be written in the next commit so there is no
# need to flush before checking the database.
#
assert self.event_session is not None
with self.event_session.no_autoflush:
if attributes_id := self.event_session.execute(
find_shared_attributes_id(attr_hash, shared_attrs)
).first():
return cast(int, attributes_id[0])
return None
def _process_non_state_changed_event_into_session(self, event: Event) -> None: def _process_non_state_changed_event_into_session(self, event: Event) -> None:
"""Process any event into the session except state changed.""" """Process any event into the session except state changed."""
session = self.event_session session = self.event_session
@ -996,67 +937,53 @@ class Recorder(threading.Thread):
session.add(dbevent) session.add(dbevent)
def _serialize_state_attributes_from_event(self, event: Event) -> bytes | None:
"""Serialize state changed event data."""
try:
return StateAttributes.shared_attrs_bytes_from_event(
event,
self._entity_sources,
self._exclude_attributes_by_domain,
self.dialect_name,
)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
event.data.get("new_state"),
ex,
)
return None
def _process_state_changed_event_into_session(self, event: Event) -> None: def _process_state_changed_event_into_session(self, event: Event) -> None:
"""Process a state_changed event into the session.""" """Process a state_changed event into the session."""
state_attributes_manager = self.state_attributes_manager
dbstate = States.from_event(event) dbstate = States.from_event(event)
if (entity_id := dbstate.entity_id) is None or not ( if (entity_id := dbstate.entity_id) is None or not (
shared_attrs_bytes := self._serialize_state_attributes_from_event(event) shared_attrs_bytes := state_attributes_manager.serialize_from_event(event)
): ):
return return
assert self.event_session is not None assert self.event_session is not None
event_session = self.event_session session = self.event_session
# Map the entity_id to the StatesMeta table # Map the entity_id to the StatesMeta table
states_meta_manager = self.states_meta_manager states_meta_manager = self.states_meta_manager
if pending_states_meta := states_meta_manager.get_pending(entity_id): if pending_states_meta := states_meta_manager.get_pending(entity_id):
dbstate.states_meta_rel = pending_states_meta dbstate.states_meta_rel = pending_states_meta
elif metadata_id := states_meta_manager.get(entity_id, event_session, True): elif metadata_id := states_meta_manager.get(entity_id, session, True):
dbstate.metadata_id = metadata_id dbstate.metadata_id = metadata_id
else: else:
states_meta = StatesMeta(entity_id=entity_id) states_meta = StatesMeta(entity_id=entity_id)
states_meta_manager.add_pending(states_meta) states_meta_manager.add_pending(states_meta)
event_session.add(states_meta) session.add(states_meta)
dbstate.states_meta_rel = states_meta dbstate.states_meta_rel = states_meta
# Map the event data to the StateAttributes table
shared_attrs = shared_attrs_bytes.decode("utf-8") shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None dbstate.attributes = None
# Matching attributes found in the pending commit # Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs): if pending_event_data := state_attributes_manager.get_pending(shared_attrs):
dbstate.state_attributes = pending_attributes dbstate.state_attributes = pending_event_data
# Matching attributes id found in the cache # Matching attributes id found in the cache
elif attributes_id := self._state_attributes_ids.get(shared_attrs): elif (
attributes_id := state_attributes_manager.get_from_cache(shared_attrs)
) or (
(hash_ := StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes))
and (
attributes_id := state_attributes_manager.get(
shared_attrs, hash_, session
)
)
):
dbstate.attributes_id = attributes_id dbstate.attributes_id = attributes_id
else: else:
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
# Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id
self._state_attributes_ids[shared_attrs] = attributes_id
# No matching attributes found, save them in the DB # No matching attributes found, save them in the DB
else: dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
dbstate_attributes = StateAttributes( state_attributes_manager.add_pending(dbstate_attributes)
shared_attrs=shared_attrs, hash=attr_hash session.add(dbstate_attributes)
) dbstate.state_attributes = dbstate_attributes
dbstate.state_attributes = dbstate_attributes
self._pending_state_attributes[shared_attrs] = dbstate_attributes
self.event_session.add(dbstate_attributes)
if old_state := self._old_states.pop(entity_id, None): if old_state := self._old_states.pop(entity_id, None):
if old_state.state_id: if old_state.state_id:
@ -1128,11 +1055,7 @@ class Recorder(threading.Thread):
# and we now know the attributes_ids. We can save # and we now know the attributes_ids. We can save
# many selects for matching attributes by loading them # many selects for matching attributes by loading them
# into the LRU cache now. # into the LRU cache now.
for state_attr in self._pending_state_attributes.values(): self.state_attributes_manager.post_commit_pending()
self._state_attributes_ids[
state_attr.shared_attrs
] = state_attr.attributes_id
self._pending_state_attributes = {}
self.event_data_manager.post_commit_pending() self.event_data_manager.post_commit_pending()
self.event_type_manager.post_commit_pending() self.event_type_manager.post_commit_pending()
self.states_meta_manager.post_commit_pending() self.states_meta_manager.post_commit_pending()
@ -1158,8 +1081,7 @@ class Recorder(threading.Thread):
def _close_event_session(self) -> None: def _close_event_session(self) -> None:
"""Close the event session.""" """Close the event session."""
self._old_states.clear() self._old_states.clear()
self._state_attributes_ids.clear() self.state_attributes_manager.reset()
self._pending_state_attributes.clear()
self.event_data_manager.reset() self.event_data_manager.reset()
self.event_type_manager.reset() self.event_type_manager.reset()
self.states_meta_manager.reset() self.states_meta_manager.reset()

View File

@ -479,28 +479,6 @@ def _evict_purged_states_from_old_states_cache(
old_states.pop(old_state_reversed[purged_state_id], None) old_states.pop(old_state_reversed[purged_state_id], None)
def _evict_purged_attributes_from_attributes_cache(
instance: Recorder, purged_attributes_ids: set[int]
) -> None:
"""Evict purged attribute ids from the attribute ids cache."""
# Make a map from attributes_id to the attributes json
state_attributes_ids = (
instance._state_attributes_ids # pylint: disable=protected-access
)
state_attributes_ids_reversed = {
attributes_id: attributes
for attributes, attributes_id in state_attributes_ids.items()
}
# Evict any purged attributes from the state_attributes_ids cache
for purged_attribute_id in purged_attributes_ids.intersection(
state_attributes_ids_reversed
):
state_attributes_ids.pop(
state_attributes_ids_reversed[purged_attribute_id], None
)
def _purge_batch_attributes_ids( def _purge_batch_attributes_ids(
instance: Recorder, session: Session, attributes_ids: set[int] instance: Recorder, session: Session, attributes_ids: set[int]
) -> None: ) -> None:
@ -512,7 +490,7 @@ def _purge_batch_attributes_ids(
_LOGGER.debug("Deleted %s attribute states", deleted_rows) _LOGGER.debug("Deleted %s attribute states", deleted_rows)
# Evict any entries in the state_attributes_ids cache referring to a purged state # Evict any entries in the state_attributes_ids cache referring to a purged state
_evict_purged_attributes_from_attributes_cache(instance, attributes_ids) instance.state_attributes_manager.evict_purged(attributes_ids)
def _purge_batch_data_ids( def _purge_batch_data_ids(

View File

@ -74,17 +74,6 @@ def find_states_metadata_ids(entity_ids: Iterable[str]) -> StatementLambdaElemen
) )
def find_shared_attributes_id(
data_hash: int, shared_attrs: str
) -> StatementLambdaElement:
"""Find an attributes_id by hash and shared_attrs."""
return lambda_stmt(
lambda: select(StateAttributes.attributes_id)
.filter(StateAttributes.hash == data_hash)
.filter(StateAttributes.shared_attrs == shared_attrs)
)
def _state_attrs_exist(attr: int | None) -> Select: def _state_attrs_exist(attr: int | None) -> Select:
"""Check if a state attributes id exists in the states table.""" """Check if a state attributes id exists in the states table."""
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 # https://github.com/sqlalchemy/sqlalchemy/issues/9189

View File

@ -1,15 +1,75 @@
"""Managers for each table.""" """Managers for each table."""
from typing import TYPE_CHECKING from collections.abc import MutableMapping
from typing import TYPE_CHECKING, Generic, TypeVar
from lru import LRU # pylint: disable=no-name-in-module
if TYPE_CHECKING: if TYPE_CHECKING:
from ..core import Recorder from ..core import Recorder
_DataT = TypeVar("_DataT")
class BaseTableManager:
class BaseTableManager(Generic[_DataT]):
"""Base class for table managers.""" """Base class for table managers."""
def __init__(self, recorder: "Recorder") -> None: def __init__(self, recorder: "Recorder") -> None:
"""Initialize the table manager.""" """Initialize the table manager.
The table manager is responsible for managing the id mappings
for a table. When data is committed to the database, the
manager will move the data from the pending to the id map.
"""
self.active = False self.active = False
self.recorder = recorder self.recorder = recorder
self._pending: dict[str, _DataT] = {}
self._id_map: MutableMapping[str, int] = {}
def get_from_cache(self, data: str) -> int | None:
"""Resolve data to the id without accessing the underlying database.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._id_map.get(data)
def get_pending(self, shared_data: str) -> _DataT | None:
"""Get pending data that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(shared_data)
def reset(self) -> None:
"""Reset after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
class BaseLRUTableManager(BaseTableManager[_DataT]):
"""Base class for LRU table managers."""
def __init__(self, recorder: "Recorder", lru_size: int) -> None:
"""Initialize the LRU table manager.
We keep track of the most recently used items
and evict the least recently used items when the cache is full.
"""
super().__init__(recorder)
self._id_map: MutableMapping[str, int] = LRU(lru_size)
def adjust_lru_size(self, new_size: int) -> None:
"""Adjust the LRU cache size.
This call is not thread-safe and must be called from the
recorder thread.
"""
lru: LRU = self._id_map
if new_size > lru.get_size():
lru.set_size(new_size)

View File

@ -5,13 +5,12 @@ from collections.abc import Iterable
import logging import logging
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from . import BaseTableManager from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventData from ..db_schema import EventData
from ..queries import get_shared_event_datas from ..queries import get_shared_event_datas
@ -26,14 +25,12 @@ CACHE_SIZE = 2048
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class EventDataManager(BaseTableManager): class EventDataManager(BaseLRUTableManager[EventData]):
"""Manage the EventData table.""" """Manage the EventData table."""
def __init__(self, recorder: Recorder) -> None: def __init__(self, recorder: Recorder) -> None:
"""Initialize the event type manager.""" """Initialize the event type manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE) super().__init__(recorder, CACHE_SIZE)
self._pending: dict[str, EventData] = {}
super().__init__(recorder)
self.active = True # always active self.active = True # always active
def serialize_from_event(self, event: Event) -> bytes | None: def serialize_from_event(self, event: Event) -> bytes | None:
@ -67,14 +64,6 @@ class EventDataManager(BaseTableManager):
""" """
return self.get_many(((shared_data, data_hash),), session)[shared_data] return self.get_many(((shared_data, data_hash),), session)[shared_data]
def get_from_cache(self, shared_data: str) -> int | None:
"""Resolve shared_data to the data_id without accessing the underlying database.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._id_map.get(shared_data)
def get_many( def get_many(
self, shared_data_data_hashs: Iterable[tuple[str, int]], session: Session self, shared_data_data_hashs: Iterable[tuple[str, int]], session: Session
) -> dict[str, int | None]: ) -> dict[str, int | None]:
@ -116,14 +105,6 @@ class EventDataManager(BaseTableManager):
return results return results
def get_pending(self, shared_data: str) -> EventData | None:
"""Get pending EventData that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(shared_data)
def add_pending(self, db_event_data: EventData) -> None: def add_pending(self, db_event_data: EventData) -> None:
"""Add a pending EventData that will be committed at the next interval. """Add a pending EventData that will be committed at the next interval.
@ -144,15 +125,6 @@ class EventDataManager(BaseTableManager):
self._id_map[shared_data] = db_event_data.data_id self._id_map[shared_data] = db_event_data.data_id
self._pending.clear() self._pending.clear()
def reset(self) -> None:
"""Reset the event manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, data_ids: set[int]) -> None: def evict_purged(self, data_ids: set[int]) -> None:
"""Evict purged data_ids from the cache when they are no longer used. """Evict purged data_ids from the cache when they are no longer used.

View File

@ -4,12 +4,11 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event
from . import BaseTableManager from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import EventTypes from ..db_schema import EventTypes
from ..queries import find_event_type_ids from ..queries import find_event_type_ids
@ -22,14 +21,12 @@ if TYPE_CHECKING:
CACHE_SIZE = 2048 CACHE_SIZE = 2048
class EventTypeManager(BaseTableManager): class EventTypeManager(BaseLRUTableManager[EventTypes]):
"""Manage the EventTypes table.""" """Manage the EventTypes table."""
def __init__(self, recorder: Recorder) -> None: def __init__(self, recorder: Recorder) -> None:
"""Initialize the event type manager.""" """Initialize the event type manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE) super().__init__(recorder, CACHE_SIZE)
self._pending: dict[str, EventTypes] = {}
super().__init__(recorder)
def load(self, events: list[Event], session: Session) -> None: def load(self, events: list[Event], session: Session) -> None:
"""Load the event_type to event_type_ids mapping into memory. """Load the event_type to event_type_ids mapping into memory.
@ -80,14 +77,6 @@ class EventTypeManager(BaseTableManager):
return results return results
def get_pending(self, event_type: str) -> EventTypes | None:
"""Get pending EventTypes that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(event_type)
def add_pending(self, db_event_type: EventTypes) -> None: def add_pending(self, db_event_type: EventTypes) -> None:
"""Add a pending EventTypes that will be committed at the next interval. """Add a pending EventTypes that will be committed at the next interval.
@ -108,15 +97,6 @@ class EventTypeManager(BaseTableManager):
self._id_map[event_type] = db_event_types.event_type_id self._id_map[event_type] = db_event_types.event_type_id
self._pending.clear() self._pending.clear()
def reset(self) -> None:
"""Reset the event manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, event_types: Iterable[str]) -> None: def evict_purged(self, event_types: Iterable[str]) -> None:
"""Evict purged event_types from the cache when they are no longer used. """Evict purged event_types from the cache when they are no longer used.

View File

@ -0,0 +1,160 @@
"""Support managing StateAttributes."""
from __future__ import annotations
from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING, cast
from sqlalchemy.orm.session import Session
from homeassistant.core import Event
from homeassistant.helpers.entity import entity_sources
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS
from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StateAttributes
from ..queries import get_shared_attributes
from ..util import chunked
if TYPE_CHECKING:
from ..core import Recorder
# The number of attribute ids to cache in memory
#
# Based on:
# - The number of overlapping attributes
# - How frequently states with overlapping attributes will change
# - How much memory our low end hardware has
CACHE_SIZE = 2048
_LOGGER = logging.getLogger(__name__)
class StateAttributesManager(BaseLRUTableManager[StateAttributes]):
"""Manage the StateAttributes table."""
def __init__(
self, recorder: Recorder, exclude_attributes_by_domain: dict[str, set[str]]
) -> None:
"""Initialize the event type manager."""
super().__init__(recorder, CACHE_SIZE)
self.active = True # always active
self._exclude_attributes_by_domain = exclude_attributes_by_domain
self._entity_sources = entity_sources(recorder.hass)
def serialize_from_event(self, event: Event) -> bytes | None:
"""Serialize event data."""
try:
return StateAttributes.shared_attrs_bytes_from_event(
event,
self._entity_sources,
self._exclude_attributes_by_domain,
self.recorder.dialect_name,
)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
event.data.get("new_state"),
ex,
)
return None
def load(self, events: list[Event], session: Session) -> None:
"""Load the shared_attrs to attributes_ids mapping into memory from events.
This call is not thread-safe and must be called from the
recorder thread.
"""
if hashes := {
StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
for event in events
if (shared_attrs_bytes := self.serialize_from_event(event))
}:
self._load_from_hashes(hashes, session)
def get(self, shared_attr: str, data_hash: int, session: Session) -> int | None:
"""Resolve shared_attrs to the attributes_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self.get_many(((shared_attr, data_hash),), session)[shared_attr]
def get_many(
self, shared_attrs_data_hashes: Iterable[tuple[str, int]], session: Session
) -> dict[str, int | None]:
"""Resolve shared_attrs to attributes_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
missing_hashes: set[int] = set()
for shared_attrs, data_hash in shared_attrs_data_hashes:
if (attributes_id := self._id_map.get(shared_attrs)) is None:
missing_hashes.add(data_hash)
results[shared_attrs] = attributes_id
if not missing_hashes:
return results
return results | self._load_from_hashes(missing_hashes, session)
def _load_from_hashes(
self, hashes: Iterable[int], session: Session
) -> dict[str, int | None]:
"""Load the shared_attrs to attributes_ids mapping into memory from a list of hashes.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
for attributes_id, shared_attrs in session.execute(
get_shared_attributes(hashs_chunk)
):
results[shared_attrs] = self._id_map[shared_attrs] = cast(
int, attributes_id
)
return results
def add_pending(self, db_state_attributes: StateAttributes) -> None:
"""Add a pending StateAttributes that will be committed at the next interval.
This call is not thread-safe and must be called from the
recorder thread.
"""
assert db_state_attributes.shared_attrs is not None
shared_attrs: str = db_state_attributes.shared_attrs
self._pending[shared_attrs] = db_state_attributes
def post_commit_pending(self) -> None:
"""Call after commit to load the attributes_ids of the new StateAttributes into the LRU.
This call is not thread-safe and must be called from the
recorder thread.
"""
for shared_attrs, db_state_attributes in self._pending.items():
self._id_map[shared_attrs] = db_state_attributes.attributes_id
self._pending.clear()
def evict_purged(self, attributes_ids: set[int]) -> None:
"""Evict purged attributes_ids from the cache when they are no longer used.
This call is not thread-safe and must be called from the
recorder thread.
"""
id_map = self._id_map
state_attributes_ids_reversed = {
attributes_id: shared_attrs
for shared_attrs, attributes_id in id_map.items()
}
# Evict any purged data from the cache
for purged_attributes_id in attributes_ids.intersection(
state_attributes_ids_reversed
):
id_map.pop(state_attributes_ids_reversed[purged_attributes_id], None)

View File

@ -4,12 +4,11 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.core import Event from homeassistant.core import Event
from . import BaseTableManager from . import BaseLRUTableManager
from ..const import SQLITE_MAX_BIND_VARS from ..const import SQLITE_MAX_BIND_VARS
from ..db_schema import StatesMeta from ..db_schema import StatesMeta
from ..queries import find_all_states_metadata_ids, find_states_metadata_ids from ..queries import find_all_states_metadata_ids, find_states_metadata_ids
@ -21,15 +20,13 @@ if TYPE_CHECKING:
CACHE_SIZE = 8192 CACHE_SIZE = 8192
class StatesMetaManager(BaseTableManager): class StatesMetaManager(BaseLRUTableManager[StatesMeta]):
"""Manage the StatesMeta table.""" """Manage the StatesMeta table."""
def __init__(self, recorder: Recorder) -> None: def __init__(self, recorder: Recorder) -> None:
"""Initialize the states meta manager.""" """Initialize the states meta manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
self._pending: dict[str, StatesMeta] = {}
self._did_first_load = False self._did_first_load = False
super().__init__(recorder) super().__init__(recorder, CACHE_SIZE)
def load(self, events: list[Event], session: Session) -> None: def load(self, events: list[Event], session: Session) -> None:
"""Load the entity_id to metadata_id mapping into memory. """Load the entity_id to metadata_id mapping into memory.
@ -112,14 +109,6 @@ class StatesMetaManager(BaseTableManager):
return results return results
def get_pending(self, entity_id: str) -> StatesMeta | None:
"""Get pending StatesMeta that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(entity_id)
def add_pending(self, db_states_meta: StatesMeta) -> None: def add_pending(self, db_states_meta: StatesMeta) -> None:
"""Add a pending StatesMeta that will be committed at the next interval. """Add a pending StatesMeta that will be committed at the next interval.
@ -140,15 +129,6 @@ class StatesMetaManager(BaseTableManager):
self._id_map[entity_id] = db_states_meta.metadata_id self._id_map[entity_id] = db_states_meta.metadata_id
self._pending.clear() self._pending.clear()
def reset(self) -> None:
"""Reset the states meta manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, entity_ids: Iterable[str]) -> None: def evict_purged(self, entity_ids: Iterable[str]) -> None:
"""Evict purged event_types from the cache when they are no longer used. """Evict purged event_types from the cache when they are no longer used.

View File

@ -1872,9 +1872,11 @@ def test_deduplication_event_data_inside_commit_interval(
assert all(event.data_id == first_data_id for event in events) assert all(event.data_id == first_data_id for event in events)
# Patch STATE_ATTRIBUTES_ID_CACHE_SIZE since otherwise # Patch CACHE_SIZE since otherwise
# the CI can fail because the test takes too long to run # the CI can fail because the test takes too long to run
@patch("homeassistant.components.recorder.core.STATE_ATTRIBUTES_ID_CACHE_SIZE", 5) @patch(
"homeassistant.components.recorder.table_managers.state_attributes.CACHE_SIZE", 5
)
def test_deduplication_state_attributes_inside_commit_interval( def test_deduplication_state_attributes_inside_commit_interval(
hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture hass_recorder: Callable[..., HomeAssistant], caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
@ -2159,4 +2161,8 @@ async def test_lru_increases_with_many_entities(
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10)) async_fire_time_changed(hass, dt_util.utcnow() + timedelta(minutes=10))
await async_wait_recording_done(hass) await async_wait_recording_done(hass)
assert recorder_mock._state_attributes_ids.get_size() == mock_entity_count * 2 assert (
recorder_mock.state_attributes_manager._id_map.get_size()
== mock_entity_count * 2
)
assert recorder_mock.states_meta_manager._id_map.get_size() == mock_entity_count * 2