Migrate States to use a table manager (#89769)

This commit is contained in:
J. Nick Koston 2023-03-15 16:19:43 -10:00 committed by GitHub
parent 4080d68489
commit 99d6b1fa57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 42 deletions

View File

@ -84,6 +84,7 @@ from .run_history import RunHistory
from .table_managers.event_data import EventDataManager
from .table_managers.event_types import EventTypeManager
from .table_managers.state_attributes import StateAttributesManager
from .table_managers.states import StatesManager
from .table_managers.states_meta import StatesMetaManager
from .tasks import (
AdjustLRUSizeTask,
@ -200,14 +201,13 @@ class Recorder(threading.Thread):
self.schema_version = 0
self._commits_without_expire = 0
self._old_states: dict[str | None, States] = {}
self.states_manager = StatesManager()
self.event_data_manager = EventDataManager(self)
self.event_type_manager = EventTypeManager(self)
self.states_meta_manager = StatesMetaManager(self)
self.state_attributes_manager = StateAttributesManager(
self, exclude_attributes_by_domain
)
self._pending_expunge: list[States] = []
self.event_session: Session | None = None
self._get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None
@ -985,14 +985,13 @@ class Recorder(threading.Thread):
session.add(dbstate_attributes)
dbstate.state_attributes = dbstate_attributes
if old_state := self._old_states.pop(entity_id, None):
if old_state.state_id:
dbstate.old_state_id = old_state.state_id
else:
dbstate.old_state = old_state
states_manager = self.states_manager
if old_state := states_manager.pop_pending(entity_id):
dbstate.old_state = old_state
elif old_state_id := states_manager.pop_committed(entity_id):
dbstate.old_state_id = old_state_id
if event.data.get("new_state"):
self._old_states[entity_id] = dbstate
self._pending_expunge.append(dbstate)
states_manager.add_pending(entity_id, dbstate)
else:
dbstate.state = None
@ -1043,18 +1042,11 @@ class Recorder(threading.Thread):
self._commits_without_expire += 1
self.event_session.commit()
if self._pending_expunge:
for dbstate in self._pending_expunge:
# Expunge the state so its not expired
# until we use it later for dbstate.old_state
if dbstate in self.event_session:
self.event_session.expunge(dbstate)
self._pending_expunge = []
# We just committed the state attributes to the database
# and we now know the attributes_ids. We can save
# many selects for matching attributes by loading them
# into the LRU cache now.
# into the LRU or committed now.
self.states_manager.post_commit_pending()
self.state_attributes_manager.post_commit_pending()
self.event_data_manager.post_commit_pending()
self.event_type_manager.post_commit_pending()
@ -1080,7 +1072,7 @@ class Recorder(threading.Thread):
def _close_event_session(self) -> None:
"""Close the event session."""
self._old_states.clear()
self.states_manager.reset()
self.state_attributes_manager.reset()
self.event_data_manager.reset()
self.event_type_manager.reset()

View File

@ -459,24 +459,7 @@ def _purge_state_ids(instance: Recorder, session: Session, state_ids: set[int])
_LOGGER.debug("Deleted %s states", deleted_rows)
# Evict eny entries in the old_states cache referring to a purged state
_evict_purged_states_from_old_states_cache(instance, state_ids)
def _evict_purged_states_from_old_states_cache(
instance: Recorder, purged_state_ids: set[int]
) -> None:
"""Evict purged states from the old states cache."""
# Make a map from old_state_id to entity_id
old_states = instance._old_states # pylint: disable=protected-access
old_state_reversed = {
old_state.state_id: entity_id
for entity_id, old_state in old_states.items()
if old_state.state_id
}
# Evict any purged state from the old states cache
for purged_state_id in purged_state_ids.intersection(old_state_reversed):
old_states.pop(old_state_reversed[purged_state_id], None)
instance.states_manager.evict_purged_state_ids(state_ids)
def _purge_batch_attributes_ids(
@ -576,6 +559,7 @@ def _purge_old_entity_ids(instance: Recorder, session: Session) -> None:
# Evict any entries in the event_type cache referring to a purged state
instance.states_meta_manager.evict_purged(purge_entity_ids)
instance.states_manager.evict_purged_entity_ids(purge_entity_ids)
def _purge_filtered_data(instance: Recorder, session: Session) -> bool:

View File

@ -0,0 +1,91 @@
"""Support managing States."""
from __future__ import annotations
from ..db_schema import States
class StatesManager:
"""Manage the states table."""
def __init__(self) -> None:
"""Initialize the states manager for linking old_state_id."""
self._pending: dict[str, States] = {}
self._last_committed_id: dict[str, int] = {}
def pop_pending(self, entity_id: str) -> States | None:
"""Pop a pending state.
Pending states are states that are in the session but not yet committed.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.pop(entity_id, None)
def pop_committed(self, entity_id: str) -> int | None:
"""Pop a committed state.
Committed states are states that have already been committed to the
database.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._last_committed_id.pop(entity_id, None)
def add_pending(self, entity_id: str, state: States) -> None:
"""Add a pending state.
Pending states are states that are in the session but not yet committed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._pending[entity_id] = state
def post_commit_pending(self) -> None:
"""Call after commit to load the state_id of the new States into committed.
This call is not thread-safe and must be called from the
recorder thread.
"""
for entity_id, db_states in self._pending.items():
self._last_committed_id[entity_id] = db_states.state_id
self._pending.clear()
def reset(self) -> None:
"""Reset after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._last_committed_id.clear()
self._pending.clear()
def evict_purged_state_ids(self, purged_state_ids: set[int]) -> None:
"""Evict purged states from the committed states.
When we purge states we need to make sure the next call to record a state
does not link the old_state_id to the purged state.
"""
# Make a map from the committed state_id to the entity_id
last_committed_ids = self._last_committed_id
last_committed_ids_reversed = {
state_id: entity_id for entity_id, state_id in last_committed_ids.items()
}
# Evict any purged state from the old states cache
for purged_state_id in purged_state_ids.intersection(
last_committed_ids_reversed
):
last_committed_ids.pop(last_committed_ids_reversed[purged_state_id], None)
def evict_purged_entity_ids(self, purged_entity_ids: set[str]) -> None:
"""Evict purged entity_ids from the committed states.
When we purge states we need to make sure the next call to record a state
does not link the old_state_id to the purged state.
"""
last_committed_ids = self._last_committed_id
for entity_id in purged_entity_ids:
last_committed_ids.pop(entity_id, None)

View File

@ -82,7 +82,7 @@ async def test_purge_old_states(
events = session.query(Events).filter(Events.event_type == "state_changed")
assert events.count() == 0
assert "test.recorder2" in instance._old_states
assert "test.recorder2" in instance.states_manager._last_committed_id
purge_before = dt_util.utcnow() - timedelta(days=4)
@ -98,7 +98,7 @@ async def test_purge_old_states(
assert states.count() == 2
assert state_attributes.count() == 1
assert "test.recorder2" in instance._old_states
assert "test.recorder2" in instance.states_manager._last_committed_id
states_after_purge = list(session.query(States))
# Since these states are deleted in batches, we can't guarantee the order
@ -115,7 +115,7 @@ async def test_purge_old_states(
assert states.count() == 2
assert state_attributes.count() == 1
assert "test.recorder2" in instance._old_states
assert "test.recorder2" in instance.states_manager._last_committed_id
# run purge_old_data again
purge_before = dt_util.utcnow()
@ -130,7 +130,7 @@ async def test_purge_old_states(
assert states.count() == 0
assert state_attributes.count() == 0
assert "test.recorder2" not in instance._old_states
assert "test.recorder2" not in instance.states_manager._last_committed_id
# Add some more states
await _add_test_states(hass)
@ -144,7 +144,7 @@ async def test_purge_old_states(
events = session.query(Events).filter(Events.event_type == "state_changed")
assert events.count() == 0
assert "test.recorder2" in instance._old_states
assert "test.recorder2" in instance.states_manager._last_committed_id
state_attributes = session.query(StateAttributes)
assert state_attributes.count() == 3