diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index ceec7ee9eed..915e6b45181 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -5,6 +5,7 @@ import concurrent.futures from datetime import datetime import logging import queue +import sqlite3 import threading import time from typing import Any, Callable, List, Optional @@ -37,7 +38,12 @@ import homeassistant.util.dt as dt_util from . import migration, purge from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX from .models import Base, Events, RecorderRuns, States -from .util import session_scope, validate_or_move_away_sqlite_database +from .util import ( + dburl_to_path, + move_away_broken_database, + session_scope, + validate_or_move_away_sqlite_database, +) _LOGGER = logging.getLogger(__name__) @@ -247,7 +253,7 @@ class Recorder(threading.Thread): self._pending_expunge = [] self.event_session = None self.get_session = None - self._completed_database_setup = False + self._completed_database_setup = None @callback def async_initialize(self): @@ -278,39 +284,8 @@ class Recorder(threading.Thread): def run(self): """Start processing events to save.""" - tries = 1 - connected = False - while not connected and tries <= self.db_max_retries: - if tries != 1: - time.sleep(self.db_retry_wait) - try: - self._setup_connection() - migration.migrate_schema(self) - self._setup_run() - connected = True - _LOGGER.debug("Connected to recorder database") - except Exception as err: # pylint: disable=broad-except - _LOGGER.error( - "Error during connection setup: %s (retrying in %s seconds)", - err, - self.db_retry_wait, - ) - tries += 1 - - if not connected: - - @callback - def connection_failed(): - """Connect failed tasks.""" - self.async_db_ready.set_result(False) - persistent_notification.async_create( - self.hass, - "The recorder could not start, please check the log", - "Recorder", - ) - - self.hass.add_job(connection_failed) + if not self._setup_recorder(): return shutdown_task = object() @@ -346,15 +321,11 @@ class Recorder(threading.Thread): self.hass.add_job(register) result = hass_started.result() - self.event_session = self.get_session() - self.event_session.expire_on_commit = False - # If shutdown happened before Home Assistant finished starting if result is shutdown_task: # Make sure we cleanly close the run if # we restart before startup finishes - self._close_run() - self._close_connection() + self._shutdown() return # Start periodic purge @@ -370,175 +341,180 @@ class Recorder(threading.Thread): async_purge, hour=4, minute=12, second=0 ) + _LOGGER.debug("Recorder processing the queue") # Use a session for the event read loop # with a commit every time the event time # has changed. This reduces the disk io. while True: event = self.queue.get() + if event is None: - self._close_run() - self._close_connection() + self._shutdown() return - if isinstance(event, PurgeTask): - # Schedule a new purge task if this one didn't finish - if not purge.purge_old_data(self, event.keep_days, event.repack): - self.queue.put(PurgeTask(event.keep_days, event.repack)) - continue - if isinstance(event, WaitTask): - self._queue_watch.set() - continue - if event.event_type == EVENT_TIME_CHANGED: - self._keepalive_count += 1 - if self._keepalive_count >= KEEPALIVE_TIME: - self._keepalive_count = 0 - self._send_keep_alive() - if self.commit_interval: - self._timechanges_seen += 1 - if self._timechanges_seen >= self.commit_interval: - self._timechanges_seen = 0 - self._commit_event_session_or_retry() - continue + self._process_one_event(event) + + def _setup_recorder(self) -> bool: + """Create schema and connect to the database.""" + tries = 1 + + while tries <= self.db_max_retries: try: - if event.event_type == EVENT_STATE_CHANGED: - dbevent = Events.from_event(event, event_data="{}") - else: - dbevent = Events.from_event(event) - dbevent.created = event.time_fired - self.event_session.add(dbevent) - except (TypeError, ValueError): - _LOGGER.warning("Event is not JSON serializable: %s", event) + self._setup_connection() + migration.migrate_schema(self) + self._setup_run() except Exception as err: # pylint: disable=broad-except - # Must catch the exception to prevent the loop from collapsing - _LOGGER.exception("Error adding event: %s", err) + _LOGGER.error( + "Error during connection setup to %s: %s (retrying in %s seconds)", + self.db_url, + err, + self.db_retry_wait, + ) + else: + _LOGGER.debug("Connected to recorder database") + self._open_event_session() + return True - if dbevent and event.event_type == EVENT_STATE_CHANGED: - try: - dbstate = States.from_event(event) - has_new_state = event.data.get("new_state") - if dbstate.entity_id in self._old_states: - old_state = self._old_states.pop(dbstate.entity_id) - if old_state.state_id: - dbstate.old_state_id = old_state.state_id - else: - dbstate.old_state = old_state - if not has_new_state: - dbstate.state = None - dbstate.event = dbevent - dbstate.created = event.time_fired - self.event_session.add(dbstate) - if has_new_state: - self._old_states[dbstate.entity_id] = dbstate - self._pending_expunge.append(dbstate) - except (TypeError, ValueError): - _LOGGER.warning( - "State is not JSON serializable: %s", - event.data.get("new_state"), - ) - except Exception as err: # pylint: disable=broad-except - # Must catch the exception to prevent the loop from collapsing - _LOGGER.exception("Error adding state change: %s", err) + tries += 1 + time.sleep(self.db_retry_wait) - # If they do not have a commit interval - # than we commit right away - if not self.commit_interval: - self._commit_event_session_or_retry() + @callback + def connection_failed(): + """Connect failed tasks.""" + self.async_db_ready.set_result(False) + persistent_notification.async_create( + self.hass, + "The recorder could not start, please check the log", + "Recorder", + ) + + self.hass.add_job(connection_failed) + return False + + def _process_one_event(self, event): + """Process one event.""" + if isinstance(event, PurgeTask): + # Schedule a new purge task if this one didn't finish + if not purge.purge_old_data(self, event.keep_days, event.repack): + self.queue.put(PurgeTask(event.keep_days, event.repack)) + return + if isinstance(event, WaitTask): + self._queue_watch.set() + return + if event.event_type == EVENT_TIME_CHANGED: + self._keepalive_count += 1 + if self._keepalive_count >= KEEPALIVE_TIME: + self._keepalive_count = 0 + self._send_keep_alive() + if self.commit_interval: + self._timechanges_seen += 1 + if self._timechanges_seen >= self.commit_interval: + self._timechanges_seen = 0 + self._commit_event_session_or_recover() + return - def _send_keep_alive(self): try: - _LOGGER.debug("Sending keepalive") - self.event_session.connection().scalar(select([1])) + if event.event_type == EVENT_STATE_CHANGED: + dbevent = Events.from_event(event, event_data="{}") + else: + dbevent = Events.from_event(event) + dbevent.created = event.time_fired + self.event_session.add(dbevent) + except (TypeError, ValueError): + _LOGGER.warning("Event is not JSON serializable: %s", event) return except Exception as err: # pylint: disable=broad-except # Must catch the exception to prevent the loop from collapsing - _LOGGER.error( - "Error in database connectivity during keepalive: %s", - err, - ) - self._reopen_event_session() + _LOGGER.exception("Error adding event: %s", err) + return + + if event.event_type == EVENT_STATE_CHANGED: + try: + dbstate = States.from_event(event) + has_new_state = event.data.get("new_state") + if dbstate.entity_id in self._old_states: + old_state = self._old_states.pop(dbstate.entity_id) + if old_state.state_id: + dbstate.old_state_id = old_state.state_id + else: + dbstate.old_state = old_state + if not has_new_state: + dbstate.state = None + dbstate.event = dbevent + dbstate.created = event.time_fired + self.event_session.add(dbstate) + if has_new_state: + self._old_states[dbstate.entity_id] = dbstate + self._pending_expunge.append(dbstate) + except (TypeError, ValueError): + _LOGGER.warning( + "State is not JSON serializable: %s", + event.data.get("new_state"), + ) + except Exception as err: # pylint: disable=broad-except + # Must catch the exception to prevent the loop from collapsing + _LOGGER.exception("Error adding state change: %s", err) + + # If they do not have a commit interval + # than we commit right away + if not self.commit_interval: + self._commit_event_session_or_recover() + + def _commit_event_session_or_recover(self): + """Commit changes to the database and recover if the database fails when possible.""" + try: + self._commit_event_session_or_retry() + return + except exc.DatabaseError as err: + if isinstance(err.__cause__, sqlite3.DatabaseError): + _LOGGER.exception( + "Unrecoverable sqlite3 database corruption detected: %s", err + ) + self._handle_sqlite_corruption() + return + _LOGGER.exception("Unexpected error saving events: %s", err) + except Exception as err: # pylint: disable=broad-except + # Must catch the exception to prevent the loop from collapsing + _LOGGER.exception("Unexpected error saving events: %s", err) + + self._reopen_event_session() + return def _commit_event_session_or_retry(self): tries = 1 while tries <= self.db_max_retries: - if tries != 1: - time.sleep(self.db_retry_wait) - try: self._commit_event_session() return except (exc.InternalError, exc.OperationalError) as err: if err.connection_invalidated: - _LOGGER.error( - "Database connection invalidated: %s. " - "(retrying in %s seconds)", - err, - self.db_retry_wait, - ) + message = "Database connection invalidated" else: - _LOGGER.error( - "Error in database connectivity during commit: %s. " - "(retrying in %s seconds)", - err, - self.db_retry_wait, - ) + message = "Error in database connectivity during commit" + _LOGGER.error( + "%s: Error executing query: %s. (retrying in %s seconds)", + message, + err, + self.db_retry_wait, + ) + if tries == self.db_max_retries: + raise + tries += 1 - - except Exception as err: # pylint: disable=broad-except - # Must catch the exception to prevent the loop from collapsing - _LOGGER.exception("Error saving events: %s", err) - return - - _LOGGER.error( - "Error in database update. Could not save " "after %d tries. Giving up", - tries, - ) - self._reopen_event_session() - - def _reopen_event_session(self): - try: - self.event_session.rollback() - except Exception as err: # pylint: disable=broad-except - # Must catch the exception to prevent the loop from collapsing - _LOGGER.exception("Error while rolling back event session: %s", err) - - try: - self.event_session.close() - except Exception as err: # pylint: disable=broad-except - # Must catch the exception to prevent the loop from collapsing - _LOGGER.exception("Error while closing event session: %s", err) - - try: - self.event_session = self.get_session() - self.event_session.expire_on_commit = False - except Exception as err: # pylint: disable=broad-except - # Must catch the exception to prevent the loop from collapsing - _LOGGER.exception("Error while creating new event session: %s", err) + time.sleep(self.db_retry_wait) def _commit_event_session(self): self._commits_without_expire += 1 - try: - if self._pending_expunge: - self.event_session.flush() - for dbstate in self._pending_expunge: - # Expunge the state so its not expired - # until we use it later for dbstate.old_state - if dbstate in self.event_session: - self.event_session.expunge(dbstate) - self._pending_expunge = [] - self.event_session.commit() - except exc.IntegrityError as err: - _LOGGER.error( - "Integrity error executing query (database likely deleted out from under us): %s", - err, - ) - self.event_session.rollback() - self._old_states = {} - raise - except Exception as err: - _LOGGER.error("Error executing query: %s", err) - self.event_session.rollback() - raise + if self._pending_expunge: + self.event_session.flush() + for dbstate in self._pending_expunge: + # Expunge the state so its not expired + # until we use it later for dbstate.old_state + if dbstate in self.event_session: + self.event_session.expunge(dbstate) + self._pending_expunge = [] + self.event_session.commit() # Expire is an expensive operation (frequently more expensive # than the flush and commit itself) so we only @@ -547,6 +523,47 @@ class Recorder(threading.Thread): self._commits_without_expire = 0 self.event_session.expire_all() + def _handle_sqlite_corruption(self): + """Handle the sqlite3 database being corrupt.""" + self._close_connection() + move_away_broken_database(dburl_to_path(self.db_url)) + self._setup_recorder() + + def _reopen_event_session(self): + """Rollback the event session and reopen it after a failure.""" + self._old_states = {} + + try: + self.event_session.rollback() + self.event_session.close() + except Exception as err: # pylint: disable=broad-except + # Must catch the exception to prevent the loop from collapsing + _LOGGER.exception( + "Error while rolling back and closing the event session: %s", err + ) + + self._open_event_session() + + def _open_event_session(self): + """Open the event session.""" + try: + self.event_session = self.get_session() + self.event_session.expire_on_commit = False + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error while creating new event session: %s", err) + + def _send_keep_alive(self): + try: + _LOGGER.debug("Sending keepalive") + self.event_session.connection().scalar(select([1])) + return + except Exception as err: # pylint: disable=broad-except + _LOGGER.error( + "Error in database connectivity during keepalive: %s", + err, + ) + self._reopen_event_session() + @callback def event_listener(self, event): """Listen for new events and put them in the process queue.""" @@ -571,6 +588,7 @@ class Recorder(threading.Thread): def _setup_connection(self): """Ensure database is ready to fly.""" kwargs = {} + self._completed_database_setup = False def setup_recorder_connection(dbapi_connection, connection_record): """Dbapi specific connection settings.""" @@ -603,9 +621,7 @@ class Recorder(threading.Thread): else: kwargs["echo"] = False - if self.db_url != SQLITE_URL_PREFIX and self.db_url.startswith( - SQLITE_URL_PREFIX - ): + if self._using_file_sqlite: with self.hass.timeout.freeze(DOMAIN): # # Here we run an sqlite3 quick_check. In the majority @@ -628,6 +644,13 @@ class Recorder(threading.Thread): Base.metadata.create_all(self.engine) self.get_session = scoped_session(sessionmaker(bind=self.engine)) + @property + def _using_file_sqlite(self): + """Short version to check if we are using sqlite3 as a file.""" + return self.db_url != SQLITE_URL_PREFIX and self.db_url.startswith( + SQLITE_URL_PREFIX + ) + def _close_connection(self): """Close the connection.""" self.engine.dispose() @@ -652,12 +675,18 @@ class Recorder(threading.Thread): session.flush() session.expunge(self.run_info) - def _close_run(self): + def _shutdown(self): """Save end time for current run.""" if self.event_session is not None: self.run_info.end = dt_util.utcnow() self.event_session.add(self.run_info) - self._commit_event_session_or_retry() - self.event_session.close() + try: + self._commit_event_session_or_retry() + self.event_session.close() + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception( + "Error saving the event session during shutdown: %s", err + ) self.run_info = None + self._close_connection() diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 41bca335a56..b945386de82 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -112,19 +112,24 @@ def execute(qry, to_native=False, validate_entity_ids=True): def validate_or_move_away_sqlite_database(dburl: str, db_integrity_check: bool) -> bool: """Ensure that the database is valid or move it away.""" - dbpath = dburl[len(SQLITE_URL_PREFIX) :] + dbpath = dburl_to_path(dburl) if not os.path.exists(dbpath): # Database does not exist yet, this is OK return True if not validate_sqlite_database(dbpath, db_integrity_check): - _move_away_broken_database(dbpath) + move_away_broken_database(dbpath) return False return True +def dburl_to_path(dburl): + """Convert the db url into a filesystem path.""" + return dburl[len(SQLITE_URL_PREFIX) :] + + def last_run_was_recently_clean(cursor): """Verify the last recorder run was recently clean.""" @@ -208,7 +213,7 @@ def run_checks_on_open_db(dbpath, cursor, db_integrity_check): cursor.execute("PRAGMA QUICK_CHECK") -def _move_away_broken_database(dbfile: str) -> None: +def move_away_broken_database(dbfile: str) -> None: """Move away a broken sqlite3 database.""" isotime = dt_util.utcnow().isoformat() diff --git a/tests/components/recorder/common.py b/tests/components/recorder/common.py index 1d0e6dbbfa0..d2b731777e2 100644 --- a/tests/components/recorder/common.py +++ b/tests/components/recorder/common.py @@ -10,14 +10,27 @@ from tests.common import fire_time_changed def wait_recording_done(hass): """Block till recording is done.""" + hass.block_till_done() trigger_db_commit(hass) hass.block_till_done() hass.data[recorder.DATA_INSTANCE].block_till_done() hass.block_till_done() +async def async_wait_recording_done(hass): + """Block till recording is done.""" + await hass.loop.run_in_executor(None, wait_recording_done, hass) + + def trigger_db_commit(hass): """Force the recorder to commit.""" for _ in range(recorder.DEFAULT_COMMIT_INTERVAL): # We only commit on time change fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=1)) + + +def corrupt_db_file(test_db_file): + """Corrupt an sqlite3 database file.""" + with open(test_db_file, "w+") as fhandle: + fhandle.seek(200) + fhandle.write("I am a corrupt db" * 100) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 3b71648166e..ca25fe10284 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -6,6 +6,7 @@ from unittest.mock import patch from sqlalchemy.exc import OperationalError from homeassistant.components.recorder import ( + CONF_DB_URL, CONFIG_SCHEMA, DOMAIN, Recorder, @@ -13,7 +14,7 @@ from homeassistant.components.recorder import ( run_information_from_instance, run_information_with_session, ) -from homeassistant.components.recorder.const import DATA_INSTANCE +from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.models import Events, RecorderRuns, States from homeassistant.components.recorder.util import session_scope from homeassistant.const import ( @@ -26,7 +27,7 @@ from homeassistant.core import Context, CoreState, callback from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util -from .common import wait_recording_done +from .common import async_wait_recording_done, corrupt_db_file, wait_recording_done from tests.common import ( async_init_recorder_component, @@ -519,3 +520,52 @@ def test_run_information(hass_recorder): class CannotSerializeMe: """A class that the JSONEncoder cannot serialize.""" + + +async def test_database_corruption_while_running(hass, tmpdir, caplog): + """Test we can recover from sqlite3 db corruption.""" + + def _create_tmpdir_for_test_db(): + return tmpdir.mkdir("sqlite").join("test.db") + + test_db_file = await hass.async_add_executor_job(_create_tmpdir_for_test_db) + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}}) + await hass.async_block_till_done() + caplog.clear() + + hass.states.async_set("test.lost", "on", {}) + + await async_wait_recording_done(hass) + await hass.async_add_executor_job(corrupt_db_file, test_db_file) + await async_wait_recording_done(hass) + + # This state will not be recorded because + # the database corruption will be discovered + # and we will have to rollback to recover + hass.states.async_set("test.one", "off", {}) + await async_wait_recording_done(hass) + + assert "Unrecoverable sqlite3 database corruption detected" in caplog.text + assert "The system will rename the corrupt database file" in caplog.text + assert "Connected to recorder database" in caplog.text + + # This state should go into the new database + hass.states.async_set("test.two", "on", {}) + await async_wait_recording_done(hass) + + def _get_last_state(): + with session_scope(hass=hass) as session: + db_states = list(session.query(States)) + assert len(db_states) == 1 + assert db_states[0].event_id > 0 + return db_states[0].to_native() + + state = await hass.async_add_executor_job(_get_last_state) + assert state.entity_id == "test.two" + assert state.state == "on" + + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + hass.stop() diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 38df1285008..f1d55999ae4 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -10,7 +10,7 @@ from homeassistant.components.recorder import util from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.util import dt as dt_util -from .common import wait_recording_done +from .common import corrupt_db_file, wait_recording_done from tests.common import get_test_home_assistant, init_recorder_component @@ -90,7 +90,7 @@ def test_validate_or_move_away_sqlite_database_with_integrity_check( util.validate_or_move_away_sqlite_database(dburl, db_integrity_check) is False ) - _corrupt_db_file(test_db_file) + corrupt_db_file(test_db_file) assert util.validate_sqlite_database(dburl, db_integrity_check) is False @@ -127,7 +127,7 @@ def test_validate_or_move_away_sqlite_database_without_integrity_check( util.validate_or_move_away_sqlite_database(dburl, db_integrity_check) is False ) - _corrupt_db_file(test_db_file) + corrupt_db_file(test_db_file) assert util.validate_sqlite_database(dburl, db_integrity_check) is False @@ -150,7 +150,7 @@ def test_last_run_was_recently_clean(hass_recorder): assert util.last_run_was_recently_clean(cursor) is False - hass.data[DATA_INSTANCE]._close_run() + hass.data[DATA_INSTANCE]._shutdown() wait_recording_done(hass) assert util.last_run_was_recently_clean(cursor) is True @@ -244,10 +244,3 @@ def test_combined_checks(hass_recorder, caplog): caplog.clear() with pytest.raises(sqlite3.DatabaseError): util.run_checks_on_open_db("fake_db_path", cursor, True) - - -def _corrupt_db_file(test_db_file): - """Corrupt an sqlite3 database file.""" - f = open(test_db_file, "a") - f.write("I am a corrupt db") - f.close()