mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Use StaticPool for recorder and NullPool for all other threads with sqlite3 (#49693)
This commit is contained in:
parent
d9714e6b79
commit
b27e9e376d
@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util
|
||||
from . import migration, purge
|
||||
from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX
|
||||
from .models import Base, Events, RecorderRuns, States
|
||||
from .pool import RecorderPool
|
||||
from .util import (
|
||||
dburl_to_path,
|
||||
end_incomplete_runs,
|
||||
@ -783,6 +784,8 @@ class Recorder(threading.Thread):
|
||||
kwargs["connect_args"] = {"check_same_thread": False}
|
||||
kwargs["poolclass"] = StaticPool
|
||||
kwargs["pool_reset_on_return"] = None
|
||||
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||
kwargs["poolclass"] = RecorderPool
|
||||
else:
|
||||
kwargs["echo"] = False
|
||||
|
||||
|
34
homeassistant/components/recorder/pool.py
Normal file
34
homeassistant/components/recorder/pool.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""A pool for sqlite connections."""
|
||||
import threading
|
||||
|
||||
from sqlalchemy.pool import NullPool, StaticPool
|
||||
|
||||
|
||||
class RecorderPool(StaticPool, NullPool):
|
||||
"""A hybird of NullPool and StaticPool.
|
||||
|
||||
When called from the creating thread acts like StaticPool
|
||||
When called from any other thread, acts like NullPool
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
||||
"""Create the pool."""
|
||||
self._tid = threading.current_thread().ident
|
||||
StaticPool.__init__(self, *args, **kw)
|
||||
|
||||
def _do_return_conn(self, conn):
|
||||
if threading.current_thread().ident == self._tid:
|
||||
return super()._do_return_conn(conn)
|
||||
conn.close()
|
||||
|
||||
def dispose(self):
|
||||
"""Dispose of the connection."""
|
||||
if threading.current_thread().ident == self._tid:
|
||||
return super().dispose()
|
||||
|
||||
def _do_get(self):
|
||||
if threading.current_thread().ident == self._tid:
|
||||
return super()._do_get()
|
||||
return super( # pylint: disable=bad-super-call
|
||||
NullPool, self
|
||||
)._create_connection()
|
@ -1,9 +1,10 @@
|
||||
"""The tests for the Recorder component."""
|
||||
# pylint: disable=protected-access
|
||||
from datetime import datetime, timedelta
|
||||
import sqlite3
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
||||
|
||||
from homeassistant.components import recorder
|
||||
from homeassistant.components.recorder import (
|
||||
@ -885,6 +886,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
|
||||
|
||||
hass.states.async_set("test.lost", "on", {})
|
||||
|
||||
sqlite3_exception = DatabaseError("statement", {}, [])
|
||||
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
|
||||
|
||||
with patch.object(
|
||||
hass.data[DATA_INSTANCE].event_session,
|
||||
"close",
|
||||
@ -894,11 +898,16 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
|
||||
await hass.async_add_executor_job(corrupt_db_file, test_db_file)
|
||||
await async_wait_recording_done_without_instance(hass)
|
||||
|
||||
# This state will not be recorded because
|
||||
# the database corruption will be discovered
|
||||
# and we will have to rollback to recover
|
||||
hass.states.async_set("test.one", "off", {})
|
||||
await async_wait_recording_done_without_instance(hass)
|
||||
with patch.object(
|
||||
hass.data[DATA_INSTANCE].event_session,
|
||||
"commit",
|
||||
side_effect=[sqlite3_exception, None],
|
||||
):
|
||||
# This state will not be recorded because
|
||||
# the database corruption will be discovered
|
||||
# and we will have to rollback to recover
|
||||
hass.states.async_set("test.one", "off", {})
|
||||
await async_wait_recording_done_without_instance(hass)
|
||||
|
||||
assert "Unrecoverable sqlite3 database corruption detected" in caplog.text
|
||||
assert "The system will rename the corrupt database file" in caplog.text
|
||||
|
34
tests/components/recorder/test_pool.py
Normal file
34
tests/components/recorder/test_pool.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Test pool."""
|
||||
import threading
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from homeassistant.components.recorder.pool import RecorderPool
|
||||
|
||||
|
||||
def test_recorder_pool():
|
||||
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||
|
||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
||||
get_session = sessionmaker(bind=engine)
|
||||
|
||||
connections = []
|
||||
|
||||
def _get_connection_twice():
|
||||
session = get_session()
|
||||
connections.append(session.connection().connection.connection)
|
||||
session.close()
|
||||
|
||||
session = get_session()
|
||||
connections.append(session.connection().connection.connection)
|
||||
session.close()
|
||||
|
||||
_get_connection_twice()
|
||||
assert connections[0] == connections[1]
|
||||
|
||||
new_thread = threading.Thread(target=_get_connection_twice)
|
||||
new_thread.start()
|
||||
new_thread.join()
|
||||
|
||||
assert connections[2] != connections[3]
|
Loading…
x
Reference in New Issue
Block a user