diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 4f1401aaee7..18acfabffaa 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -331,6 +331,8 @@ def _async_register_services(hass, instance): class RecorderTask(abc.ABC): """ABC for recorder tasks.""" + commit_before = True + @abc.abstractmethod def run(self, instance: Recorder) -> None: """Handle the task.""" @@ -439,6 +441,8 @@ class ExternalStatisticsTask(RecorderTask): class WaitTask(RecorderTask): """An object to insert into the recorder queue to tell it set the _queue_watch event.""" + commit_before = False + def run(self, instance: Recorder) -> None: """Handle the task.""" instance._queue_watch.set() # pylint: disable=[protected-access] @@ -461,6 +465,8 @@ class DatabaseLockTask(RecorderTask): class StopTask(RecorderTask): """An object to insert into the recorder queue to stop the event handler.""" + commit_before = False + def run(self, instance: Recorder) -> None: """Handle the task.""" instance.stop_requested = True @@ -471,6 +477,7 @@ class EventTask(RecorderTask): """An object to insert into the recorder queue to stop the event handler.""" event: bool + commit_before = False def run(self, instance: Recorder) -> None: """Handle the task.""" @@ -800,6 +807,10 @@ class Recorder(threading.Thread): def _process_one_task_or_recover(self, task: RecorderTask): """Process an event, reconnect, or recover a malformed database.""" try: + # If its not an event, commit everything + # that is pending before running the task + if task.commit_before: + self._commit_event_session_or_retry() return task.run(self) except exc.DatabaseError as err: if self._handle_database_error(err): @@ -955,7 +966,9 @@ class Recorder(threading.Thread): def _commit_event_session_or_retry(self): """Commit the event session if there is work to do.""" - if not self.event_session.new and not self.event_session.dirty: + if not self.event_session or ( + not self.event_session.new and not self.event_session.dirty + ): return tries = 1 while tries <= self.db_max_retries: diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index dc7881cfb42..6bdc8250afc 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -288,7 +288,7 @@ async def test_force_shutdown_with_queue_of_writes_that_generate_exceptions( await async_wait_recording_done(hass, instance) - with patch.object(instance, "db_retry_wait", 0.2), patch.object( + with patch.object(instance, "db_retry_wait", 0.05), patch.object( instance.event_session, "flush", side_effect=OperationalError(