From 1c4a785fb379fe3867d199204a84c1e08bc72608 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 26 Apr 2022 10:04:58 -1000 Subject: [PATCH] Prevent autoflush from happening during attrs lookup (#70768) --- homeassistant/components/recorder/__init__.py | 52 +++++++++++++++---- tests/components/recorder/test_init.py | 2 + tests/components/recorder/test_migrate.py | 4 +- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 9c90ca30e8d..9aeed3b6d12 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -14,10 +14,19 @@ import time from typing import Any, TypeVar, cast from lru import LRU # pylint: disable=no-name-in-module -from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select +from sqlalchemy import ( + bindparam, + create_engine, + event as sqlalchemy_event, + exc, + func, + select, +) from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext import baked from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session import voluptuous as vol @@ -279,6 +288,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: entity_filter=entity_filter, exclude_t=exclude_t, exclude_attributes_by_domain=exclude_attributes_by_domain, + bakery=baked.bakery(), ) instance.async_initialize() instance.async_register() @@ -600,6 +610,7 @@ class Recorder(threading.Thread): entity_filter: Callable[[str], bool], exclude_t: list[str], exclude_attributes_by_domain: dict[str, set[str]], + bakery: baked.bakery, ) -> None: """Initialize the recorder.""" threading.Thread.__init__(self, name="Recorder") @@ -628,6 +639,8 @@ class Recorder(threading.Thread): self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE) self._pending_state_attributes: dict[str, StateAttributes] = {} self._pending_expunge: list[States] = [] + self._bakery = bakery + self._find_shared_attr_query: Query | None = None self.event_session: Session | None = None self.get_session: Callable[[], Session] | None = None self._completed_first_database_setup: bool | None = None @@ -1118,6 +1131,32 @@ class Recorder(threading.Thread): if not self.commit_interval: self._commit_event_session_or_retry() + def _find_shared_attr_in_db(self, attr_hash: int, shared_attrs: str) -> int | None: + """Find shared attributes in the db from the hash and shared_attrs.""" + # + # Avoid the event session being flushed since it will + # commit all the pending events and states to the database. + # + # The lookup has already have checked to see if the data is cached + # or going to be written in the next commit so there is no + # need to flush before checking the database. + # + assert self.event_session is not None + if self._find_shared_attr_query is None: + self._find_shared_attr_query = self._bakery( + lambda session: session.query(StateAttributes.attributes_id) + .filter(StateAttributes.hash == bindparam("attr_hash")) + .filter(StateAttributes.shared_attrs == bindparam("shared_attrs")) + ) + with self.event_session.no_autoflush: + if ( + attributes := self._find_shared_attr_query(self.event_session) + .params(attr_hash=attr_hash, shared_attrs=shared_attrs) + .first() + ): + return cast(int, attributes[0]) + return None + def _process_event_into_session(self, event: Event) -> None: assert self.event_session is not None @@ -1157,14 +1196,9 @@ class Recorder(threading.Thread): else: attr_hash = StateAttributes.hash_shared_attrs(shared_attrs) # Matching attributes found in the database - if ( - attributes := self.event_session.query(StateAttributes.attributes_id) - .filter(StateAttributes.hash == attr_hash) - .filter(StateAttributes.shared_attrs == shared_attrs) - .first() - ): - dbstate.attributes_id = attributes[0] - self._state_attributes_ids[shared_attrs] = attributes[0] + if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs): + dbstate.attributes_id = attributes_id + self._state_attributes_ids[shared_attrs] = attributes_id # No matching attributes found, save them in the DB else: dbstate_attributes = StateAttributes( diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 945523be778..78a3f808d03 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -10,6 +10,7 @@ from unittest.mock import Mock, patch import pytest from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError +from sqlalchemy.ext import baked from homeassistant.components import recorder from homeassistant.components.recorder import ( @@ -79,6 +80,7 @@ def _default_recorder(hass): entity_filter=CONFIG_SCHEMA({DOMAIN: {}}), exclude_t=[], exclude_attributes_by_domain={}, + bakery=baked.bakery(), ) diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 41b71aa59c0..6b963941263 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -181,7 +181,9 @@ async def test_events_during_migration_are_queued(hass): True, ), patch("homeassistant.components.recorder.create_engine", new=create_engine_test): await async_setup_component( - hass, "recorder", {"recorder": {"db_url": "sqlite://"}} + hass, + "recorder", + {"recorder": {"db_url": "sqlite://", "commit_interval": 0}}, ) hass.states.async_set("my.entity", "on", {}) hass.states.async_set("my.entity", "off", {})