From eb77f8db8559dba95e5e36c8a9314f89e1ae82b1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 4 May 2022 12:22:50 -0500 Subject: [PATCH] Complete strict typing for recorder (#71274) * Complete strict typing for recorder * update tests * Update tests/components/recorder/test_migrate.py Co-authored-by: Martin Hjelmare * Update tests/components/recorder/test_migrate.py Co-authored-by: Martin Hjelmare * Remove the asserts * remove ignore comments Co-authored-by: Martin Hjelmare --- .strict-typing | 18 +- homeassistant/components/recorder/core.py | 22 +- .../components/recorder/migration.py | 208 ++++++++++-------- homeassistant/components/recorder/purge.py | 4 +- .../components/recorder/statistics.py | 16 +- homeassistant/components/recorder/tasks.py | 2 - mypy.ini | 178 +-------------- tests/components/recorder/test_init.py | 2 + tests/components/recorder/test_migrate.py | 23 +- tests/components/recorder/test_statistics.py | 2 +- 10 files changed, 166 insertions(+), 309 deletions(-) diff --git a/.strict-typing b/.strict-typing index f42bd4a4ab1..67efdfa7953 100644 --- a/.strict-typing +++ b/.strict-typing @@ -177,23 +177,7 @@ homeassistant.components.pure_energie.* homeassistant.components.rainmachine.* homeassistant.components.rdw.* homeassistant.components.recollect_waste.* -homeassistant.components.recorder -homeassistant.components.recorder.const -homeassistant.components.recorder.core -homeassistant.components.recorder.backup -homeassistant.components.recorder.executor -homeassistant.components.recorder.history -homeassistant.components.recorder.models -homeassistant.components.recorder.pool -homeassistant.components.recorder.purge -homeassistant.components.recorder.repack -homeassistant.components.recorder.run_history -homeassistant.components.recorder.services -homeassistant.components.recorder.statistics -homeassistant.components.recorder.system_health -homeassistant.components.recorder.tasks -homeassistant.components.recorder.util -homeassistant.components.recorder.websocket_api +homeassistant.components.recorder.* homeassistant.components.remote.* homeassistant.components.renault.* homeassistant.components.ridwell.* diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 84509a1bd53..af368b909b7 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -171,7 +171,7 @@ class Recorder(threading.Thread): self._pending_event_data: dict[str, EventData] = {} self._pending_expunge: list[States] = [] self.event_session: Session | None = None - self.get_session: Callable[[], Session] | None = None + self._get_session: Callable[[], Session] | None = None self._completed_first_database_setup: bool | None = None self.async_migration_event = asyncio.Event() self.migration_in_progress = False @@ -205,6 +205,12 @@ class Recorder(threading.Thread): """Return if the recorder is recording.""" return self._event_listener is not None + def get_session(self) -> Session: + """Get a new sqlalchemy session.""" + if self._get_session is None: + raise RuntimeError("The database connection has not been established") + return self._get_session() + def queue_task(self, task: RecorderTask) -> None: """Add a task to the recorder queue.""" self._queue.put(task) @@ -459,7 +465,7 @@ class Recorder(threading.Thread): @callback def _async_setup_periodic_tasks(self) -> None: """Prepare periodic tasks.""" - if self.hass.is_stopping or not self.get_session: + if self.hass.is_stopping or not self._get_session: # Home Assistant is shutting down return @@ -591,7 +597,7 @@ class Recorder(threading.Thread): while tries <= self.db_max_retries: try: self._setup_connection() - return migration.get_schema_version(self) + return migration.get_schema_version(self.get_session) except Exception as err: # pylint: disable=broad-except _LOGGER.exception( "Error during connection setup: %s (retrying in %s seconds)", @@ -619,7 +625,9 @@ class Recorder(threading.Thread): self.hass.add_job(self._async_migration_started) try: - migration.migrate_schema(self, current_version) + migration.migrate_schema( + self.hass, self.engine, self.get_session, current_version + ) except exc.DatabaseError as err: if self._handle_database_error(err): return True @@ -896,7 +904,6 @@ class Recorder(threading.Thread): def _open_event_session(self) -> None: """Open the event session.""" - assert self.get_session is not None self.event_session = self.get_session() self.event_session.expire_on_commit = False @@ -1011,7 +1018,7 @@ class Recorder(threading.Thread): sqlalchemy_event.listen(self.engine, "connect", setup_recorder_connection) Base.metadata.create_all(self.engine) - self.get_session = scoped_session(sessionmaker(bind=self.engine, future=True)) + self._get_session = scoped_session(sessionmaker(bind=self.engine, future=True)) _LOGGER.debug("Connected to recorder database") def _close_connection(self) -> None: @@ -1019,11 +1026,10 @@ class Recorder(threading.Thread): assert self.engine is not None self.engine.dispose() self.engine = None - self.get_session = None + self._get_session = None def _setup_run(self) -> None: """Log the start of the current run and schedule any needed jobs.""" - assert self.get_session is not None with session_scope(session=self.get_session()) as session: end_incomplete_runs(session, self.run_history.recording_start) self.run_history.start(session) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 7835f5320b9..b38bb89b5b9 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -1,11 +1,13 @@ """Schema migration helpers.""" +from collections.abc import Callable, Iterable import contextlib from datetime import timedelta import logging -from typing import Any +from typing import cast import sqlalchemy from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text +from sqlalchemy.engine import Engine from sqlalchemy.exc import ( DatabaseError, InternalError, @@ -13,9 +15,12 @@ from sqlalchemy.exc import ( ProgrammingError, SQLAlchemyError, ) +from sqlalchemy.orm.session import Session from sqlalchemy.schema import AddConstraint, DropConstraint from sqlalchemy.sql.expression import true +from homeassistant.core import HomeAssistant + from .models import ( SCHEMA_VERSION, TABLE_STATES, @@ -33,7 +38,7 @@ from .util import session_scope _LOGGER = logging.getLogger(__name__) -def raise_if_exception_missing_str(ex, match_substrs): +def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) -> None: """Raise an exception if the exception and cause do not contain the match substrs.""" lower_ex_strs = [str(ex).lower(), str(ex.__cause__).lower()] for str_sub in match_substrs: @@ -44,10 +49,9 @@ def raise_if_exception_missing_str(ex, match_substrs): raise ex -def get_schema_version(instance: Any) -> int: +def get_schema_version(session_maker: Callable[[], Session]) -> int: """Get the schema version.""" - assert instance.get_session is not None - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: res = ( session.query(SchemaChanges) .order_by(SchemaChanges.change_id.desc()) @@ -61,7 +65,7 @@ def get_schema_version(instance: Any) -> int: "No schema version found. Inspected version: %s", current_version ) - return current_version + return cast(int, current_version) def schema_is_current(current_version: int) -> bool: @@ -69,21 +73,27 @@ def schema_is_current(current_version: int) -> bool: return current_version == SCHEMA_VERSION -def migrate_schema(instance: Any, current_version: int) -> None: +def migrate_schema( + hass: HomeAssistant, + engine: Engine, + session_maker: Callable[[], Session], + current_version: int, +) -> None: """Check if the schema needs to be upgraded.""" - assert instance.get_session is not None _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) for version in range(current_version, SCHEMA_VERSION): new_version = version + 1 _LOGGER.info("Upgrading recorder db schema to version %s", new_version) - _apply_update(instance, new_version, current_version) - with session_scope(session=instance.get_session()) as session: + _apply_update(hass, engine, session_maker, new_version, current_version) + with session_scope(session=session_maker()) as session: session.add(SchemaChanges(schema_version=new_version)) _LOGGER.info("Upgrade to version %s done", new_version) -def _create_index(instance, table_name, index_name): +def _create_index( + session_maker: Callable[[], Session], table_name: str, index_name: str +) -> None: """Create an index for the specified table. The index name should match the name given for the index @@ -104,7 +114,7 @@ def _create_index(instance, table_name, index_name): "be patient!", index_name, ) - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() index.create(connection) @@ -117,7 +127,9 @@ def _create_index(instance, table_name, index_name): _LOGGER.debug("Finished creating %s", index_name) -def _drop_index(instance, table_name, index_name): +def _drop_index( + session_maker: Callable[[], Session], table_name: str, index_name: str +) -> None: """Drop an index from a specified table. There is no universal way to do something like `DROP INDEX IF EXISTS` @@ -132,7 +144,7 @@ def _drop_index(instance, table_name, index_name): success = False # Engines like DB2/Oracle - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute(text(f"DROP INDEX {index_name}")) @@ -143,7 +155,7 @@ def _drop_index(instance, table_name, index_name): # Engines like SQLite, SQL Server if not success: - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute( @@ -160,7 +172,7 @@ def _drop_index(instance, table_name, index_name): if not success: # Engines like MySQL, MS Access - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute( @@ -194,7 +206,9 @@ def _drop_index(instance, table_name, index_name): ) -def _add_columns(instance, table_name, columns_def): +def _add_columns( + session_maker: Callable[[], Session], table_name: str, columns_def: list[str] +) -> None: """Add columns to a table.""" _LOGGER.warning( "Adding columns %s to table %s. Note: this can take several " @@ -206,7 +220,7 @@ def _add_columns(instance, table_name, columns_def): columns_def = [f"ADD {col_def}" for col_def in columns_def] - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute( @@ -223,7 +237,7 @@ def _add_columns(instance, table_name, columns_def): _LOGGER.info("Unable to use quick column add. Adding 1 by 1") for column_def in columns_def: - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute( @@ -242,7 +256,12 @@ def _add_columns(instance, table_name, columns_def): ) -def _modify_columns(instance, engine, table_name, columns_def): +def _modify_columns( + session_maker: Callable[[], Session], + engine: Engine, + table_name: str, + columns_def: list[str], +) -> None: """Modify columns in a table.""" if engine.dialect.name == "sqlite": _LOGGER.debug( @@ -274,7 +293,7 @@ def _modify_columns(instance, engine, table_name, columns_def): else: columns_def = [f"MODIFY {col_def}" for col_def in columns_def] - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute( @@ -289,7 +308,7 @@ def _modify_columns(instance, engine, table_name, columns_def): _LOGGER.info("Unable to use quick column modify. Modifying 1 by 1") for column_def in columns_def: - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute( @@ -305,7 +324,9 @@ def _modify_columns(instance, engine, table_name, columns_def): ) -def _update_states_table_with_foreign_key_options(instance, engine): +def _update_states_table_with_foreign_key_options( + session_maker: Callable[[], Session], engine: Engine +) -> None: """Add the options to foreign key constraints.""" inspector = sqlalchemy.inspect(engine) alters = [] @@ -333,7 +354,7 @@ def _update_states_table_with_foreign_key_options(instance, engine): ) for alter in alters: - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute(DropConstraint(alter["old_fk"])) @@ -346,7 +367,9 @@ def _update_states_table_with_foreign_key_options(instance, engine): ) -def _drop_foreign_key_constraints(instance, engine, table, columns): +def _drop_foreign_key_constraints( + session_maker: Callable[[], Session], engine: Engine, table: str, columns: list[str] +) -> None: """Drop foreign key constraints for a table on specific columns.""" inspector = sqlalchemy.inspect(engine) drops = [] @@ -364,7 +387,7 @@ def _drop_foreign_key_constraints(instance, engine, table, columns): ) for drop in drops: - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: try: connection = session.connection() connection.execute(DropConstraint(drop)) @@ -376,19 +399,24 @@ def _drop_foreign_key_constraints(instance, engine, table, columns): ) -def _apply_update(instance, new_version, old_version): # noqa: C901 +def _apply_update( # noqa: C901 + hass: HomeAssistant, + engine: Engine, + session_maker: Callable[[], Session], + new_version: int, + old_version: int, +) -> None: """Perform operations to bring schema up to date.""" - engine = instance.engine dialect = engine.dialect.name big_int = "INTEGER(20)" if dialect == "mysql" else "INTEGER" if new_version == 1: - _create_index(instance, "events", "ix_events_time_fired") + _create_index(session_maker, "events", "ix_events_time_fired") elif new_version == 2: # Create compound start/end index for recorder_runs - _create_index(instance, "recorder_runs", "ix_recorder_runs_start_end") + _create_index(session_maker, "recorder_runs", "ix_recorder_runs_start_end") # Create indexes for states - _create_index(instance, "states", "ix_states_last_updated") + _create_index(session_maker, "states", "ix_states_last_updated") elif new_version == 3: # There used to be a new index here, but it was removed in version 4. pass @@ -398,41 +426,41 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 if old_version == 3: # Remove index that was added in version 3 - _drop_index(instance, "states", "ix_states_created_domain") + _drop_index(session_maker, "states", "ix_states_created_domain") if old_version == 2: # Remove index that was added in version 2 - _drop_index(instance, "states", "ix_states_entity_id_created") + _drop_index(session_maker, "states", "ix_states_entity_id_created") # Remove indexes that were added in version 0 - _drop_index(instance, "states", "states__state_changes") - _drop_index(instance, "states", "states__significant_changes") - _drop_index(instance, "states", "ix_states_entity_id_created") + _drop_index(session_maker, "states", "states__state_changes") + _drop_index(session_maker, "states", "states__significant_changes") + _drop_index(session_maker, "states", "ix_states_entity_id_created") - _create_index(instance, "states", "ix_states_entity_id_last_updated") + _create_index(session_maker, "states", "ix_states_entity_id_last_updated") elif new_version == 5: # Create supporting index for States.event_id foreign key - _create_index(instance, "states", "ix_states_event_id") + _create_index(session_maker, "states", "ix_states_event_id") elif new_version == 6: _add_columns( - instance, + session_maker, "events", ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ) - _create_index(instance, "events", "ix_events_context_id") - _create_index(instance, "events", "ix_events_context_user_id") + _create_index(session_maker, "events", "ix_events_context_id") + _create_index(session_maker, "events", "ix_events_context_user_id") _add_columns( - instance, + session_maker, "states", ["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"], ) - _create_index(instance, "states", "ix_states_context_id") - _create_index(instance, "states", "ix_states_context_user_id") + _create_index(session_maker, "states", "ix_states_context_id") + _create_index(session_maker, "states", "ix_states_context_user_id") elif new_version == 7: - _create_index(instance, "states", "ix_states_entity_id") + _create_index(session_maker, "states", "ix_states_entity_id") elif new_version == 8: - _add_columns(instance, "events", ["context_parent_id CHARACTER(36)"]) - _add_columns(instance, "states", ["old_state_id INTEGER"]) - _create_index(instance, "events", "ix_events_context_parent_id") + _add_columns(session_maker, "events", ["context_parent_id CHARACTER(36)"]) + _add_columns(session_maker, "states", ["old_state_id INTEGER"]) + _create_index(session_maker, "events", "ix_events_context_parent_id") elif new_version == 9: # We now get the context from events with a join # since its always there on state_changed events @@ -443,35 +471,35 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 # sqlalchemy alembic to make that work # # no longer dropping ix_states_context_id since its recreated in 28 - _drop_index(instance, "states", "ix_states_context_user_id") + _drop_index(session_maker, "states", "ix_states_context_user_id") # This index won't be there if they were not running # nightly but we don't treat that as a critical issue - _drop_index(instance, "states", "ix_states_context_parent_id") + _drop_index(session_maker, "states", "ix_states_context_parent_id") # Redundant keys on composite index: # We already have ix_states_entity_id_last_updated - _drop_index(instance, "states", "ix_states_entity_id") - _create_index(instance, "events", "ix_events_event_type_time_fired") - _drop_index(instance, "events", "ix_events_event_type") + _drop_index(session_maker, "states", "ix_states_entity_id") + _create_index(session_maker, "events", "ix_events_event_type_time_fired") + _drop_index(session_maker, "events", "ix_events_event_type") elif new_version == 10: # Now done in step 11 pass elif new_version == 11: - _create_index(instance, "states", "ix_states_old_state_id") - _update_states_table_with_foreign_key_options(instance, engine) + _create_index(session_maker, "states", "ix_states_old_state_id") + _update_states_table_with_foreign_key_options(session_maker, engine) elif new_version == 12: if engine.dialect.name == "mysql": - _modify_columns(instance, engine, "events", ["event_data LONGTEXT"]) - _modify_columns(instance, engine, "states", ["attributes LONGTEXT"]) + _modify_columns(session_maker, engine, "events", ["event_data LONGTEXT"]) + _modify_columns(session_maker, engine, "states", ["attributes LONGTEXT"]) elif new_version == 13: if engine.dialect.name == "mysql": _modify_columns( - instance, + session_maker, engine, "events", ["time_fired DATETIME(6)", "created DATETIME(6)"], ) _modify_columns( - instance, + session_maker, engine, "states", [ @@ -481,12 +509,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 ], ) elif new_version == 14: - _modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"]) + _modify_columns(session_maker, engine, "events", ["event_type VARCHAR(64)"]) elif new_version == 15: # This dropped the statistics table, done again in version 18. pass elif new_version == 16: - _drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"]) + _drop_foreign_key_constraints( + session_maker, engine, TABLE_STATES, ["old_state_id"] + ) elif new_version == 17: # This dropped the statistics table, done again in version 18. pass @@ -511,13 +541,13 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 elif new_version == 19: # This adds the statistic runs table, insert a fake run to prevent duplicating # statistics. - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: session.add(StatisticsRuns(start=get_start_time())) elif new_version == 20: # This changed the precision of statistics from float to double if engine.dialect.name in ["mysql", "postgresql"]: _modify_columns( - instance, + session_maker, engine, "statistics", [ @@ -539,7 +569,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 table, ) with contextlib.suppress(SQLAlchemyError): - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: connection = session.connection() connection.execute( # Using LOCK=EXCLUSIVE to prevent the database from corrupting @@ -574,7 +604,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 # Block 5-minute statistics for one hour from the last run, or it will overlap # with existing hourly statistics. Don't block on a database with no existing # statistics. - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: if session.query(Statistics.id).count() and ( last_run_string := session.query( func.max(StatisticsRuns.start) @@ -590,7 +620,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 # When querying the database, be careful to only explicitly query for columns # which were present in schema version 21. If querying the table, SQLAlchemy # will refer to future columns. - with session_scope(session=instance.get_session()) as session: + with session_scope(session=session_maker()) as session: for sum_statistic in session.query(StatisticsMeta.id).filter_by( has_sum=true() ): @@ -617,48 +647,52 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 ) elif new_version == 23: # Add name column to StatisticsMeta - _add_columns(instance, "statistics_meta", ["name VARCHAR(255)"]) + _add_columns(session_maker, "statistics_meta", ["name VARCHAR(255)"]) elif new_version == 24: # Recreate statistics indices to block duplicated statistics - _drop_index(instance, "statistics", "ix_statistics_statistic_id_start") + _drop_index(session_maker, "statistics", "ix_statistics_statistic_id_start") _drop_index( - instance, + session_maker, "statistics_short_term", "ix_statistics_short_term_statistic_id_start", ) try: - _create_index(instance, "statistics", "ix_statistics_statistic_id_start") _create_index( - instance, + session_maker, "statistics", "ix_statistics_statistic_id_start" + ) + _create_index( + session_maker, "statistics_short_term", "ix_statistics_short_term_statistic_id_start", ) except DatabaseError: # There may be duplicated statistics entries, delete duplicated statistics # and try again - with session_scope(session=instance.get_session()) as session: - delete_duplicates(instance, session) - _create_index(instance, "statistics", "ix_statistics_statistic_id_start") + with session_scope(session=session_maker()) as session: + delete_duplicates(hass, session) _create_index( - instance, + session_maker, "statistics", "ix_statistics_statistic_id_start" + ) + _create_index( + session_maker, "statistics_short_term", "ix_statistics_short_term_statistic_id_start", ) elif new_version == 25: - _add_columns(instance, "states", [f"attributes_id {big_int}"]) - _create_index(instance, "states", "ix_states_attributes_id") + _add_columns(session_maker, "states", [f"attributes_id {big_int}"]) + _create_index(session_maker, "states", "ix_states_attributes_id") elif new_version == 26: - _create_index(instance, "statistics_runs", "ix_statistics_runs_start") + _create_index(session_maker, "statistics_runs", "ix_statistics_runs_start") elif new_version == 27: - _add_columns(instance, "events", [f"data_id {big_int}"]) - _create_index(instance, "events", "ix_events_data_id") + _add_columns(session_maker, "events", [f"data_id {big_int}"]) + _create_index(session_maker, "events", "ix_events_data_id") elif new_version == 28: - _add_columns(instance, "events", ["origin_idx INTEGER"]) + _add_columns(session_maker, "events", ["origin_idx INTEGER"]) # We never use the user_id or parent_id index - _drop_index(instance, "events", "ix_events_context_user_id") - _drop_index(instance, "events", "ix_events_context_parent_id") + _drop_index(session_maker, "events", "ix_events_context_user_id") + _drop_index(session_maker, "events", "ix_events_context_parent_id") _add_columns( - instance, + session_maker, "states", [ "origin_idx INTEGER", @@ -667,14 +701,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901 "context_parent_id VARCHAR(36)", ], ) - _create_index(instance, "states", "ix_states_context_id") + _create_index(session_maker, "states", "ix_states_context_id") # Once there are no longer any state_changed events # in the events table we can drop the index on states.event_id else: raise ValueError(f"No schema migration defined for version {new_version}") -def _inspect_schema_version(session): +def _inspect_schema_version(session: Session) -> int: """Determine the schema version by inspecting the db structure. When the schema version is not present in the db, either db was just @@ -696,4 +730,4 @@ def _inspect_schema_version(session): # Version 1 schema changes not found, this db needs to be migrated. current_version = SchemaChanges(schema_version=0) session.add(current_version) - return current_version.schema_version + return cast(int, current_version.schema_version) diff --git a/homeassistant/components/recorder/purge.py b/homeassistant/components/recorder/purge.py index 3a0e2e6e141..b2547d13e45 100644 --- a/homeassistant/components/recorder/purge.py +++ b/homeassistant/components/recorder/purge.py @@ -47,7 +47,7 @@ def purge_old_data( ) using_sqlite = instance.using_sqlite() - with session_scope(session=instance.get_session()) as session: # type: ignore[misc] + with session_scope(session=instance.get_session()) as session: # Purge a max of MAX_ROWS_TO_PURGE, based on the oldest states or events record ( event_ids, @@ -515,7 +515,7 @@ def _purge_filtered_events( def purge_entity_data(instance: Recorder, entity_filter: Callable[[str], bool]) -> bool: """Purge states and events of specified entities.""" using_sqlite = instance.using_sqlite() - with session_scope(session=instance.get_session()) as session: # type: ignore[misc] + with session_scope(session=instance.get_session()) as session: selected_entity_ids: list[str] = [ entity_id for (entity_id,) in session.query(distinct(States.entity_id)).all() diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 1c993b32bb6..9104fb7e234 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -377,7 +377,7 @@ def _delete_duplicates_from_table( return (total_deleted_rows, all_non_identical_duplicates) -def delete_duplicates(instance: Recorder, session: Session) -> None: +def delete_duplicates(hass: HomeAssistant, session: Session) -> None: """Identify and delete duplicated statistics. A backup will be made of duplicated statistics before it is deleted. @@ -391,7 +391,7 @@ def delete_duplicates(instance: Recorder, session: Session) -> None: if non_identical_duplicates: isotime = dt_util.utcnow().isoformat() backup_file_name = f"deleted_statistics.{isotime}.json" - backup_path = instance.hass.config.path(STORAGE_DIR, backup_file_name) + backup_path = hass.config.path(STORAGE_DIR, backup_file_name) os.makedirs(os.path.dirname(backup_path), exist_ok=True) with open(backup_path, "w", encoding="utf8") as backup_file: @@ -551,7 +551,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool: end = start + timedelta(minutes=5) # Return if we already have 5-minute statistics for the requested period - with session_scope(session=instance.get_session()) as session: # type: ignore[misc] + with session_scope(session=instance.get_session()) as session: if session.query(StatisticsRuns).filter_by(start=start).first(): _LOGGER.debug("Statistics already compiled for %s-%s", start, end) return True @@ -578,7 +578,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool: # Insert collected statistics in the database with session_scope( - session=instance.get_session(), # type: ignore[misc] + session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: for stats in platform_stats: @@ -768,7 +768,7 @@ def _configured_unit(unit: str | None, units: UnitSystem) -> str | None: def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: """Clear statistics for a list of statistic_ids.""" - with session_scope(session=instance.get_session()) as session: # type: ignore[misc] + with session_scope(session=instance.get_session()) as session: session.query(StatisticsMeta).filter( StatisticsMeta.statistic_id.in_(statistic_ids) ).delete(synchronize_session=False) @@ -778,7 +778,7 @@ def update_statistics_metadata( instance: Recorder, statistic_id: str, unit_of_measurement: str | None ) -> None: """Update statistics metadata for a statistic_id.""" - with session_scope(session=instance.get_session()) as session: # type: ignore[misc] + with session_scope(session=instance.get_session()) as session: session.query(StatisticsMeta).filter( StatisticsMeta.statistic_id == statistic_id ).update({StatisticsMeta.unit_of_measurement: unit_of_measurement}) @@ -1376,7 +1376,7 @@ def add_external_statistics( """Process an add_external_statistics job.""" with session_scope( - session=instance.get_session(), # type: ignore[misc] + session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: old_metadata_dict = get_metadata_with_session( @@ -1403,7 +1403,7 @@ def adjust_statistics( ) -> bool: """Process an add_statistics job.""" - with session_scope(session=instance.get_session()) as session: # type: ignore[misc] + with session_scope(session=instance.get_session()) as session: metadata = get_metadata_with_session( instance.hass, session, statistic_ids=(statistic_id,) ) diff --git a/homeassistant/components/recorder/tasks.py b/homeassistant/components/recorder/tasks.py index bed49e36f16..e12526b316a 100644 --- a/homeassistant/components/recorder/tasks.py +++ b/homeassistant/components/recorder/tasks.py @@ -65,8 +65,6 @@ class PurgeTask(RecorderTask): def run(self, instance: Recorder) -> None: """Purge the database.""" - assert instance.get_session is not None - if purge.purge_old_data( instance, self.purge_before, self.repack, self.apply_filter ): diff --git a/mypy.ini b/mypy.ini index 3d97a716955..81677b8d8ff 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1710,183 +1710,7 @@ no_implicit_optional = true warn_return_any = true warn_unreachable = true -[mypy-homeassistant.components.recorder] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.const] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.core] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.backup] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.executor] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.history] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.models] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.pool] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.purge] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.repack] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.run_history] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.services] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.statistics] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.system_health] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.tasks] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.util] -check_untyped_defs = true -disallow_incomplete_defs = true -disallow_subclassing_any = true -disallow_untyped_calls = true -disallow_untyped_decorators = true -disallow_untyped_defs = true -no_implicit_optional = true -warn_return_any = true -warn_unreachable = true - -[mypy-homeassistant.components.recorder.websocket_api] +[mypy-homeassistant.components.recorder.*] check_untyped_defs = true disallow_incomplete_defs = true disallow_subclassing_any = true diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 18b74df0189..17287151bc1 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -138,6 +138,8 @@ async def test_shutdown_closes_connections(hass, recorder_mock): await hass.async_block_till_done() assert len(pool.shutdown.mock_calls) == 1 + with pytest.raises(RuntimeError): + assert instance.get_session() async def test_state_gets_saved_when_set_before_start_event( diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index 0a95d174d66..fcc35938088 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -60,9 +60,12 @@ async def test_schema_update_calls(hass): await async_wait_recording_done(hass) assert recorder.util.async_migration_in_progress(hass) is False + instance = recorder.get_instance(hass) + engine = instance.engine + session_maker = instance.get_session update.assert_has_calls( [ - call(hass.data[DATA_INSTANCE], version + 1, 0) + call(hass, engine, session_maker, version + 1, 0) for version in range(0, models.SCHEMA_VERSION) ] ) @@ -327,10 +330,10 @@ async def test_schema_migrate(hass, start_version): assert recorder.util.async_migration_in_progress(hass) is not True -def test_invalid_update(): +def test_invalid_update(hass): """Test that an invalid new version raises an exception.""" with pytest.raises(ValueError): - migration._apply_update(Mock(), -1, 0) + migration._apply_update(hass, Mock(), Mock(), -1, 0) @pytest.mark.parametrize( @@ -351,7 +354,9 @@ def test_modify_column(engine_type, substr): instance.get_session = Mock(return_value=session) engine = Mock() engine.dialect.name = engine_type - migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"]) + migration._modify_columns( + instance.get_session, engine, "events", ["event_type VARCHAR(64)"] + ) if substr: assert substr in connection.execute.call_args[0][0].text else: @@ -365,8 +370,12 @@ def test_forgiving_add_column(): session.execute(text("CREATE TABLE hello (id int)")) instance = Mock() instance.get_session = Mock(return_value=session) - migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"]) - migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"]) + migration._add_columns( + instance.get_session, "hello", ["context_id CHARACTER(36)"] + ) + migration._add_columns( + instance.get_session, "hello", ["context_id CHARACTER(36)"] + ) def test_forgiving_add_index(): @@ -376,7 +385,7 @@ def test_forgiving_add_index(): with Session(engine) as session: instance = Mock() instance.get_session = Mock(return_value=session) - migration._create_index(instance, "states", "ix_states_context_id") + migration._create_index(instance.get_session, "states", "ix_states_context_id") @pytest.mark.parametrize( diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index dc13f2abb6a..765364a7487 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -740,7 +740,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog): hass = hass_recorder() wait_recording_done(hass) with session_scope(hass=hass) as session: - delete_duplicates(hass.data[DATA_INSTANCE], session) + delete_duplicates(hass, session) assert "duplicated statistics rows" not in caplog.text assert "Found non identical" not in caplog.text assert "Found duplicated" not in caplog.text