Allow to lock SQLite database during backup (#60874)

* Allow to set CONF_DB_URL

This is useful for test which need a custom DB path.

* Introduce write_lock_db helper to lock SQLite database

* Introduce Websocket API which allows to lock database during backup

* Fix isort

* Avoid mutable default arguments

* Address pylint issues

* Avoid holding executor thread

* Set unlock event in case timeout occures

This makes sure the database is left unlocked even in case of a race
condition.

* Add more unit tests

* Address new pylint errors

* Lower timeout to speedup tests

* Introduce queue overflow test

* Unlock database if necessary

This makes sure that the test runs through in case locking actually
succeeds (and the test fails).

* Make DB_LOCK_TIMEOUT a global

There is no good reason for this to be an argument. The recorder needs
to pick a sensible value.

* Add Websocket Timeout test

* Test lock_database() return

* Update homeassistant/components/recorder/__init__.py

Co-authored-by: Erik Montnemery <erik@montnemery.com>

* Fix format

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Stefan Agner 2021-12-07 13:16:24 +01:00 committed by GitHub
parent 4eeee79517
commit f0006b92be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 310 additions and 11 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import concurrent.futures import concurrent.futures
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
import queue import queue
@ -76,6 +77,7 @@ from .util import (
session_scope, session_scope,
setup_connection_for_dialect, setup_connection_for_dialect,
validate_or_move_away_sqlite_database, validate_or_move_away_sqlite_database,
write_lock_db,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -123,6 +125,9 @@ KEEPALIVE_TIME = 30
# States and Events objects # States and Events objects
EXPIRE_AFTER_COMMITS = 120 EXPIRE_AFTER_COMMITS = 120
DB_LOCK_TIMEOUT = 30
DB_LOCK_QUEUE_CHECK_TIMEOUT = 1
CONF_AUTO_PURGE = "auto_purge" CONF_AUTO_PURGE = "auto_purge"
CONF_DB_URL = "db_url" CONF_DB_URL = "db_url"
CONF_DB_MAX_RETRIES = "db_max_retries" CONF_DB_MAX_RETRIES = "db_max_retries"
@ -370,6 +375,15 @@ class WaitTask:
"""An object to insert into the recorder queue to tell it set the _queue_watch event.""" """An object to insert into the recorder queue to tell it set the _queue_watch event."""
@dataclass
class DatabaseLockTask:
"""An object to insert into the recorder queue to prevent writes to the database."""
database_locked: asyncio.Event
database_unlock: threading.Event
queue_overflow: bool
class Recorder(threading.Thread): class Recorder(threading.Thread):
"""A threaded recorder class.""" """A threaded recorder class."""
@ -419,6 +433,7 @@ class Recorder(threading.Thread):
self.migration_in_progress = False self.migration_in_progress = False
self._queue_watcher = None self._queue_watcher = None
self._db_supports_row_number = True self._db_supports_row_number = True
self._database_lock_task: DatabaseLockTask | None = None
self.enabled = True self.enabled = True
@ -687,6 +702,8 @@ class Recorder(threading.Thread):
def _process_one_event_or_recover(self, event): def _process_one_event_or_recover(self, event):
"""Process an event, reconnect, or recover a malformed database.""" """Process an event, reconnect, or recover a malformed database."""
try: try:
if self._process_one_task(event):
return
self._process_one_event(event) self._process_one_event(event)
return return
except exc.DatabaseError as err: except exc.DatabaseError as err:
@ -788,34 +805,63 @@ class Recorder(threading.Thread):
# Schedule a new statistics task if this one didn't finish # Schedule a new statistics task if this one didn't finish
self.queue.put(ExternalStatisticsTask(metadata, stats)) self.queue.put(ExternalStatisticsTask(metadata, stats))
def _process_one_event(self, event): def _lock_database(self, task: DatabaseLockTask):
@callback
def _async_set_database_locked(task: DatabaseLockTask):
task.database_locked.set()
with write_lock_db(self):
# Notify that lock is being held, wait until database can be used again.
self.hass.add_job(_async_set_database_locked, task)
while not task.database_unlock.wait(timeout=DB_LOCK_QUEUE_CHECK_TIMEOUT):
if self.queue.qsize() > MAX_QUEUE_BACKLOG * 0.9:
_LOGGER.warning(
"Database queue backlog reached more than 90% of maximum queue "
"length while waiting for backup to finish; recorder will now "
"resume writing to database. The backup can not be trusted and "
"must be restarted"
)
task.queue_overflow = True
break
_LOGGER.info(
"Database queue backlog reached %d entries during backup",
self.queue.qsize(),
)
def _process_one_task(self, event) -> bool:
"""Process one event.""" """Process one event."""
if isinstance(event, PurgeTask): if isinstance(event, PurgeTask):
self._run_purge(event.purge_before, event.repack, event.apply_filter) self._run_purge(event.purge_before, event.repack, event.apply_filter)
return return True
if isinstance(event, PurgeEntitiesTask): if isinstance(event, PurgeEntitiesTask):
self._run_purge_entities(event.entity_filter) self._run_purge_entities(event.entity_filter)
return return True
if isinstance(event, PerodicCleanupTask): if isinstance(event, PerodicCleanupTask):
perodic_db_cleanups(self) perodic_db_cleanups(self)
return return True
if isinstance(event, StatisticsTask): if isinstance(event, StatisticsTask):
self._run_statistics(event.start) self._run_statistics(event.start)
return return True
if isinstance(event, ClearStatisticsTask): if isinstance(event, ClearStatisticsTask):
statistics.clear_statistics(self, event.statistic_ids) statistics.clear_statistics(self, event.statistic_ids)
return return True
if isinstance(event, UpdateStatisticsMetadataTask): if isinstance(event, UpdateStatisticsMetadataTask):
statistics.update_statistics_metadata( statistics.update_statistics_metadata(
self, event.statistic_id, event.unit_of_measurement self, event.statistic_id, event.unit_of_measurement
) )
return return True
if isinstance(event, ExternalStatisticsTask): if isinstance(event, ExternalStatisticsTask):
self._run_external_statistics(event.metadata, event.statistics) self._run_external_statistics(event.metadata, event.statistics)
return return True
if isinstance(event, WaitTask): if isinstance(event, WaitTask):
self._queue_watch.set() self._queue_watch.set()
return return True
if isinstance(event, DatabaseLockTask):
self._lock_database(event)
return True
return False
def _process_one_event(self, event):
if event.event_type == EVENT_TIME_CHANGED: if event.event_type == EVENT_TIME_CHANGED:
self._keepalive_count += 1 self._keepalive_count += 1
if self._keepalive_count >= KEEPALIVE_TIME: if self._keepalive_count >= KEEPALIVE_TIME:
@ -982,6 +1028,42 @@ class Recorder(threading.Thread):
self.queue.put(WaitTask()) self.queue.put(WaitTask())
self._queue_watch.wait() self._queue_watch.wait()
async def lock_database(self) -> bool:
"""Lock database so it can be backed up safely."""
if self._database_lock_task:
_LOGGER.warning("Database already locked")
return False
database_locked = asyncio.Event()
task = DatabaseLockTask(database_locked, threading.Event(), False)
self.queue.put(task)
try:
await asyncio.wait_for(database_locked.wait(), timeout=DB_LOCK_TIMEOUT)
except asyncio.TimeoutError as err:
task.database_unlock.set()
raise TimeoutError(
f"Could not lock database within {DB_LOCK_TIMEOUT} seconds."
) from err
self._database_lock_task = task
return True
@callback
def unlock_database(self) -> bool:
"""Unlock database.
Returns true if database lock has been held throughout the process.
"""
if not self._database_lock_task:
_LOGGER.warning("Database currently not locked")
return False
self._database_lock_task.database_unlock.set()
success = not self._database_lock_task.queue_overflow
self._database_lock_task = None
return success
def _setup_connection(self): def _setup_connection(self):
"""Ensure database is ready to fly.""" """Ensure database is ready to fly."""
kwargs = {} kwargs = {}

View File

@ -457,6 +457,25 @@ def perodic_db_cleanups(instance: Recorder):
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE);")) connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE);"))
@contextmanager
def write_lock_db(instance: Recorder):
"""Lock database for writes."""
if instance.engine.dialect.name == "sqlite":
with instance.engine.connect() as connection:
# Execute sqlite to create a wal checkpoint
# This is optional but makes sure the backup is going to be minimal
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
# Create write lock
_LOGGER.debug("Lock database")
connection.execute(text("BEGIN IMMEDIATE;"))
try:
yield
finally:
_LOGGER.debug("Unlock database")
connection.execute(text("END;"))
def async_migration_in_progress(hass: HomeAssistant) -> bool: def async_migration_in_progress(hass: HomeAssistant) -> bool:
"""Determine is a migration is in progress. """Determine is a migration is in progress.

View File

@ -1,6 +1,7 @@
"""The Energy websocket API.""" """The Energy websocket API."""
from __future__ import annotations from __future__ import annotations
import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import voluptuous as vol import voluptuous as vol
@ -15,6 +16,8 @@ from .util import async_migration_in_progress
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Recorder from . import Recorder
_LOGGER: logging.Logger = logging.getLogger(__package__)
@callback @callback
def async_setup(hass: HomeAssistant) -> None: def async_setup(hass: HomeAssistant) -> None:
@ -23,6 +26,8 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, ws_clear_statistics) websocket_api.async_register_command(hass, ws_clear_statistics)
websocket_api.async_register_command(hass, ws_update_statistics_metadata) websocket_api.async_register_command(hass, ws_update_statistics_metadata)
websocket_api.async_register_command(hass, ws_info) websocket_api.async_register_command(hass, ws_info)
websocket_api.async_register_command(hass, ws_backup_start)
websocket_api.async_register_command(hass, ws_backup_end)
@websocket_api.websocket_command( @websocket_api.websocket_command(
@ -106,3 +111,38 @@ def ws_info(
"thread_running": thread_alive, "thread_running": thread_alive,
} }
connection.send_result(msg["id"], recorder_info) connection.send_result(msg["id"], recorder_info)
@websocket_api.require_admin
@websocket_api.websocket_command({vol.Required("type"): "backup/start"})
@websocket_api.async_response
async def ws_backup_start(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Backup start notification."""
_LOGGER.info("Backup start notification, locking database for writes")
instance: Recorder = hass.data[DATA_INSTANCE]
try:
await instance.lock_database()
except TimeoutError as err:
connection.send_error(msg["id"], "timeout_error", str(err))
return
connection.send_result(msg["id"])
@websocket_api.require_admin
@websocket_api.websocket_command({vol.Required("type"): "backup/end"})
@websocket_api.async_response
async def ws_backup_end(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Backup end notification."""
instance: Recorder = hass.data[DATA_INSTANCE]
_LOGGER.info("Backup end notification, releasing write lock")
if not instance.unlock_database():
connection.send_error(
msg["id"], "database_unlock_failed", "Failed to unlock database."
)
connection.send_result(msg["id"])

View File

@ -902,7 +902,8 @@ def init_recorder_component(hass, add_config=None):
async def async_init_recorder_component(hass, add_config=None): async def async_init_recorder_component(hass, add_config=None):
"""Initialize the recorder asynchronously.""" """Initialize the recorder asynchronously."""
config = dict(add_config) if add_config else {} config = add_config or {}
if recorder.CONF_DB_URL not in config:
config[recorder.CONF_DB_URL] = "sqlite://" config[recorder.CONF_DB_URL] = "sqlite://"
with patch("homeassistant.components.recorder.migration.migrate_schema"): with patch("homeassistant.components.recorder.migration.migrate_schema"):

View File

@ -1,5 +1,6 @@
"""The tests for the Recorder component.""" """The tests for the Recorder component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import sqlite3 import sqlite3
from unittest.mock import patch from unittest.mock import patch
@ -1134,3 +1135,81 @@ def test_entity_id_filter(hass_recorder):
db_events = list(session.query(Events).filter_by(event_type="hello")) db_events = list(session.query(Events).filter_by(event_type="hello"))
# Keep referring idx + 1, as no new events are being added # Keep referring idx + 1, as no new events are being added
assert len(db_events) == idx + 1, data assert len(db_events) == idx + 1, data
async def test_database_lock_and_unlock(hass: HomeAssistant, tmp_path):
"""Test writing events during lock getting written after unlocking."""
# Use file DB, in memory DB cannot do write locks.
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
await async_init_recorder_component(hass, config)
await hass.async_block_till_done()
instance: Recorder = hass.data[DATA_INSTANCE]
assert await instance.lock_database()
assert not await instance.lock_database()
event_type = "EVENT_TEST"
event_data = {"test_attr": 5, "test_attr_10": "nice"}
hass.bus.fire(event_type, event_data)
task = asyncio.create_task(async_wait_recording_done(hass, instance))
# Recording can't be finished while lock is held
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(task), timeout=1)
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type=event_type))
assert len(db_events) == 0
assert instance.unlock_database()
await task
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type=event_type))
assert len(db_events) == 1
async def test_database_lock_and_overflow(hass: HomeAssistant, tmp_path):
"""Test writing events during lock leading to overflow the queue causes the database to unlock."""
# Use file DB, in memory DB cannot do write locks.
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
await async_init_recorder_component(hass, config)
await hass.async_block_till_done()
instance: Recorder = hass.data[DATA_INSTANCE]
with patch.object(recorder, "MAX_QUEUE_BACKLOG", 1), patch.object(
recorder, "DB_LOCK_QUEUE_CHECK_TIMEOUT", 0.1
):
await instance.lock_database()
event_type = "EVENT_TEST"
event_data = {"test_attr": 5, "test_attr_10": "nice"}
hass.bus.fire(event_type, event_data)
# Check that this causes the queue to overflow and write succeeds
# even before unlocking.
await async_wait_recording_done(hass, instance)
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type=event_type))
assert len(db_events) == 1
assert not instance.unlock_database()
async def test_database_lock_timeout(hass):
"""Test locking database timeout when recorder stopped."""
await async_init_recorder_component(hass)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
instance: Recorder = hass.data[DATA_INSTANCE]
with patch.object(recorder, "DB_LOCK_TIMEOUT", 0.1):
try:
with pytest.raises(TimeoutError):
await instance.lock_database()
finally:
instance.unlock_database()

View File

@ -8,6 +8,7 @@ import pytest
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.elements import TextClause
from homeassistant.components import recorder
from homeassistant.components.recorder import run_information_with_session, util from homeassistant.components.recorder import run_information_with_session, util
from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX from homeassistant.components.recorder.const import DATA_INSTANCE, SQLITE_URL_PREFIX
from homeassistant.components.recorder.models import RecorderRuns from homeassistant.components.recorder.models import RecorderRuns
@ -556,3 +557,21 @@ def test_perodic_db_cleanups(hass_recorder):
][0] ][0]
assert isinstance(text_obj, TextClause) assert isinstance(text_obj, TextClause)
assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);" assert str(text_obj) == "PRAGMA wal_checkpoint(TRUNCATE);"
async def test_write_lock_db(hass, tmp_path):
"""Test database write lock."""
from sqlalchemy.exc import OperationalError
# Use file DB, in memory DB cannot do write locks.
config = {recorder.CONF_DB_URL: "sqlite:///" + str(tmp_path / "pytest.db")}
await async_init_recorder_component(hass, config)
await hass.async_block_till_done()
instance = hass.data[DATA_INSTANCE]
with util.write_lock_db(instance):
# Database should be locked now, try writing SQL command
with instance.engine.connect() as connection:
with pytest.raises(OperationalError):
connection.execute(text("DROP TABLE events;"))

View File

@ -358,3 +358,62 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client):
assert response["result"]["migration_in_progress"] is False assert response["result"]["migration_in_progress"] is False
assert response["result"]["recording"] is True assert response["result"]["recording"] is True
assert response["result"]["thread_running"] is True assert response["result"]["thread_running"] is True
async def test_backup_start_no_recorder(hass, hass_ws_client):
"""Test getting backup start when recorder is not present."""
client = await hass_ws_client()
await client.send_json({"id": 1, "type": "backup/start"})
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "unknown_command"
async def test_backup_start_timeout(hass, hass_ws_client):
"""Test getting backup start when recorder is not present."""
client = await hass_ws_client()
await async_init_recorder_component(hass)
# Ensure there are no queued events
await async_wait_recording_done_without_instance(hass)
with patch.object(recorder, "DB_LOCK_TIMEOUT", 0):
try:
await client.send_json({"id": 1, "type": "backup/start"})
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "timeout_error"
finally:
await client.send_json({"id": 2, "type": "backup/end"})
async def test_backup_end(hass, hass_ws_client):
"""Test backup start."""
client = await hass_ws_client()
await async_init_recorder_component(hass)
# Ensure there are no queued events
await async_wait_recording_done_without_instance(hass)
await client.send_json({"id": 1, "type": "backup/start"})
response = await client.receive_json()
assert response["success"]
await client.send_json({"id": 2, "type": "backup/end"})
response = await client.receive_json()
assert response["success"]
async def test_backup_end_without_start(hass, hass_ws_client):
"""Test backup start."""
client = await hass_ws_client()
await async_init_recorder_component(hass)
# Ensure there are no queued events
await async_wait_recording_done_without_instance(hass)
await client.send_json({"id": 1, "type": "backup/end"})
response = await client.receive_json()
assert not response["success"]
assert response["error"]["code"] == "database_unlock_failed"