From 9b2e9b8746769291daf2cb2881b2259533c3e7ef Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 12 Apr 2023 12:09:15 -1000 Subject: [PATCH] Update typing on recorder pool for sqlalchemy 2.0 (#91244) --- homeassistant/components/recorder/pool.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py index 02ba7545f89..09b113f03eb 100644 --- a/homeassistant/components/recorder/pool.py +++ b/homeassistant/components/recorder/pool.py @@ -52,7 +52,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) ) - def _do_return_conn(self, record: ConnectionPoolEntry) -> Any: + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: if self.recorder_or_dbworker: return super()._do_return_conn(record) record.close() @@ -72,8 +72,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] if self.recorder_or_dbworker: super().dispose() - # Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy - def _do_get(self) -> Any: + def _do_get(self) -> ConnectionPoolEntry: if self.recorder_or_dbworker: return super()._do_get() check_loop( @@ -83,7 +82,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc] ) return self._do_get_db_connection_protected() - def _do_get_db_connection_protected(self) -> Any: + def _do_get_db_connection_protected(self) -> ConnectionPoolEntry: report( ( "accesses the database without the database executor; " @@ -106,7 +105,7 @@ class MutexPool(StaticPool): _reference_counter = 0 pool_lock: threading.RLock - def _do_return_conn(self, record: ConnectionPoolEntry) -> Any: + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: if DEBUG_MUTEX_POOL_TRACE: trace = traceback.extract_stack() trace_msg = "\n" + "".join(traceback.format_list(trace[:-1])) @@ -124,7 +123,7 @@ class MutexPool(StaticPool): ) MutexPool.pool_lock.release() - def _do_get(self) -> Any: + def _do_get(self) -> ConnectionPoolEntry: if DEBUG_MUTEX_POOL_TRACE: trace = traceback.extract_stack() trace_msg = "".join(traceback.format_list(trace[:-1]))