Commit any pending changes before running non-EventTasks in the recorder (#68287)

This commit is contained in:
J. Nick Koston 2022-03-17 18:33:22 -10:00 committed by GitHub
parent 490c921763
commit 95f20500ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 2 deletions

View File

@ -331,6 +331,8 @@ def _async_register_services(hass, instance):
class RecorderTask(abc.ABC): class RecorderTask(abc.ABC):
"""ABC for recorder tasks.""" """ABC for recorder tasks."""
commit_before = True
@abc.abstractmethod @abc.abstractmethod
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Handle the task.""" """Handle the task."""
@ -439,6 +441,8 @@ class ExternalStatisticsTask(RecorderTask):
class WaitTask(RecorderTask): class WaitTask(RecorderTask):
"""An object to insert into the recorder queue to tell it set the _queue_watch event.""" """An object to insert into the recorder queue to tell it set the _queue_watch event."""
commit_before = False
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Handle the task.""" """Handle the task."""
instance._queue_watch.set() # pylint: disable=[protected-access] instance._queue_watch.set() # pylint: disable=[protected-access]
@ -461,6 +465,8 @@ class DatabaseLockTask(RecorderTask):
class StopTask(RecorderTask): class StopTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler.""" """An object to insert into the recorder queue to stop the event handler."""
commit_before = False
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Handle the task.""" """Handle the task."""
instance.stop_requested = True instance.stop_requested = True
@ -471,6 +477,7 @@ class EventTask(RecorderTask):
"""An object to insert into the recorder queue to stop the event handler.""" """An object to insert into the recorder queue to stop the event handler."""
event: bool event: bool
commit_before = False
def run(self, instance: Recorder) -> None: def run(self, instance: Recorder) -> None:
"""Handle the task.""" """Handle the task."""
@ -800,6 +807,10 @@ class Recorder(threading.Thread):
def _process_one_task_or_recover(self, task: RecorderTask): def _process_one_task_or_recover(self, task: RecorderTask):
"""Process an event, reconnect, or recover a malformed database.""" """Process an event, reconnect, or recover a malformed database."""
try: try:
# If its not an event, commit everything
# that is pending before running the task
if task.commit_before:
self._commit_event_session_or_retry()
return task.run(self) return task.run(self)
except exc.DatabaseError as err: except exc.DatabaseError as err:
if self._handle_database_error(err): if self._handle_database_error(err):
@ -955,7 +966,9 @@ class Recorder(threading.Thread):
def _commit_event_session_or_retry(self): def _commit_event_session_or_retry(self):
"""Commit the event session if there is work to do.""" """Commit the event session if there is work to do."""
if not self.event_session.new and not self.event_session.dirty: if not self.event_session or (
not self.event_session.new and not self.event_session.dirty
):
return return
tries = 1 tries = 1
while tries <= self.db_max_retries: while tries <= self.db_max_retries:

View File

@ -288,7 +288,7 @@ async def test_force_shutdown_with_queue_of_writes_that_generate_exceptions(
await async_wait_recording_done(hass, instance) await async_wait_recording_done(hass, instance)
with patch.object(instance, "db_retry_wait", 0.2), patch.object( with patch.object(instance, "db_retry_wait", 0.05), patch.object(
instance.event_session, instance.event_session,
"flush", "flush",
side_effect=OperationalError( side_effect=OperationalError(