Prevent autoflush from happening during attrs lookup (#70768)

This commit is contained in:
J. Nick Koston 2022-04-26 10:04:58 -10:00 committed by GitHub
parent f073f17040
commit 1c4a785fb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 10 deletions

View File

@ -14,10 +14,19 @@ import time
from typing import Any, TypeVar, cast from typing import Any, TypeVar, cast
from lru import LRU # pylint: disable=no-name-in-module from lru import LRU # pylint: disable=no-name-in-module
from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select from sqlalchemy import (
bindparam,
create_engine,
event as sqlalchemy_event,
exc,
func,
select,
)
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext import baked
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
import voluptuous as vol import voluptuous as vol
@ -279,6 +288,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entity_filter=entity_filter, entity_filter=entity_filter,
exclude_t=exclude_t, exclude_t=exclude_t,
exclude_attributes_by_domain=exclude_attributes_by_domain, exclude_attributes_by_domain=exclude_attributes_by_domain,
bakery=baked.bakery(),
) )
instance.async_initialize() instance.async_initialize()
instance.async_register() instance.async_register()
@ -600,6 +610,7 @@ class Recorder(threading.Thread):
entity_filter: Callable[[str], bool], entity_filter: Callable[[str], bool],
exclude_t: list[str], exclude_t: list[str],
exclude_attributes_by_domain: dict[str, set[str]], exclude_attributes_by_domain: dict[str, set[str]],
bakery: baked.bakery,
) -> None: ) -> None:
"""Initialize the recorder.""" """Initialize the recorder."""
threading.Thread.__init__(self, name="Recorder") threading.Thread.__init__(self, name="Recorder")
@ -628,6 +639,8 @@ class Recorder(threading.Thread):
self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE) self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE)
self._pending_state_attributes: dict[str, StateAttributes] = {} self._pending_state_attributes: dict[str, StateAttributes] = {}
self._pending_expunge: list[States] = [] self._pending_expunge: list[States] = []
self._bakery = bakery
self._find_shared_attr_query: Query | None = None
self.event_session: Session | None = None self.event_session: Session | None = None
self.get_session: Callable[[], Session] | None = None self.get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None self._completed_first_database_setup: bool | None = None
@ -1118,6 +1131,32 @@ class Recorder(threading.Thread):
if not self.commit_interval: if not self.commit_interval:
self._commit_event_session_or_retry() self._commit_event_session_or_retry()
def _find_shared_attr_in_db(self, attr_hash: int, shared_attrs: str) -> int | None:
"""Find shared attributes in the db from the hash and shared_attrs."""
#
# Avoid the event session being flushed since it will
# commit all the pending events and states to the database.
#
# The lookup has already have checked to see if the data is cached
# or going to be written in the next commit so there is no
# need to flush before checking the database.
#
assert self.event_session is not None
if self._find_shared_attr_query is None:
self._find_shared_attr_query = self._bakery(
lambda session: session.query(StateAttributes.attributes_id)
.filter(StateAttributes.hash == bindparam("attr_hash"))
.filter(StateAttributes.shared_attrs == bindparam("shared_attrs"))
)
with self.event_session.no_autoflush:
if (
attributes := self._find_shared_attr_query(self.event_session)
.params(attr_hash=attr_hash, shared_attrs=shared_attrs)
.first()
):
return cast(int, attributes[0])
return None
def _process_event_into_session(self, event: Event) -> None: def _process_event_into_session(self, event: Event) -> None:
assert self.event_session is not None assert self.event_session is not None
@ -1157,14 +1196,9 @@ class Recorder(threading.Thread):
else: else:
attr_hash = StateAttributes.hash_shared_attrs(shared_attrs) attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
# Matching attributes found in the database # Matching attributes found in the database
if ( if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
attributes := self.event_session.query(StateAttributes.attributes_id) dbstate.attributes_id = attributes_id
.filter(StateAttributes.hash == attr_hash) self._state_attributes_ids[shared_attrs] = attributes_id
.filter(StateAttributes.shared_attrs == shared_attrs)
.first()
):
dbstate.attributes_id = attributes[0]
self._state_attributes_ids[shared_attrs] = attributes[0]
# No matching attributes found, save them in the DB # No matching attributes found, save them in the DB
else: else:
dbstate_attributes = StateAttributes( dbstate_attributes = StateAttributes(

View File

@ -10,6 +10,7 @@ from unittest.mock import Mock, patch
import pytest import pytest
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
from sqlalchemy.ext import baked
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import ( from homeassistant.components.recorder import (
@ -79,6 +80,7 @@ def _default_recorder(hass):
entity_filter=CONFIG_SCHEMA({DOMAIN: {}}), entity_filter=CONFIG_SCHEMA({DOMAIN: {}}),
exclude_t=[], exclude_t=[],
exclude_attributes_by_domain={}, exclude_attributes_by_domain={},
bakery=baked.bakery(),
) )

View File

@ -181,7 +181,9 @@ async def test_events_during_migration_are_queued(hass):
True, True,
), patch("homeassistant.components.recorder.create_engine", new=create_engine_test): ), patch("homeassistant.components.recorder.create_engine", new=create_engine_test):
await async_setup_component( await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}} hass,
"recorder",
{"recorder": {"db_url": "sqlite://", "commit_interval": 0}},
) )
hass.states.async_set("my.entity", "on", {}) hass.states.async_set("my.entity", "on", {})
hass.states.async_set("my.entity", "off", {}) hass.states.async_set("my.entity", "off", {})