Update typing on recorder pool for sqlalchemy 2.0 (#91244)

This commit is contained in:
J. Nick Koston 2023-04-12 12:09:15 -10:00 committed by GitHub
parent d483ad820c
commit 9b2e9b8746
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,7 +52,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX) thread_name == "Recorder" or thread_name.startswith(DB_WORKER_PREFIX)
) )
def _do_return_conn(self, record: ConnectionPoolEntry) -> Any: def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
return super()._do_return_conn(record) return super()._do_return_conn(record)
record.close() record.close()
@ -72,8 +72,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
super().dispose() super().dispose()
# Any can be switched out for ConnectionPoolEntry in the next version of sqlalchemy def _do_get(self) -> ConnectionPoolEntry:
def _do_get(self) -> Any:
if self.recorder_or_dbworker: if self.recorder_or_dbworker:
return super()._do_get() return super()._do_get()
check_loop( check_loop(
@ -83,7 +82,7 @@ class RecorderPool(SingletonThreadPool, NullPool): # type: ignore[misc]
) )
return self._do_get_db_connection_protected() return self._do_get_db_connection_protected()
def _do_get_db_connection_protected(self) -> Any: def _do_get_db_connection_protected(self) -> ConnectionPoolEntry:
report( report(
( (
"accesses the database without the database executor; " "accesses the database without the database executor; "
@ -106,7 +105,7 @@ class MutexPool(StaticPool):
_reference_counter = 0 _reference_counter = 0
pool_lock: threading.RLock pool_lock: threading.RLock
def _do_return_conn(self, record: ConnectionPoolEntry) -> Any: def _do_return_conn(self, record: ConnectionPoolEntry) -> None:
if DEBUG_MUTEX_POOL_TRACE: if DEBUG_MUTEX_POOL_TRACE:
trace = traceback.extract_stack() trace = traceback.extract_stack()
trace_msg = "\n" + "".join(traceback.format_list(trace[:-1])) trace_msg = "\n" + "".join(traceback.format_list(trace[:-1]))
@ -124,7 +123,7 @@ class MutexPool(StaticPool):
) )
MutexPool.pool_lock.release() MutexPool.pool_lock.release()
def _do_get(self) -> Any: def _do_get(self) -> ConnectionPoolEntry:
if DEBUG_MUTEX_POOL_TRACE: if DEBUG_MUTEX_POOL_TRACE:
trace = traceback.extract_stack() trace = traceback.extract_stack()
trace_msg = "".join(traceback.format_list(trace[:-1])) trace_msg = "".join(traceback.format_list(trace[:-1]))