mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Improve recorder and worker thread matching in RecorderPool (#116886)
* Improve recorder and worker thread matching in RecorderPool Previously we would look at the name of the threads. This was a brittle if because other integrations may name their thread Recorder or DbWorker. Instead we now use explict thread ids which ensures there will never be a conflict * fix * fixes * fixes
This commit is contained in:
parent
ee031f4850
commit
6339c63176
@ -187,6 +187,7 @@ class Recorder(threading.Thread):
|
|||||||
|
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.thread_id: int | None = None
|
self.thread_id: int | None = None
|
||||||
|
self.recorder_and_worker_thread_ids: set[int] = set()
|
||||||
self.auto_purge = auto_purge
|
self.auto_purge = auto_purge
|
||||||
self.auto_repack = auto_repack
|
self.auto_repack = auto_repack
|
||||||
self.keep_days = keep_days
|
self.keep_days = keep_days
|
||||||
@ -294,6 +295,7 @@ class Recorder(threading.Thread):
|
|||||||
def async_start_executor(self) -> None:
|
def async_start_executor(self) -> None:
|
||||||
"""Start the executor."""
|
"""Start the executor."""
|
||||||
self._db_executor = DBInterruptibleThreadPoolExecutor(
|
self._db_executor = DBInterruptibleThreadPoolExecutor(
|
||||||
|
self.recorder_and_worker_thread_ids,
|
||||||
thread_name_prefix=DB_WORKER_PREFIX,
|
thread_name_prefix=DB_WORKER_PREFIX,
|
||||||
max_workers=MAX_DB_EXECUTOR_WORKERS,
|
max_workers=MAX_DB_EXECUTOR_WORKERS,
|
||||||
shutdown_hook=self._shutdown_pool,
|
shutdown_hook=self._shutdown_pool,
|
||||||
@ -717,7 +719,10 @@ class Recorder(threading.Thread):
|
|||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
"""Start processing events to save."""
|
"""Start processing events to save."""
|
||||||
self.thread_id = threading.get_ident()
|
thread_id = threading.get_ident()
|
||||||
|
self.thread_id = thread_id
|
||||||
|
self.recorder_and_worker_thread_ids.add(thread_id)
|
||||||
|
|
||||||
setup_result = self._setup_recorder()
|
setup_result = self._setup_recorder()
|
||||||
|
|
||||||
if not setup_result:
|
if not setup_result:
|
||||||
@ -1411,6 +1416,9 @@ class Recorder(threading.Thread):
|
|||||||
kwargs["pool_reset_on_return"] = None
|
kwargs["pool_reset_on_return"] = None
|
||||||
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
elif self.db_url.startswith(SQLITE_URL_PREFIX):
|
||||||
kwargs["poolclass"] = RecorderPool
|
kwargs["poolclass"] = RecorderPool
|
||||||
|
kwargs["recorder_and_worker_thread_ids"] = (
|
||||||
|
self.recorder_and_worker_thread_ids
|
||||||
|
)
|
||||||
elif self.db_url.startswith(
|
elif self.db_url.startswith(
|
||||||
(
|
(
|
||||||
MARIADB_URL_PREFIX,
|
MARIADB_URL_PREFIX,
|
||||||
|
@ -12,9 +12,13 @@ from homeassistant.util.executor import InterruptibleThreadPoolExecutor
|
|||||||
|
|
||||||
|
|
||||||
def _worker_with_shutdown_hook(
|
def _worker_with_shutdown_hook(
|
||||||
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any
|
shutdown_hook: Callable[[], None],
|
||||||
|
recorder_and_worker_thread_ids: set[int],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a worker that calls a function after its finished."""
|
"""Create a worker that calls a function after its finished."""
|
||||||
|
recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||||
_worker(*args, **kwargs)
|
_worker(*args, **kwargs)
|
||||||
shutdown_hook()
|
shutdown_hook()
|
||||||
|
|
||||||
@ -22,9 +26,12 @@ def _worker_with_shutdown_hook(
|
|||||||
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
||||||
"""A database instance that will not deadlock on shutdown."""
|
"""A database instance that will not deadlock on shutdown."""
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self, recorder_and_worker_thread_ids: set[int], *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
"""Init the executor with a shutdown hook support."""
|
"""Init the executor with a shutdown hook support."""
|
||||||
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
|
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
|
||||||
|
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _adjust_thread_count(self) -> None:
|
def _adjust_thread_count(self) -> None:
|
||||||
@ -54,6 +61,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
|
|||||||
target=_worker_with_shutdown_hook,
|
target=_worker_with_shutdown_hook,
|
||||||
args=(
|
args=(
|
||||||
self._shutdown_hook,
|
self._shutdown_hook,
|
||||||
|
self.recorder_and_worker_thread_ids,
|
||||||
weakref.ref(self, weakref_cb),
|
weakref.ref(self, weakref_cb),
|
||||||
self._work_queue,
|
self._work_queue,
|
||||||
self._initializer,
|
self._initializer,
|
||||||
|
@ -16,8 +16,6 @@ from sqlalchemy.pool import (
|
|||||||
from homeassistant.helpers.frame import report
|
from homeassistant.helpers.frame import report
|
||||||
from homeassistant.util.loop import check_loop
|
from homeassistant.util.loop import check_loop
|
||||||
|
|
||||||
from .const import DB_WORKER_PREFIX
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# For debugging the MutexPool
|
# For debugging the MutexPool
|
||||||
@ -31,7 +29,7 @@ ADVISE_MSG = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
class RecorderPool(SingletonThreadPool, NullPool):
|
||||||
"""A hybrid of NullPool and SingletonThreadPool.
|
"""A hybrid of NullPool and SingletonThreadPool.
|
||||||
|
|
||||||
When called from the creating thread or db executor acts like SingletonThreadPool
|
When called from the creating thread or db executor acts like SingletonThreadPool
|
||||||
@ -39,29 +37,44 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__( # pylint: disable=super-init-not-called
|
def __init__( # pylint: disable=super-init-not-called
|
||||||
self, *args: Any, **kw: Any
|
self,
|
||||||
|
creator: Any,
|
||||||
|
recorder_and_worker_thread_ids: set[int] | None = None,
|
||||||
|
**kw: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create the pool."""
|
"""Create the pool."""
|
||||||
kw["pool_size"] = POOL_SIZE
|
kw["pool_size"] = POOL_SIZE
|
||||||
SingletonThreadPool.__init__(self, *args, **kw)
|
assert (
|
||||||
|
recorder_and_worker_thread_ids is not None
|
||||||
|
), "recorder_and_worker_thread_ids is required"
|
||||||
|
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
|
||||||
|
SingletonThreadPool.__init__(self, creator, **kw)
|
||||||
|
|
||||||
@property
|
def recreate(self) -> "RecorderPool":
|
||||||
def recorder_or_dbworker(self) -> bool:
|
"""Recreate the pool."""
|
||||||
"""Check if the thread is a recorder or dbworker thread."""
|
self.logger.info("Pool recreating")
|
||||||
thread_name = threading.current_thread().name
|
return self.__class__(
|
||||||
return bool(
|
self._creator,
|
||||||
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
|
pool_size=self.size,
|
||||||
|
recycle=self._recycle,
|
||||||
|
echo=self.echo,
|
||||||
|
pre_ping=self._pre_ping,
|
||||||
|
logging_name=self._orig_logging_name,
|
||||||
|
reset_on_return=self._reset_on_return,
|
||||||
|
_dispatch=self.dispatch,
|
||||||
|
dialect=self._dialect,
|
||||||
|
recorder_and_worker_thread_ids=self.recorder_and_worker_thread_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
|
def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
|
||||||
if self.recorder_or_dbworker:
|
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||||
return super()._do_return_conn(record)
|
return super()._do_return_conn(record)
|
||||||
record.close()
|
record.close()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Close the connection."""
|
"""Close the connection."""
|
||||||
if (
|
if (
|
||||||
self.recorder_or_dbworker
|
threading.get_ident() in self.recorder_and_worker_thread_ids
|
||||||
and self._conn
|
and self._conn
|
||||||
and hasattr(self._conn, "current")
|
and hasattr(self._conn, "current")
|
||||||
and (conn := self._conn.current())
|
and (conn := self._conn.current())
|
||||||
@ -70,11 +83,11 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
|
|||||||
|
|
||||||
def dispose(self) -> None:
|
def dispose(self) -> None:
|
||||||
"""Dispose of the connection."""
|
"""Dispose of the connection."""
|
||||||
if self.recorder_or_dbworker:
|
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||||
super().dispose()
|
super().dispose()
|
||||||
|
|
||||||
def _do_get(self) -> ConnectionPoolEntry:
|
def _do_get(self) -> ConnectionPoolEntry:
|
||||||
if self.recorder_or_dbworker:
|
if threading.get_ident() in self.recorder_and_worker_thread_ids:
|
||||||
return super()._do_get()
|
return super()._do_get()
|
||||||
check_loop(
|
check_loop(
|
||||||
self._do_get_db_connection_protected,
|
self._do_get_db_connection_protected,
|
||||||
|
@ -14,6 +14,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
|
||||||
|
from sqlalchemy.pool import QueuePool
|
||||||
|
|
||||||
from homeassistant.components import recorder
|
from homeassistant.components import recorder
|
||||||
from homeassistant.components.recorder import (
|
from homeassistant.components.recorder import (
|
||||||
@ -30,7 +31,6 @@ from homeassistant.components.recorder import (
|
|||||||
db_schema,
|
db_schema,
|
||||||
get_instance,
|
get_instance,
|
||||||
migration,
|
migration,
|
||||||
pool,
|
|
||||||
statistics,
|
statistics,
|
||||||
)
|
)
|
||||||
from homeassistant.components.recorder.const import (
|
from homeassistant.components.recorder.const import (
|
||||||
@ -2265,7 +2265,7 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None:
|
|||||||
def engine_created(*args): ...
|
def engine_created(*args): ...
|
||||||
|
|
||||||
def get_dialect_pool_class(self, *args):
|
def get_dialect_pool_class(self, *args):
|
||||||
return pool.RecorderPool
|
return QueuePool
|
||||||
|
|
||||||
def initialize(*args): ...
|
def initialize(*args): ...
|
||||||
|
|
||||||
|
@ -12,20 +12,32 @@ from homeassistant.components.recorder.pool import RecorderPool
|
|||||||
|
|
||||||
async def test_recorder_pool_called_from_event_loop() -> None:
|
async def test_recorder_pool_called_from_event_loop() -> None:
|
||||||
"""Test we raise an exception when calling from the event loop."""
|
"""Test we raise an exception when calling from the event loop."""
|
||||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
recorder_and_worker_thread_ids: set[int] = set()
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite://",
|
||||||
|
poolclass=RecorderPool,
|
||||||
|
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
|
||||||
|
)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
sessionmaker(bind=engine)().connection()
|
sessionmaker(bind=engine)().connection()
|
||||||
|
|
||||||
|
|
||||||
def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
|
def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
|
||||||
"""Test RecorderPool gives the same connection in the creating thread."""
|
"""Test RecorderPool gives the same connection in the creating thread."""
|
||||||
|
recorder_and_worker_thread_ids: set[int] = set()
|
||||||
engine = create_engine("sqlite://", poolclass=RecorderPool)
|
engine = create_engine(
|
||||||
|
"sqlite://",
|
||||||
|
poolclass=RecorderPool,
|
||||||
|
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
|
||||||
|
)
|
||||||
get_session = sessionmaker(bind=engine)
|
get_session = sessionmaker(bind=engine)
|
||||||
shutdown = False
|
shutdown = False
|
||||||
connections = []
|
connections = []
|
||||||
|
add_thread = False
|
||||||
|
|
||||||
def _get_connection_twice():
|
def _get_connection_twice():
|
||||||
|
if add_thread:
|
||||||
|
recorder_and_worker_thread_ids.add(threading.get_ident())
|
||||||
session = get_session()
|
session = get_session()
|
||||||
connections.append(session.connection().connection.driver_connection)
|
connections.append(session.connection().connection.driver_connection)
|
||||||
session.close()
|
session.close()
|
||||||
@ -44,6 +56,7 @@ def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
|
|||||||
assert "accesses the database without the database executor" in caplog.text
|
assert "accesses the database without the database executor" in caplog.text
|
||||||
assert connections[0] != connections[1]
|
assert connections[0] != connections[1]
|
||||||
|
|
||||||
|
add_thread = True
|
||||||
caplog.clear()
|
caplog.clear()
|
||||||
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
|
||||||
new_thread.start()
|
new_thread.start()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user