Improve recorder and worker thread matching in RecorderPool (#116886)

* Improve recorder and worker thread matching in RecorderPool

Previously we would look at the name of the threads. This
was a brittle if because other integrations may name their
thread Recorder or DbWorker. Instead we now use explict thread
ids which ensures there will never be a conflict

* fix

* fixes

* fixes
This commit is contained in:
J. Nick Koston 2024-05-05 15:25:10 -05:00 committed by GitHub
parent ee031f4850
commit 6339c63176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 23 deletions

View File

@ -187,6 +187,7 @@ class Recorder(threading.Thread):
self.hass = hass self.hass = hass
self.thread_id: int | None = None self.thread_id: int | None = None
self.recorder_and_worker_thread_ids: set[int] = set()
self.auto_purge = auto_purge self.auto_purge = auto_purge
self.auto_repack = auto_repack self.auto_repack = auto_repack
self.keep_days = keep_days self.keep_days = keep_days
@ -294,6 +295,7 @@ class Recorder(threading.Thread):
def async_start_executor(self) -> None: def async_start_executor(self) -> None:
"""Start the executor.""" """Start the executor."""
self._db_executor = DBInterruptibleThreadPoolExecutor( self._db_executor = DBInterruptibleThreadPoolExecutor(
self.recorder_and_worker_thread_ids,
thread_name_prefix=DB_WORKER_PREFIX, thread_name_prefix=DB_WORKER_PREFIX,
max_workers=MAX_DB_EXECUTOR_WORKERS, max_workers=MAX_DB_EXECUTOR_WORKERS,
shutdown_hook=self._shutdown_pool, shutdown_hook=self._shutdown_pool,
@ -717,7 +719,10 @@ class Recorder(threading.Thread):
def _run(self) -> None: def _run(self) -> None:
"""Start processing events to save.""" """Start processing events to save."""
self.thread_id = threading.get_ident() thread_id = threading.get_ident()
self.thread_id = thread_id
self.recorder_and_worker_thread_ids.add(thread_id)
setup_result = self._setup_recorder() setup_result = self._setup_recorder()
if not setup_result: if not setup_result:
@ -1411,6 +1416,9 @@ class Recorder(threading.Thread):
kwargs["pool_reset_on_return"] = None kwargs["pool_reset_on_return"] = None
elif self.db_url.startswith(SQLITE_URL_PREFIX): elif self.db_url.startswith(SQLITE_URL_PREFIX):
kwargs["poolclass"] = RecorderPool kwargs["poolclass"] = RecorderPool
kwargs["recorder_and_worker_thread_ids"] = (
self.recorder_and_worker_thread_ids
)
elif self.db_url.startswith( elif self.db_url.startswith(
( (
MARIADB_URL_PREFIX, MARIADB_URL_PREFIX,

View File

@ -12,9 +12,13 @@ from homeassistant.util.executor import InterruptibleThreadPoolExecutor
def _worker_with_shutdown_hook( def _worker_with_shutdown_hook(
shutdown_hook: Callable[[], None], *args: Any, **kwargs: Any shutdown_hook: Callable[[], None],
recorder_and_worker_thread_ids: set[int],
*args: Any,
**kwargs: Any,
) -> None: ) -> None:
"""Create a worker that calls a function after its finished.""" """Create a worker that calls a function after its finished."""
recorder_and_worker_thread_ids.add(threading.get_ident())
_worker(*args, **kwargs) _worker(*args, **kwargs)
shutdown_hook() shutdown_hook()
@ -22,9 +26,12 @@ def _worker_with_shutdown_hook(
class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor): class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
"""A database instance that will not deadlock on shutdown.""" """A database instance that will not deadlock on shutdown."""
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(
self, recorder_and_worker_thread_ids: set[int], *args: Any, **kwargs: Any
) -> None:
"""Init the executor with a shutdown hook support.""" """Init the executor with a shutdown hook support."""
self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook") self._shutdown_hook: Callable[[], None] = kwargs.pop("shutdown_hook")
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _adjust_thread_count(self) -> None: def _adjust_thread_count(self) -> None:
@ -54,6 +61,7 @@ class DBInterruptibleThreadPoolExecutor(InterruptibleThreadPoolExecutor):
target=_worker_with_shutdown_hook, target=_worker_with_shutdown_hook,
args=( args=(
self._shutdown_hook, self._shutdown_hook,
self.recorder_and_worker_thread_ids,
weakref.ref(self, weakref_cb), weakref.ref(self, weakref_cb),
self._work_queue, self._work_queue,
self._initializer, self._initializer,

View File

@ -16,8 +16,6 @@ from sqlalchemy.pool import (
from homeassistant.helpers.frame import report from homeassistant.helpers.frame import report
from homeassistant.util.loop import check_loop from homeassistant.util.loop import check_loop
from .const import DB_WORKER_PREFIX
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# For debugging the MutexPool # For debugging the MutexPool
@ -31,7 +29,7 @@ ADVISE_MSG = (
) )
class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] class RecorderPool(SingletonThreadPool, NullPool):
"""A hybrid of NullPool and SingletonThreadPool. """A hybrid of NullPool and SingletonThreadPool.
When called from the creating thread or db executor acts like SingletonThreadPool When called from the creating thread or db executor acts like SingletonThreadPool
@ -39,29 +37,44 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
""" """
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, *args: Any, **kw: Any self,
creator: Any,
recorder_and_worker_thread_ids: set[int] | None = None,
**kw: Any,
) -> None: ) -> None:
"""Create the pool.""" """Create the pool."""
kw["pool_size"] = POOL_SIZE kw["pool_size"] = POOL_SIZE
SingletonThreadPool.__init__(self, *args, **kw) assert (
recorder_and_worker_thread_ids is not None
), "recorder_and_worker_thread_ids is required"
self.recorder_and_worker_thread_ids = recorder_and_worker_thread_ids
SingletonThreadPool.__init__(self, creator, **kw)
@property def recreate(self) -> "RecorderPool":
def recorder_or_dbworker(self) -> bool: """Recreate the pool."""
"""Check if the thread is a recorder or dbworker thread.""" self.logger.info("Pool recreating")
thread_name = threading.current_thread().name return self.__class__(
return bool( self._creator,
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) pool_size=self.size,
recycle=self._recycle,
echo=self.echo,
pre_ping=self._pre_ping,
logging_name=self._orig_logging_name,
reset_on_return=self._reset_on_return,
_dispatch=self.dispatch,
dialect=self._dialect,
recorder_and_worker_thread_ids=self.recorder_and_worker_thread_ids,
) )
def _do_return_conn(self, record: ConnectionPoolEntry) -> None: def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
if self.recorder_or_dbworker: if threading.get_ident() in self.recorder_and_worker_thread_ids:
return super()._do_return_conn(record) return super()._do_return_conn(record)
record.close() record.close()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Close the connection.""" """Close the connection."""
if ( if (
self.recorder_or_dbworker threading.get_ident() in self.recorder_and_worker_thread_ids
and self._conn and self._conn
and hasattr(self._conn, "current") and hasattr(self._conn, "current")
and (conn := self._conn.current()) and (conn := self._conn.current())
@ -70,11 +83,11 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
def dispose(self) -> None: def dispose(self) -> None:
"""Dispose of the connection.""" """Dispose of the connection."""
if self.recorder_or_dbworker: if threading.get_ident() in self.recorder_and_worker_thread_ids:
super().dispose() super().dispose()
def _do_get(self) -> ConnectionPoolEntry: def _do_get(self) -> ConnectionPoolEntry:
if self.recorder_or_dbworker: if threading.get_ident() in self.recorder_and_worker_thread_ids:
return super()._do_get() return super()._do_get()
check_loop( check_loop(
self._do_get_db_connection_protected, self._do_get_db_connection_protected,

View File

@ -14,6 +14,7 @@ from unittest.mock import MagicMock, Mock, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
import pytest import pytest
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
from sqlalchemy.pool import QueuePool
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import ( from homeassistant.components.recorder import (
@ -30,7 +31,6 @@ from homeassistant.components.recorder import (
db_schema, db_schema,
get_instance, get_instance,
migration, migration,
pool,
statistics, statistics,
) )
from homeassistant.components.recorder.const import ( from homeassistant.components.recorder.const import (
@ -2265,7 +2265,7 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None:
def engine_created(*args): ... def engine_created(*args): ...
def get_dialect_pool_class(self, *args): def get_dialect_pool_class(self, *args):
return pool.RecorderPool return QueuePool
def initialize(*args): ... def initialize(*args): ...

View File

@ -12,20 +12,32 @@ from homeassistant.components.recorder.pool import RecorderPool
async def test_recorder_pool_called_from_event_loop() -> None: async def test_recorder_pool_called_from_event_loop() -> None:
"""Test we raise an exception when calling from the event loop.""" """Test we raise an exception when calling from the event loop."""
engine = create_engine("sqlite://", poolclass=RecorderPool) recorder_and_worker_thread_ids: set[int] = set()
engine = create_engine(
"sqlite://",
poolclass=RecorderPool,
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
sessionmaker(bind=engine)().connection() sessionmaker(bind=engine)().connection()
def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None: def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
"""Test RecorderPool gives the same connection in the creating thread.""" """Test RecorderPool gives the same connection in the creating thread."""
recorder_and_worker_thread_ids: set[int] = set()
engine = create_engine("sqlite://", poolclass=RecorderPool) engine = create_engine(
"sqlite://",
poolclass=RecorderPool,
recorder_and_worker_thread_ids=recorder_and_worker_thread_ids,
)
get_session = sessionmaker(bind=engine) get_session = sessionmaker(bind=engine)
shutdown = False shutdown = False
connections = [] connections = []
add_thread = False
def _get_connection_twice(): def _get_connection_twice():
if add_thread:
recorder_and_worker_thread_ids.add(threading.get_ident())
session = get_session() session = get_session()
connections.append(session.connection().connection.driver_connection) connections.append(session.connection().connection.driver_connection)
session.close() session.close()
@ -44,6 +56,7 @@ def test_recorder_pool(caplog: pytest.LogCaptureFixture) -> None:
assert "accesses the database without the database executor" in caplog.text assert "accesses the database without the database executor" in caplog.text
assert connections[0] != connections[1] assert connections[0] != connections[1]
add_thread = True
caplog.clear() caplog.clear()
new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX) new_thread = threading.Thread(target=_get_connection_twice, name=DB_WORKER_PREFIX)
new_thread.start() new_thread.start()