diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index d046fe30d5c..d68b993444b 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -16,6 +16,7 @@ from typing import Any, TypeVar from lru import LRU # pylint: disable=no-name-in-module from sqlalchemy import create_engine, event as sqlalchemy_event, exc, func, select +from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm.session import Session @@ -33,7 +34,14 @@ from homeassistant.const import ( EVENT_TIME_CHANGED, MATCH_ALL, ) -from homeassistant.core import CoreState, HomeAssistant, ServiceCall, callback +from homeassistant.core import ( + CALLBACK_TYPE, + CoreState, + Event, + HomeAssistant, + ServiceCall, + callback, +) import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entityfilter import ( INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA, @@ -306,7 +314,7 @@ async def _process_recorder_platform(hass, domain, platform): @callback -def _async_register_services(hass, instance): +def _async_register_services(hass: HomeAssistant, instance: Recorder) -> None: """Register recorder services.""" async def async_handle_purge_service(service: ServiceCall) -> None: @@ -524,9 +532,9 @@ class StopTask(RecorderTask): @dataclass class EventTask(RecorderTask): - """An object to insert into the recorder queue to stop the event handler.""" + """An event to be processed.""" - event: bool + event: Event commit_before = False def run(self, instance: Recorder) -> None: @@ -567,7 +575,7 @@ class Recorder(threading.Thread): self.async_db_ready: asyncio.Future = asyncio.Future() self.async_recorder_ready = asyncio.Event() self._queue_watch = threading.Event() - self.engine: Any = None + self.engine: Engine | None = None self.run_info: Any = None self.entity_filter = entity_filter @@ -580,13 +588,13 @@ class Recorder(threading.Thread): self._state_attributes_ids: LRU = LRU(STATE_ATTRIBUTES_ID_CACHE_SIZE) self._pending_state_attributes: dict[str, StateAttributes] = {} self._pending_expunge: list[States] = [] - self.event_session = None - self.get_session = None - self._completed_first_database_setup = None - self._event_listener = None + self.event_session: Session | None = None + self.get_session: Callable[[], Session] | None = None + self._completed_first_database_setup: bool | None = None + self._event_listener: CALLBACK_TYPE | None = None self.async_migration_event = asyncio.Event() self.migration_in_progress = False - self._queue_watcher = None + self._queue_watcher: CALLBACK_TYPE | None = None self._db_supports_row_number = True self._database_lock_task: DatabaseLockTask | None = None self._db_executor: DBInterruptibleThreadPoolExecutor | None = None @@ -651,7 +659,7 @@ class Recorder(threading.Thread): self._async_stop_queue_watcher_and_event_listener() @callback - def _async_stop_queue_watcher_and_event_listener(self): + def _async_stop_queue_watcher_and_event_listener(self) -> None: """Stop watching the queue and listening for events.""" if self._queue_watcher: self._queue_watcher() @@ -661,7 +669,7 @@ class Recorder(threading.Thread): self._event_listener = None @callback - def _async_event_filter(self, event) -> bool: + def _async_event_filter(self, event: Event) -> bool: """Filter events.""" if event.event_type in self.exclude_t: return False @@ -702,7 +710,9 @@ class Recorder(threading.Thread): self.queue.put(StatisticsTask(start)) @callback - def async_register(self, shutdown_task, hass_started): + def async_register( + self, shutdown_task: object, hass_started: concurrent.futures.Future + ) -> None: """Post connection initialize.""" def _empty_queue(event): @@ -746,7 +756,7 @@ class Recorder(threading.Thread): self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, async_hass_started) @callback - def async_connection_failed(self): + def async_connection_failed(self) -> None: """Connect failed tasks.""" self.async_db_ready.set_result(False) persistent_notification.async_create( @@ -757,19 +767,19 @@ class Recorder(threading.Thread): self._async_stop_queue_watcher_and_event_listener() @callback - def async_connection_success(self): + def async_connection_success(self) -> None: """Connect success tasks.""" self.async_db_ready.set_result(True) self.async_start_executor() @callback - def _async_recorder_ready(self): + def _async_recorder_ready(self) -> None: """Finish start and mark recorder ready.""" self._async_setup_periodic_tasks() self.async_recorder_ready.set() @callback - def async_nightly_tasks(self, now): + def async_nightly_tasks(self, now: datetime) -> None: """Trigger the purge.""" if self.auto_purge: # Purge will schedule the perodic cleanups @@ -781,7 +791,7 @@ class Recorder(threading.Thread): self.queue.put(PerodicCleanupTask()) @callback - def async_periodic_statistics(self, now): + def async_periodic_statistics(self, now: datetime) -> None: """Trigger the hourly statistics run.""" start = statistics.get_start_time() self.queue.put(StatisticsTask(start)) @@ -807,7 +817,7 @@ class Recorder(threading.Thread): self.queue.put(ExternalStatisticsTask(metadata, stats)) @callback - def _async_setup_periodic_tasks(self): + def _async_setup_periodic_tasks(self) -> None: """Prepare periodic tasks.""" if self.hass.is_stopping or not self.get_session: # Home Assistant is shutting down @@ -823,10 +833,10 @@ class Recorder(threading.Thread): self.hass, self.async_periodic_statistics, minute=range(0, 60, 5), second=10 ) - def run(self): + def run(self) -> None: """Start processing events to save.""" shutdown_task = object() - hass_started = concurrent.futures.Future() + hass_started: concurrent.futures.Future = concurrent.futures.Future() self.hass.add_job(self.async_register, shutdown_task, hass_started) @@ -875,7 +885,7 @@ class Recorder(threading.Thread): self.hass.add_job(self._async_recorder_ready) self._run_event_loop() - def _run_event_loop(self): + def _run_event_loop(self) -> None: """Run the event loop for the recorder.""" # Use a session for the event read loop # with a commit every time the event time @@ -890,7 +900,7 @@ class Recorder(threading.Thread): self._shutdown() - def _process_one_task_or_recover(self, task: RecorderTask): + def _process_one_task_or_recover(self, task: RecorderTask) -> None: """Process an event, reconnect, or recover a malformed database.""" try: # If its not an event, commit everything @@ -931,11 +941,11 @@ class Recorder(threading.Thread): return None @callback - def _async_migration_started(self): + def _async_migration_started(self) -> None: """Set the migration started event.""" self.async_migration_event.set() - def _migrate_schema_and_setup_run(self, current_version) -> bool: + def _migrate_schema_and_setup_run(self, current_version: int) -> bool: """Migrate schema to the latest version.""" persistent_notification.create( self.hass, @@ -962,7 +972,7 @@ class Recorder(threading.Thread): self.migration_in_progress = False persistent_notification.dismiss(self.hass, "recorder_database_migration") - def _lock_database(self, task: DatabaseLockTask): + def _lock_database(self, task: DatabaseLockTask) -> None: @callback def _async_set_database_locked(task: DatabaseLockTask): task.database_locked.set() @@ -985,7 +995,7 @@ class Recorder(threading.Thread): self.queue.qsize(), ) - def _process_one_event(self, event): + def _process_one_event(self, event: Event) -> None: if event.event_type == EVENT_TIME_CHANGED: self._keepalive_count += 1 if self._keepalive_count >= KEEPALIVE_TIME: @@ -1000,6 +1010,7 @@ class Recorder(threading.Thread): if not self.enabled: return + assert self.event_session is not None try: if event.event_type == EVENT_STATE_CHANGED: @@ -1071,7 +1082,7 @@ class Recorder(threading.Thread): if not self.commit_interval: self._commit_event_session_or_retry() - def _handle_database_error(self, err): + def _handle_database_error(self, err: Exception) -> bool: """Handle a database error that may result in moving away the corrupt db.""" if isinstance(err.__cause__, sqlite3.DatabaseError): _LOGGER.exception( @@ -1081,7 +1092,7 @@ class Recorder(threading.Thread): return True return False - def _commit_event_session_or_retry(self): + def _commit_event_session_or_retry(self) -> None: """Commit the event session if there is work to do.""" if not self.event_session or ( not self.event_session.new and not self.event_session.dirty @@ -1105,7 +1116,8 @@ class Recorder(threading.Thread): tries += 1 time.sleep(self.db_retry_wait) - def _commit_event_session(self): + def _commit_event_session(self) -> None: + assert self.event_session is not None self._commits_without_expire += 1 if self._pending_expunge: @@ -1120,7 +1132,7 @@ class Recorder(threading.Thread): # We just committed the state attributes to the database # and we now know the attributes_ids. We can save - # a many selects for matching attributes by loading them + # many selects for matching attributes by loading them # into the LRU cache now. for state_attr in self._pending_state_attributes.values(): self._state_attributes_ids[ @@ -1135,7 +1147,7 @@ class Recorder(threading.Thread): self._commits_without_expire = 0 self.event_session.expire_all() - def _handle_sqlite_corruption(self): + def _handle_sqlite_corruption(self) -> None: """Handle the sqlite3 database being corrupt.""" self._close_event_session() self._close_connection() @@ -1143,7 +1155,7 @@ class Recorder(threading.Thread): self._setup_recorder() self._setup_run() - def _close_event_session(self): + def _close_event_session(self) -> None: """Close the event session.""" self._old_states = {} self._state_attributes_ids = {} @@ -1160,27 +1172,29 @@ class Recorder(threading.Thread): "Error while rolling back and closing the event session: %s", err ) - def _reopen_event_session(self): + def _reopen_event_session(self) -> None: """Rollback the event session and reopen it after a failure.""" self._close_event_session() self._open_event_session() - def _open_event_session(self): + def _open_event_session(self) -> None: """Open the event session.""" + assert self.get_session is not None self.event_session = self.get_session() self.event_session.expire_on_commit = False - def _send_keep_alive(self): + def _send_keep_alive(self) -> None: """Send a keep alive to keep the db connection open.""" + assert self.event_session is not None _LOGGER.debug("Sending keepalive") self.event_session.connection().scalar(select([1])) @callback - def event_listener(self, event): + def event_listener(self, event: Event) -> None: """Listen for new events and put them in the process queue.""" self.queue.put(EventTask(event)) - def block_till_done(self): + def block_till_done(self) -> None: """Block till all events processed. This is only called in tests. @@ -1244,9 +1258,9 @@ class Recorder(threading.Thread): return success - def _setup_connection(self): + def _setup_connection(self) -> None: """Ensure database is ready to fly.""" - kwargs = {} + kwargs: dict[str, Any] = {} self._completed_first_database_setup = False def setup_recorder_connection(dbapi_connection, connection_record): @@ -1280,20 +1294,22 @@ class Recorder(threading.Thread): _LOGGER.debug("Connected to recorder database") @property - def _using_file_sqlite(self): + def _using_file_sqlite(self) -> bool: """Short version to check if we are using sqlite3 as a file.""" return self.db_url != SQLITE_URL_PREFIX and self.db_url.startswith( SQLITE_URL_PREFIX ) - def _close_connection(self): + def _close_connection(self) -> None: """Close the connection.""" + assert self.engine is not None self.engine.dispose() self.engine = None self.get_session = None - def _setup_run(self): + def _setup_run(self) -> None: """Log the start of the current run and schedule any needed jobs.""" + assert self.get_session is not None with session_scope(session=self.get_session()) as session: start = self.recording_start end_incomplete_runs(session, start) @@ -1324,7 +1340,7 @@ class Recorder(threading.Thread): self.queue.put(StatisticsTask(start)) start = end - def _end_session(self): + def _end_session(self) -> None: """End the recorder session.""" if self.event_session is None: return @@ -1338,7 +1354,7 @@ class Recorder(threading.Thread): self.run_info = None - def _shutdown(self): + def _shutdown(self) -> None: """Save end time for current run.""" self.hass.add_job(self._async_stop_queue_watcher_and_event_listener) self._stop_executor() @@ -1346,6 +1362,6 @@ class Recorder(threading.Thread): self._close_connection() @property - def recording(self): + def recording(self) -> bool: """Return if the recorder is recording.""" return self._event_listener is not None diff --git a/homeassistant/components/recorder/repack.py b/homeassistant/components/recorder/repack.py index 68d7d5954c9..95df0681ddb 100644 --- a/homeassistant/components/recorder/repack.py +++ b/homeassistant/components/recorder/repack.py @@ -12,15 +12,17 @@ _LOGGER = logging.getLogger(__name__) def repack_database(instance: Recorder) -> None: """Repack based on engine type.""" + assert instance.engine is not None + dialect_name = instance.engine.dialect.name # Execute sqlite command to free up space on disk - if instance.engine.dialect.name == "sqlite": + if dialect_name == "sqlite": _LOGGER.debug("Vacuuming SQL DB to free space") instance.engine.execute("VACUUM") return # Execute postgresql vacuum command to free up space on disk - if instance.engine.dialect.name == "postgresql": + if dialect_name == "postgresql": _LOGGER.debug("Vacuuming SQL DB to free space") with instance.engine.connect().execution_options( isolation_level="AUTOCOMMIT" @@ -29,7 +31,7 @@ def repack_database(instance: Recorder) -> None: return # Optimize mysql / mariadb tables to free up space on disk - if instance.engine.dialect.name == "mysql": + if dialect_name == "mysql": _LOGGER.debug("Optimizing SQL DB to free space") instance.engine.execute("OPTIMIZE TABLE states, events, recorder_runs") return diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index b27a08f489c..50e987a5533 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -1247,19 +1247,19 @@ def _filter_unique_constraint_integrity_error( if not isinstance(err, StatementError): return False + assert instance.engine is not None + dialect_name = instance.engine.dialect.name + ignore = False - if ( - instance.engine.dialect.name == "sqlite" - and "UNIQUE constraint failed" in str(err) - ): + if dialect_name == "sqlite" and "UNIQUE constraint failed" in str(err): ignore = True if ( - instance.engine.dialect.name == "postgresql" + dialect_name == "postgresql" and hasattr(err.orig, "pgcode") and err.orig.pgcode == "23505" ): ignore = True - if instance.engine.dialect.name == "mysql" and hasattr(err.orig, "args"): + if dialect_name == "mysql" and hasattr(err.orig, "args"): with contextlib.suppress(TypeError): if err.orig.args[0] == 1062: ignore = True