diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 68c5f16f387..ba960dcb93d 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -19,7 +19,6 @@ from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm.session import Session -from sqlalchemy.pool import StaticPool import voluptuous as vol from homeassistant.components import persistent_notification @@ -79,7 +78,7 @@ from .models import ( StatisticsRuns, process_timestamp, ) -from .pool import POOL_SIZE, RecorderPool +from .pool import POOL_SIZE, MutexPool, RecorderPool from .util import ( dburl_to_path, end_incomplete_runs, @@ -1405,7 +1404,8 @@ class Recorder(threading.Thread): if self.db_url == SQLITE_URL_PREFIX or ":memory:" in self.db_url: kwargs["connect_args"] = {"check_same_thread": False} - kwargs["poolclass"] = StaticPool + kwargs["poolclass"] = MutexPool + MutexPool.pool_lock = threading.RLock() kwargs["pool_reset_on_return"] = None elif self.db_url.startswith(SQLITE_URL_PREFIX): kwargs["poolclass"] = RecorderPool diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 26234be0502..16b7d03b618 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -56,7 +56,7 @@ def get_schema_version(instance: Any) -> int: current_version = getattr(res, "schema_version", None) if current_version is None: - current_version = _inspect_schema_version(instance.engine, session) + current_version = _inspect_schema_version(session) _LOGGER.debug( "No schema version found. Inspected version: %s", current_version ) @@ -651,7 +651,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 raise ValueError(f"No schema migration defined for version {new_version}") -def _inspect_schema_version(engine, session): +def _inspect_schema_version(session): """Determine the schema version by inspecting the db structure. When the schema version is not present in the db, either db was just @@ -660,7 +660,7 @@ def _inspect_schema_version(engine, session): version 1 are present to make the determination. Eventually this logic can be removed and we can assume a new db is being created. """ - inspector = sqlalchemy.inspect(engine) + inspector = sqlalchemy.inspect(session.connection()) indexes = inspector.get_indexes("events") for index in indexes: diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py index 633c084ade4..d1a0cb28639 100644 --- a/homeassistant/components/recorder/pool.py +++ b/homeassistant/components/recorder/pool.py @@ -1,13 +1,22 @@ """A pool for sqlite connections.""" +import logging import threading +import traceback from typing import Any -from sqlalchemy.pool import NullPool, SingletonThreadPool +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.pool import NullPool, SingletonThreadPool, StaticPool from homeassistant.helpers.frame import report from .const import DB_WORKER_PREFIX +_LOGGER = logging.getLogger(__name__) + +# For debugging the MutexPool +DEBUG_MUTEX_POOL = True +DEBUG_MUTEX_POOL_TRACE = False + POOL_SIZE = 5 @@ -63,3 +72,56 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] return super( # pylint: disable=bad-super-call NullPool, self )._create_connection() + + +class MutexPool(StaticPool): # type: ignore[misc] + """A pool which prevents concurrent accesses from multiple threads. + + This is used in tests to prevent unsafe concurrent accesses to in-memory SQLite + databases. + """ + + _reference_counter = 0 + pool_lock: threading.RLock + + def _do_return_conn(self, conn: Any) -> None: + if DEBUG_MUTEX_POOL_TRACE: + trace = traceback.extract_stack() + trace_msg = "\n" + "".join(traceback.format_list(trace[:-1])) + else: + trace_msg = "" + + super()._do_return_conn(conn) + if DEBUG_MUTEX_POOL: + self._reference_counter -= 1 + _LOGGER.debug( + "%s return conn ctr: %s%s", + threading.current_thread().name, + self._reference_counter, + trace_msg, + ) + MutexPool.pool_lock.release() + + def _do_get(self) -> Any: + + if DEBUG_MUTEX_POOL_TRACE: + trace = traceback.extract_stack() + trace_msg = "".join(traceback.format_list(trace[:-1])) + else: + trace_msg = "" + + if DEBUG_MUTEX_POOL: + _LOGGER.debug("%s wait conn%s", threading.current_thread().name, trace_msg) + # pylint: disable-next=consider-using-with + got_lock = MutexPool.pool_lock.acquire(timeout=1) + if not got_lock: + raise SQLAlchemyError + conn = super()._do_get() + if DEBUG_MUTEX_POOL: + self._reference_counter += 1 + _LOGGER.debug( + "%s get conn: ctr: %s", + threading.current_thread().name, + self._reference_counter, + ) + return conn diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 8272ba79a18..a4bc1bbacaf 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -84,12 +84,18 @@ def _default_recorder(hass): ) -async def test_shutdown_before_startup_finishes(hass): +async def test_shutdown_before_startup_finishes(hass, tmp_path): """Test shutdown before recorder starts is clean.""" + # On-disk database because this test does not play nice with the + # MutexPool + config = { + recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db"), + recorder.CONF_COMMIT_INTERVAL: 1, + } hass.state = CoreState.not_running - await async_init_recorder_component(hass) + await async_init_recorder_component(hass, config) await hass.data[DATA_INSTANCE].async_db_ready await hass.async_block_till_done() diff --git a/tests/components/recorder/test_purge.py b/tests/components/recorder/test_purge.py index c591f6e5242..43e195e3a21 100644 --- a/tests/components/recorder/test_purge.py +++ b/tests/components/recorder/test_purge.py @@ -276,6 +276,18 @@ async def test_purge_method( caplog, ): """Test purge method.""" + + def assert_recorder_runs_equal(run1, run2): + assert run1.run_id == run2.run_id + assert run1.start == run2.start + assert run1.end == run2.end + assert run1.closed_incorrect == run2.closed_incorrect + assert run1.created == run2.created + + def assert_statistic_runs_equal(run1, run2): + assert run1.run_id == run2.run_id + assert run1.start == run2.start + instance = await async_setup_recorder_instance(hass) service_data = {"keep_days": 4} @@ -306,27 +318,44 @@ async def test_purge_method( assert statistics_runs.count() == 7 statistic_runs_before_purge = statistics_runs.all() - await hass.async_block_till_done() - await async_wait_purge_done(hass, instance) + for itm in runs_before_purge: + session.expunge(itm) + for itm in statistic_runs_before_purge: + session.expunge(itm) - # run purge method - no service data, use defaults - await hass.services.async_call("recorder", "purge") - await hass.async_block_till_done() + await hass.async_block_till_done() + await async_wait_purge_done(hass, instance) - # Small wait for recorder thread - await async_wait_purge_done(hass, instance) + # run purge method - no service data, use defaults + await hass.services.async_call("recorder", "purge") + await hass.async_block_till_done() + + # Small wait for recorder thread + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + states = session.query(States) + events = session.query(Events).filter(Events.event_type.like("EVENT_TEST%")) + statistics = session.query(StatisticsShortTerm) # only purged old states, events and statistics assert states.count() == 4 assert events.count() == 4 assert statistics.count() == 4 - # run purge method - correct service data - await hass.services.async_call("recorder", "purge", service_data=service_data) - await hass.async_block_till_done() + # run purge method - correct service data + await hass.services.async_call("recorder", "purge", service_data=service_data) + await hass.async_block_till_done() - # Small wait for recorder thread - await async_wait_purge_done(hass, instance) + # Small wait for recorder thread + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + states = session.query(States) + events = session.query(Events).filter(Events.event_type.like("EVENT_TEST%")) + statistics = session.query(StatisticsShortTerm) + recorder_runs = session.query(RecorderRuns) + statistics_runs = session.query(StatisticsRuns) # we should only have 2 states, events and statistics left after purging assert states.count() == 2 @@ -335,24 +364,24 @@ async def test_purge_method( # now we should only have 3 recorder runs left runs = recorder_runs.all() - assert runs[0] == runs_before_purge[0] - assert runs[1] == runs_before_purge[5] - assert runs[2] == runs_before_purge[6] + assert_recorder_runs_equal(runs[0], runs_before_purge[0]) + assert_recorder_runs_equal(runs[1], runs_before_purge[5]) + assert_recorder_runs_equal(runs[2], runs_before_purge[6]) # now we should only have 3 statistics runs left runs = statistics_runs.all() - assert runs[0] == statistic_runs_before_purge[0] - assert runs[1] == statistic_runs_before_purge[5] - assert runs[2] == statistic_runs_before_purge[6] + assert_statistic_runs_equal(runs[0], statistic_runs_before_purge[0]) + assert_statistic_runs_equal(runs[1], statistic_runs_before_purge[5]) + assert_statistic_runs_equal(runs[2], statistic_runs_before_purge[6]) assert "EVENT_TEST_PURGE" not in (event.event_type for event in events.all()) - # run purge method - correct service data, with repack - service_data["repack"] = True - await hass.services.async_call("recorder", "purge", service_data=service_data) - await hass.async_block_till_done() - await async_wait_purge_done(hass, instance) - assert "Vacuuming SQL DB to free space" in caplog.text + # run purge method - correct service data, with repack + service_data["repack"] = True + await hass.services.async_call("recorder", "purge", service_data=service_data) + await hass.async_block_till_done() + await async_wait_purge_done(hass, instance) + assert "Vacuuming SQL DB to free space" in caplog.text async def test_purge_edge_case( @@ -408,15 +437,18 @@ async def test_purge_edge_case( events = session.query(Events).filter(Events.event_type == "EVENT_TEST_PURGE") assert events.count() == 1 - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await hass.async_block_till_done() + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await hass.async_block_till_done() - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 0 + events = session.query(Events).filter(Events.event_type == "EVENT_TEST_PURGE") assert events.count() == 0 @@ -514,11 +546,12 @@ async def test_purge_cutoff_date( assert events.filter(Events.event_type == "PURGE").count() == rows - 1 assert events.filter(Events.event_type == "KEEP").count() == 1 - instance.queue.put(PurgeTask(cutoff, repack=False, apply_filter=False)) - await hass.async_block_till_done() - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + instance.queue.put(PurgeTask(cutoff, repack=False, apply_filter=False)) + await hass.async_block_till_done() + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + with session_scope(hass=hass) as session: states = session.query(States) state_attributes = session.query(StateAttributes) events = session.query(Events) @@ -543,21 +576,25 @@ async def test_purge_cutoff_date( assert events.filter(Events.event_type == "PURGE").count() == 0 assert events.filter(Events.event_type == "KEEP").count() == 1 - # Make sure we can purge everything - instance.queue.put( - PurgeTask(dt_util.utcnow(), repack=False, apply_filter=False) - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Make sure we can purge everything + instance.queue.put(PurgeTask(dt_util.utcnow(), repack=False, apply_filter=False)) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + states = session.query(States) + state_attributes = session.query(StateAttributes) assert states.count() == 0 assert state_attributes.count() == 0 - # Make sure we can purge everything when the db is already empty - instance.queue.put( - PurgeTask(dt_util.utcnow(), repack=False, apply_filter=False) - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Make sure we can purge everything when the db is already empty + instance.queue.put(PurgeTask(dt_util.utcnow(), repack=False, apply_filter=False)) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + states = session.query(States) + state_attributes = session.query(StateAttributes) assert states.count() == 0 assert state_attributes.count() == 0 @@ -667,34 +704,46 @@ async def test_purge_filtered_states( assert events_state_changed.count() == 70 assert events_keep.count() == 1 - # Normal purge doesn't remove excluded entities - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await hass.async_block_till_done() + # Normal purge doesn't remove excluded entities + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await hass.async_block_till_done() - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 74 + events_state_changed = session.query(Events).filter( + Events.event_type == EVENT_STATE_CHANGED + ) assert events_state_changed.count() == 70 + events_keep = session.query(Events).filter(Events.event_type == "EVENT_KEEP") assert events_keep.count() == 1 - # Test with 'apply_filter' = True - service_data["apply_filter"] = True - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await hass.async_block_till_done() + # Test with 'apply_filter' = True + service_data["apply_filter"] = True + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await hass.async_block_till_done() - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 13 + events_state_changed = session.query(Events).filter( + Events.event_type == EVENT_STATE_CHANGED + ) assert events_state_changed.count() == 10 + events_keep = session.query(Events).filter(Events.event_type == "EVENT_KEEP") assert events_keep.count() == 1 states_sensor_excluded = session.query(States).filter( @@ -713,25 +762,29 @@ async def test_purge_filtered_states( assert session.query(StateAttributes).count() == 11 - # Do it again to make sure nothing changes - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Do it again to make sure nothing changes + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: final_keep_state = session.query(States).get(74) assert final_keep_state.old_state_id == 62 # should have been kept assert final_keep_state.attributes_id == 71 assert session.query(StateAttributes).count() == 11 - # Finally make sure we can delete them all except for the ones missing an event_id - service_data = {"keep_days": 0} - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Finally make sure we can delete them all except for the ones missing an event_id + service_data = {"keep_days": 0} + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: remaining = list(session.query(States)) for state in remaining: assert state.event_id is None @@ -771,22 +824,27 @@ async def test_purge_filtered_states_to_empty( assert states.count() == 60 assert state_attributes.count() == 60 - # Test with 'apply_filter' = True - service_data["apply_filter"] = True - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Test with 'apply_filter' = True + service_data["apply_filter"] = True + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + states = session.query(States) + state_attributes = session.query(StateAttributes) assert states.count() == 0 assert state_attributes.count() == 0 - # Do it again to make sure nothing changes - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Do it again to make sure nothing changes + # Why do we do this? Should we check the end result? + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) async def test_purge_without_state_attributes_filtered_states_to_empty( @@ -834,22 +892,27 @@ async def test_purge_without_state_attributes_filtered_states_to_empty( assert states.count() == 1 assert state_attributes.count() == 0 - # Test with 'apply_filter' = True - service_data["apply_filter"] = True - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Test with 'apply_filter' = True + service_data["apply_filter"] = True + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + states = session.query(States) + state_attributes = session.query(StateAttributes) assert states.count() == 0 assert state_attributes.count() == 0 - # Do it again to make sure nothing changes - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data - ) - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + # Do it again to make sure nothing changes + # Why do we do this? Should we check the end result? + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) async def test_purge_filtered_events( @@ -901,32 +964,44 @@ async def test_purge_filtered_events( assert events_keep.count() == 10 assert states.count() == 10 - # Normal purge doesn't remove excluded events - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + # Normal purge doesn't remove excluded events + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await hass.async_block_till_done() + + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + events_purge = session.query(Events).filter(Events.event_type == "EVENT_PURGE") + events_keep = session.query(Events).filter( + Events.event_type == EVENT_STATE_CHANGED ) - await hass.async_block_till_done() - - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) - + states = session.query(States) assert events_purge.count() == 60 assert events_keep.count() == 10 assert states.count() == 10 - # Test with 'apply_filter' = True - service_data["apply_filter"] = True - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + # Test with 'apply_filter' = True + service_data["apply_filter"] = True + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await hass.async_block_till_done() + + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + events_purge = session.query(Events).filter(Events.event_type == "EVENT_PURGE") + events_keep = session.query(Events).filter( + Events.event_type == EVENT_STATE_CHANGED ) - await hass.async_block_till_done() - - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) - - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) - + states = session.query(States) assert events_purge.count() == 0 assert events_keep.count() == 10 assert states.count() == 10 @@ -1010,16 +1085,23 @@ async def test_purge_filtered_events_state_changed( assert events_purge.count() == 60 assert states.count() == 63 - await hass.services.async_call( - recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + await hass.services.async_call( + recorder.DOMAIN, recorder.SERVICE_PURGE, service_data + ) + await hass.async_block_till_done() + + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + await async_recorder_block_till_done(hass, instance) + await async_wait_purge_done(hass, instance) + + with session_scope(hass=hass) as session: + events_keep = session.query(Events).filter(Events.event_type == "EVENT_KEEP") + events_purge = session.query(Events).filter( + Events.event_type == EVENT_STATE_CHANGED ) - await hass.async_block_till_done() - - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) - - await async_recorder_block_till_done(hass, instance) - await async_wait_purge_done(hass, instance) + states = session.query(States) assert events_keep.count() == 10 assert events_purge.count() == 0 @@ -1104,9 +1186,10 @@ async def test_purge_entities( states = session.query(States) assert states.count() == 190 - await _purge_entities( - hass, "sensor.purge_entity", "purge_domain", "*purge_glob" - ) + await _purge_entities(hass, "sensor.purge_entity", "purge_domain", "*purge_glob") + + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 10 states_sensor_kept = session.query(States).filter( @@ -1121,13 +1204,22 @@ async def test_purge_entities( states = session.query(States) assert states.count() == 190 - await _purge_entities(hass, "sensor.purge_entity", [], []) + await _purge_entities(hass, "sensor.purge_entity", [], []) + + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 130 - await _purge_entities(hass, [], "purge_domain", []) + await _purge_entities(hass, [], "purge_domain", []) + + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 70 - await _purge_entities(hass, [], [], "*purge_glob") + await _purge_entities(hass, [], [], "*purge_glob") + + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 10 states_sensor_kept = session.query(States).filter( @@ -1142,7 +1234,10 @@ async def test_purge_entities( states = session.query(States) assert states.count() == 190 - await _purge_entities(hass, [], [], []) + await _purge_entities(hass, [], [], []) + + with session_scope(hass=hass) as session: + states = session.query(States) assert states.count() == 0 diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 5ecf6f892ad..db702b3a3e3 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -17,11 +17,12 @@ from homeassistant.components.recorder.util import ( is_second_sunday, session_scope, ) +from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.util import dt as dt_util from .common import corrupt_db_file -from tests.common import async_init_recorder_component +from tests.common import async_init_recorder_component, async_test_home_assistant def test_session_scope_not_setup(hass_recorder): @@ -95,36 +96,66 @@ def test_validate_or_move_away_sqlite_database(hass, tmpdir, caplog): assert util.validate_or_move_away_sqlite_database(dburl) is True -async def test_last_run_was_recently_clean(hass): +async def test_last_run_was_recently_clean(hass, tmp_path): """Test we can check if the last recorder run was recently clean.""" - await async_init_recorder_component(hass, {recorder.CONF_COMMIT_INTERVAL: 1}) + config = { + recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db"), + recorder.CONF_COMMIT_INTERVAL: 1, + } + hass = await async_test_home_assistant(None) + + return_values = [] + real_last_run_was_recently_clean = util.last_run_was_recently_clean + + def _last_run_was_recently_clean(cursor): + return_values.append(real_last_run_was_recently_clean(cursor)) + return return_values[-1] + + # Test last_run_was_recently_clean is not called on new DB + with patch( + "homeassistant.components.recorder.util.last_run_was_recently_clean", + wraps=_last_run_was_recently_clean, + ) as last_run_was_recently_clean_mock: + await async_init_recorder_component(hass, config) + await hass.async_block_till_done() + last_run_was_recently_clean_mock.assert_not_called() + + # Restart HA, last_run_was_recently_clean should return True + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() + await hass.async_stop() - cursor = hass.data[DATA_INSTANCE].engine.raw_connection().cursor() + with patch( + "homeassistant.components.recorder.util.last_run_was_recently_clean", + wraps=_last_run_was_recently_clean, + ) as last_run_was_recently_clean_mock: + hass = await async_test_home_assistant(None) + await async_init_recorder_component(hass, config) + last_run_was_recently_clean_mock.assert_called_once() + assert return_values[-1] is True - assert ( - await hass.async_add_executor_job(util.last_run_was_recently_clean, cursor) - is False - ) - - await hass.async_add_executor_job(hass.data[DATA_INSTANCE]._end_session) + # Restart HA with a long downtime, last_run_was_recently_clean should return False + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() - - assert ( - await hass.async_add_executor_job(util.last_run_was_recently_clean, cursor) - is True - ) + await hass.async_stop() thirty_min_future_time = dt_util.utcnow() + timedelta(minutes=30) with patch( + "homeassistant.components.recorder.util.last_run_was_recently_clean", + wraps=_last_run_was_recently_clean, + ) as last_run_was_recently_clean_mock, patch( "homeassistant.components.recorder.dt_util.utcnow", return_value=thirty_min_future_time, ): - assert ( - await hass.async_add_executor_job(util.last_run_was_recently_clean, cursor) - is False - ) + hass = await async_test_home_assistant(None) + await async_init_recorder_component(hass, config) + last_run_was_recently_clean_mock.assert_called_once() + assert return_values[-1] is False + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + await hass.async_stop() @pytest.mark.parametrize(