mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 09:17:10 +00:00
Use StaticPool for recorder and NullPool for all other threads with sqlite3 (#49693)
This commit is contained in:
parent
d9714e6b79
commit
b27e9e376d
@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util
|
|||||||
from . import migration, purge
|
from . import migration, purge
|
||||||
from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX
|
from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX
|
||||||
from .models import Base, Events, RecorderRuns, States
|
from .models import Base, Events, RecorderRuns, States
|
||||||
|
from .pool import RecorderPool
|
||||||
from .util import (
|
from .util import (
|
||||||
dburl_to_path,
|
dburl_to_path,
|
||||||
end_incomplete_runs,
|
end_incomplete_runs,
|
||||||
@ -783,6 +784,8 @@ class Recorder(threading.Thread):
|
|||||||
kwargs["connect_args"] = {"check_same_thread": False}
|
kwargs["connect_args"] = {"check_same_thread": False}
|
||||||
kwargs["poolclass"] = StaticPool
|
kwargs["poolclass"] = StaticPool
|
||||||
kwargs["pool_reset_on_return"] = None
|
kwargs["pool_reset_on_return"] = None
|
||||||
|
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||||
|
kwargs["poolclass"] = RecorderPool
|
||||||
else:
|
else:
|
||||||
kwargs["echo"] = False
|
kwargs["echo"] = False
|
||||||
|
|
||||||
|
34
homeassistant/components/recorder/pool.py
Normal file
34
homeassistant/components/recorder/pool.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""A pool for sqlite connections."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from sqlalchemy.pool import NullPool, StaticPool
|
||||||
|
|
||||||
|
|
||||||
|
class RecorderPool(StaticPool, NullPool):
|
||||||
|
"""A hybird of NullPool and StaticPool.
|
||||||
|
|
||||||
|
When called from the creating thread acts like StaticPool
|
||||||
|
When called from any other thread, acts like NullPool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
||||||
|
"""Create the pool."""
|
||||||
|
self._tid = threading.current_thread().ident
|
||||||
|
StaticPool.__init__(self, *args, **kw)
|
||||||
|
|
||||||
|
def _do_return_conn(self, conn):
|
||||||
|
if threading.current_thread().ident == self._tid:
|
||||||
|
return super()._do_return_conn(conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def dispose(self):
|
||||||
|
"""Dispose of the connection."""
|
||||||
|
if threading.current_thread().ident == self._tid:
|
||||||
|
return super().dispose()
|
||||||
|
|
||||||
|
def _do_get(self):
|
||||||
|
if threading.current_thread().ident == self._tid:
|
||||||
|
return super()._do_get()
|
||||||
|
return super( # pylint: disable=bad-super-call
|
||||||
|
NullPool, self
|
||||||
|
)._create_connection()
|
@ -1,9 +1,10 @@
|
|||||||
"""The tests for the Recorder component."""
|
"""The tests for the Recorder component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
import sqlite3
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
||||||
|
|
||||||
from homeassistant.components import recorder
|
from homeassistant.components import recorder
|
||||||
from homeassistant.components.recorder import (
|
from homeassistant.components.recorder import (
|
||||||
@ -885,6 +886,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
|
|||||||
|
|
||||||
hass.states.async_set("test.lost", "on", {})
|
hass.states.async_set("test.lost", "on", {})
|
||||||
|
|
||||||
|
sqlite3_exception = DatabaseError("statement", {}, [])
|
||||||
|
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
hass.data[DATA_INSTANCE].event_session,
|
hass.data[DATA_INSTANCE].event_session,
|
||||||
"close",
|
"close",
|
||||||
@ -894,11 +898,16 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
|
|||||||
await hass.async_add_executor_job(corrupt_db_file, test_db_file)
|
await hass.async_add_executor_job(corrupt_db_file, test_db_file)
|
||||||
await async_wait_recording_done_without_instance(hass)
|
await async_wait_recording_done_without_instance(hass)
|
||||||
|
|
||||||
# This state will not be recorded because
|
with patch.object(
|
||||||
# the database corruption will be discovered
|
hass.data[DATA_INSTANCE].event_session,
|
||||||
# and we will have to rollback to recover
|
"commit",
|
||||||
hass.states.async_set("test.one", "off", {})
|
side_effect=[sqlite3_exception, None],
|
||||||
await async_wait_recording_done_without_instance(hass)
|
):
|
||||||
|
# This state will not be recorded because
|
||||||
|
# the database corruption will be discovered
|
||||||
|
# and we will have to rollback to recover
|
||||||
|
hass.states.async_set("test.one", "off", {})
|
||||||
|
await async_wait_recording_done_without_instance(hass)
|
||||||
|
|
||||||
assert "Unrecoverable sqlite3 database corruption detected" in caplog.text
|
assert "Unrecoverable sqlite3 database corruption detected" in caplog.text
|
||||||
assert "The system will rename the corrupt database file" in caplog.text
|
assert "The system will rename the corrupt database file" in caplog.text
|
||||||
|
34
tests/components/recorder/test_pool.py
Normal file
34
tests/components/recorder/test_pool.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""Test pool."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from homeassistant.components.recorder.pool import RecorderPool
|
||||||
|
|
||||||
|
|
||||||
|
def test_recorder_pool():
|
||||||
|
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||||
|
|
||||||
|
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||||
|
get_session = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
def _get_connection_twice():
|
||||||
|
session = get_session()
|
||||||
|
connections.append(session.connection().connection.connection)
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
session = get_session()
|
||||||
|
connections.append(session.connection().connection.connection)
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
_get_connection_twice()
|
||||||
|
assert connections[0] == connections[1]
|
||||||
|
|
||||||
|
new_thread = threading.Thread(target=_get_connection_twice)
|
||||||
|
new_thread.start()
|
||||||
|
new_thread.join()
|
||||||
|
|
||||||
|
assert connections[2] != connections[3]
|
Loading…
x
Reference in New Issue
Block a user