From f869ce9d062f42658a12239e55793d044786c955 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 29 Nov 2022 10:16:08 +0100 Subject: [PATCH] Validate common statistics DB schema errors on start (#79707) * Validate common statistics db schema errors on start * Fix test * Add tests * Adjust tests * Disable statistics schema validation in tests * Update after rebase --- homeassistant/components/recorder/core.py | 12 +- .../components/recorder/migration.py | 43 +- .../components/recorder/statistics.py | 406 +++++++++++++++--- tests/components/recorder/test_statistics.py | 197 ++++++++- tests/conftest.py | 31 ++ 5 files changed, 602 insertions(+), 87 deletions(-) diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 61b75783783..a79724f765a 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -591,16 +591,14 @@ class Recorder(threading.Thread): self.hass.add_job(self.async_connection_failed) return - schema_status = migration.validate_db_schema(self.hass, self.get_session) + schema_status = migration.validate_db_schema(self.hass, self, self.get_session) if schema_status is None: # Give up if we could not validate the schema self.hass.add_job(self.async_connection_failed) return self.schema_version = schema_status.current_version - schema_is_valid = migration.schema_is_valid(schema_status) - - if schema_is_valid: + if schema_status.valid: self._setup_run() else: self.migration_in_progress = True @@ -608,8 +606,8 @@ class Recorder(threading.Thread): self.hass.add_job(self.async_connection_success) - if self.migration_is_live or schema_is_valid: - # If the migrate is live or the schema is current, we need to + if self.migration_is_live or schema_status.valid: + # If the migrate is live or the schema is valid, we need to # wait for startup to complete. If its not live, we need to continue # on. self.hass.add_job(self.async_set_db_ready) @@ -626,7 +624,7 @@ class Recorder(threading.Thread): self.hass.add_job(self.async_set_db_ready) return - if not schema_is_valid: + if not schema_status.valid: if self._migrate_schema_and_setup_run(schema_status): self.schema_version = SCHEMA_VERSION if not self._event_listener: diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 22a3b382c7d..cf56c39a885 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable import contextlib -from dataclasses import dataclass +from dataclasses import dataclass, replace as dataclass_replace from datetime import timedelta import logging from typing import TYPE_CHECKING @@ -37,9 +37,11 @@ from .db_schema import ( ) from .models import process_timestamp from .statistics import ( + correct_db_schema as statistics_correct_db_schema, delete_statistics_duplicates, delete_statistics_meta_duplicates, get_start_time, + validate_db_schema as statistics_validate_db_schema, ) from .util import session_scope @@ -83,6 +85,8 @@ class SchemaValidationStatus: """Store schema validation status.""" current_version: int + statistics_schema_errors: set[str] + valid: bool def _schema_is_current(current_version: int) -> bool: @@ -90,13 +94,8 @@ def _schema_is_current(current_version: int) -> bool: return current_version == SCHEMA_VERSION -def schema_is_valid(schema_status: SchemaValidationStatus) -> bool: - """Check if the schema is valid.""" - return _schema_is_current(schema_status.current_version) - - def validate_db_schema( - hass: HomeAssistant, session_maker: Callable[[], Session] + hass: HomeAssistant, engine: Engine, session_maker: Callable[[], Session] ) -> SchemaValidationStatus | None: """Check if the schema is valid. @@ -104,11 +103,20 @@ def validate_db_schema( errors caused by manual migration between database engines, for example importing an SQLite database to MariaDB. """ + schema_errors: set[str] = set() + current_version = get_schema_version(session_maker) if current_version is None: return None - return SchemaValidationStatus(current_version) + if is_current := _schema_is_current(current_version): + # We can only check for further errors if the schema is current, because + # columns may otherwise not exist etc. + schema_errors |= statistics_validate_db_schema(hass, engine, session_maker) + + valid = is_current and not schema_errors + + return SchemaValidationStatus(current_version, schema_errors, valid) def live_migration(schema_status: SchemaValidationStatus) -> bool: @@ -125,10 +133,18 @@ def migrate_schema( ) -> None: """Check if the schema needs to be upgraded.""" current_version = schema_status.current_version - _LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version) + if current_version != SCHEMA_VERSION: + _LOGGER.warning( + "Database is about to upgrade from schema version: %s to: %s", + current_version, + SCHEMA_VERSION, + ) db_ready = False for version in range(current_version, SCHEMA_VERSION): - if live_migration(SchemaValidationStatus(version)) and not db_ready: + if ( + live_migration(dataclass_replace(schema_status, current_version=version)) + and not db_ready + ): db_ready = True instance.migration_is_live = True hass.add_job(instance.async_set_db_ready) @@ -140,6 +156,13 @@ def migrate_schema( _LOGGER.info("Upgrade to version %s done", new_version) + if schema_errors := schema_status.statistics_schema_errors: + _LOGGER.warning( + "Database is about to correct DB schema errors: %s", + ", ".join(sorted(schema_errors)), + ) + statistics_correct_db_schema(engine, session_maker, schema_errors) + def _create_index( session_maker: Callable[[], Session], table_name: str, index_name: str diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 6c117b9698d..303a925f9a0 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Mapping import contextlib import dataclasses from datetime import datetime, timedelta @@ -15,9 +15,10 @@ import re from statistics import mean from typing import TYPE_CHECKING, Any, Literal -from sqlalchemy import bindparam, func, lambda_stmt, select +from sqlalchemy import bindparam, func, lambda_stmt, select, text +from sqlalchemy.engine import Engine from sqlalchemy.engine.row import Row -from sqlalchemy.exc import SQLAlchemyError, StatementError +from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal_column, true from sqlalchemy.sql.lambdas import StatementLambdaElement @@ -874,12 +875,17 @@ def get_metadata( ) +def _clear_statistics_with_session(session: Session, statistic_ids: list[str]) -> None: + """Clear statistics for a list of statistic_ids.""" + session.query(StatisticsMeta).filter( + StatisticsMeta.statistic_id.in_(statistic_ids) + ).delete(synchronize_session=False) + + def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: """Clear statistics for a list of statistic_ids.""" with session_scope(session=instance.get_session()) as session: - session.query(StatisticsMeta).filter( - StatisticsMeta.statistic_id.in_(statistic_ids) - ).delete(synchronize_session=False) + _clear_statistics_with_session(session, statistic_ids) def update_statistics_metadata( @@ -1562,6 +1568,78 @@ def statistic_during_period( return {key: convert(value) for key, value in result.items()} +def _statistics_during_period_with_session( + hass: HomeAssistant, + session: Session, + start_time: datetime, + end_time: datetime | None, + statistic_ids: list[str] | None, + period: Literal["5minute", "day", "hour", "week", "month"], + units: dict[str, str] | None, + types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], +) -> dict[str, list[dict[str, Any]]]: + """Return statistic data points during UTC period start_time - end_time. + + If end_time is omitted, returns statistics newer than or equal to start_time. + If statistic_ids is omitted, returns statistics for all statistics ids. + """ + metadata = None + # Fetch metadata for the given (or all) statistic_ids + metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) + if not metadata: + return {} + + metadata_ids = None + if statistic_ids is not None: + metadata_ids = [metadata_id for metadata_id, _ in metadata.values()] + + table: type[Statistics | StatisticsShortTerm] = ( + Statistics if period != "5minute" else StatisticsShortTerm + ) + stmt = _statistics_during_period_stmt( + start_time, end_time, metadata_ids, table, types + ) + stats = execute_stmt_lambda_element(session, stmt) + + if not stats: + return {} + # Return statistics combined with metadata + if period not in ("day", "week", "month"): + return _sorted_statistics_to_dict( + hass, + session, + stats, + statistic_ids, + metadata, + True, + table, + start_time, + units, + types, + ) + + result = _sorted_statistics_to_dict( + hass, + session, + stats, + statistic_ids, + metadata, + True, + table, + start_time, + units, + types, + ) + + if period == "day": + return _reduce_statistics_per_day(result, types) + + if period == "week": + return _reduce_statistics_per_week(result, types) + + return _reduce_statistics_per_month(result, types) + + def statistics_during_period( hass: HomeAssistant, start_time: datetime, @@ -1576,63 +1654,18 @@ def statistics_during_period( If end_time is omitted, returns statistics newer than or equal to start_time. If statistic_ids is omitted, returns statistics for all statistics ids. """ - metadata = None with session_scope(hass=hass) as session: - # Fetch metadata for the given (or all) statistic_ids - metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) - if not metadata: - return {} - - metadata_ids = None - if statistic_ids is not None: - metadata_ids = [metadata_id for metadata_id, _ in metadata.values()] - - table: type[Statistics | StatisticsShortTerm] = ( - Statistics if period != "5minute" else StatisticsShortTerm - ) - stmt = _statistics_during_period_stmt( - start_time, end_time, metadata_ids, table, types - ) - stats = execute_stmt_lambda_element(session, stmt) - - if not stats: - return {} - # Return statistics combined with metadata - if period not in ("day", "week", "month"): - return _sorted_statistics_to_dict( - hass, - session, - stats, - statistic_ids, - metadata, - True, - table, - start_time, - units, - types, - ) - - result = _sorted_statistics_to_dict( + return _statistics_during_period_with_session( hass, session, - stats, - statistic_ids, - metadata, - True, - table, start_time, + end_time, + statistic_ids, + period, units, types, ) - if period == "day": - return _reduce_statistics_per_day(result, types) - - if period == "week": - return _reduce_statistics_per_week(result, types) - - return _reduce_statistics_per_month(result, types) - def _get_last_statistics_stmt( metadata_id: int, @@ -2047,6 +2080,26 @@ def _filter_unique_constraint_integrity_error( return _filter_unique_constraint_integrity_error +def _import_statistics_with_session( + session: Session, + metadata: StatisticMetaData, + statistics: Iterable[StatisticData], + table: type[Statistics | StatisticsShortTerm], +) -> bool: + """Import statistics to the database.""" + old_metadata_dict = get_metadata_with_session( + session, statistic_ids=[metadata["statistic_id"]] + ) + metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict) + for stat in statistics: + if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]): + _update_statistics(session, table, stat_id, stat) + else: + _insert_statistics(session, table, metadata_id, stat) + + return True + + @retryable_database_job("statistics") def import_statistics( instance: Recorder, @@ -2060,19 +2113,7 @@ def import_statistics( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: - old_metadata_dict = get_metadata_with_session( - session, statistic_ids=[metadata["statistic_id"]] - ) - metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict) - for stat in statistics: - if stat_id := _statistics_exists( - session, table, metadata_id, stat["start"] - ): - _update_statistics(session, table, stat_id, stat) - else: - _insert_statistics(session, table, metadata_id, stat) - - return True + return _import_statistics_with_session(session, metadata, statistics, table) @retryable_database_job("adjust_statistics") @@ -2189,3 +2230,232 @@ def async_change_statistics_unit( new_unit_of_measurement=new_unit_of_measurement, old_unit_of_measurement=old_unit_of_measurement, ) + + +def _validate_db_schema_utf8( + instance: Recorder, session_maker: Callable[[], Session] +) -> set[str]: + """Do some basic checks for common schema errors caused by manual migration.""" + schema_errors: set[str] = set() + + # Lack of full utf8 support is only an issue for MySQL / MariaDB + if instance.dialect_name != SupportedDialect.MYSQL: + return schema_errors + + # This name can't be represented unless 4-byte UTF-8 unicode is supported + utf8_name = "𓆚𓃗" + statistic_id = f"{DOMAIN}.db_test" + + metadata: StatisticMetaData = { + "has_mean": True, + "has_sum": True, + "name": utf8_name, + "source": DOMAIN, + "statistic_id": statistic_id, + "unit_of_measurement": None, + } + + # Try inserting some metadata which needs utfmb4 support + try: + with session_scope(session=session_maker()) as session: + old_metadata_dict = get_metadata_with_session( + session, statistic_ids=[statistic_id] + ) + try: + _update_or_add_metadata(session, metadata, old_metadata_dict) + _clear_statistics_with_session(session, statistic_ids=[statistic_id]) + except OperationalError as err: + if err.orig and err.orig.args[0] == 1366: + _LOGGER.debug( + "Database table statistics_meta does not support 4-byte UTF-8" + ) + schema_errors.add("statistics_meta.4-byte UTF-8") + session.rollback() + else: + raise + except Exception as exc: # pylint: disable=broad-except + _LOGGER.exception("Error when validating DB schema: %s", exc) + return schema_errors + + +def _validate_db_schema( + hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session] +) -> set[str]: + """Do some basic checks for common schema errors caused by manual migration.""" + schema_errors: set[str] = set() + + # Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL + if instance.dialect_name not in ( + SupportedDialect.MYSQL, + SupportedDialect.POSTGRESQL, + ): + return schema_errors + + # This number can't be accurately represented as a 32-bit float + precise_number = 1.000000000000001 + # This time can't be accurately represented unless datetimes have µs precision + precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC) + + start_time = datetime(2020, 10, 6, tzinfo=dt_util.UTC) + statistic_id = f"{DOMAIN}.db_test" + + metadata: StatisticMetaData = { + "has_mean": True, + "has_sum": True, + "name": None, + "source": DOMAIN, + "statistic_id": statistic_id, + "unit_of_measurement": None, + } + statistics: StatisticData = { + "last_reset": precise_time, + "max": precise_number, + "mean": precise_number, + "min": precise_number, + "start": precise_time, + "state": precise_number, + "sum": precise_number, + } + + def check_columns( + schema_errors: set[str], + stored: Mapping, + expected: Mapping, + columns: tuple[str, ...], + table_name: str, + supports: str, + ) -> None: + for column in columns: + if stored[column] != expected[column]: + schema_errors.add(f"{table_name}.{supports}") + _LOGGER.debug( + "Column %s in database table %s does not support %s (%s != %s)", + column, + table_name, + supports, + stored[column], + expected[column], + ) + + # Insert / adjust a test statistics row in each of the tables + tables: tuple[type[Statistics | StatisticsShortTerm], ...] = ( + Statistics, + StatisticsShortTerm, + ) + try: + with session_scope(session=session_maker()) as session: + for table in tables: + _import_statistics_with_session(session, metadata, (statistics,), table) + stored_statistics = _statistics_during_period_with_session( + hass, + session, + start_time, + None, + [statistic_id], + "hour" if table == Statistics else "5minute", + None, + {"last_reset", "max", "mean", "min", "state", "sum"}, + ) + if not (stored_statistic := stored_statistics.get(statistic_id)): + _LOGGER.warning( + "Schema validation failed for table: %s", table.__tablename__ + ) + continue + + check_columns( + schema_errors, + stored_statistic[0], + statistics, + ("max", "mean", "min", "state", "sum"), + table.__tablename__, + "double precision", + ) + assert statistics["last_reset"] + check_columns( + schema_errors, + stored_statistic[0], + { + "last_reset": statistics["last_reset"], + "start": statistics["start"], + }, + ("start", "last_reset"), + table.__tablename__, + "µs precision", + ) + _clear_statistics_with_session(session, statistic_ids=[statistic_id]) + except Exception as exc: # pylint: disable=broad-except + _LOGGER.exception("Error when validating DB schema: %s", exc) + + return schema_errors + + +def validate_db_schema( + hass: HomeAssistant, instance: Recorder, session_maker: Callable[[], Session] +) -> set[str]: + """Do some basic checks for common schema errors caused by manual migration.""" + schema_errors: set[str] = set() + schema_errors |= _validate_db_schema_utf8(instance, session_maker) + schema_errors |= _validate_db_schema(hass, instance, session_maker) + if schema_errors: + _LOGGER.debug( + "Detected statistics schema errors: %s", ", ".join(sorted(schema_errors)) + ) + return schema_errors + + +def correct_db_schema( + engine: Engine, session_maker: Callable[[], Session], schema_errors: set[str] +) -> None: + """Correct issues detected by validate_db_schema.""" + from .migration import _modify_columns # pylint: disable=import-outside-toplevel + + if "statistics_meta.4-byte UTF-8" in schema_errors: + # Attempt to convert the table to utf8mb4 + _LOGGER.warning( + "Updating character set and collation of table %s to utf8mb4. " + "Note: this can take several minutes on large databases and slow " + "computers. Please be patient!", + "statistics_meta", + ) + with contextlib.suppress(SQLAlchemyError): + with session_scope(session=session_maker()) as session: + connection = session.connection() + connection.execute( + # Using LOCK=EXCLUSIVE to prevent the database from corrupting + # https://github.com/home-assistant/core/issues/56104 + text( + "ALTER TABLE statistics_meta CONVERT TO " + "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci, LOCK=EXCLUSIVE" + ) + ) + + tables: tuple[type[Statistics | StatisticsShortTerm], ...] = ( + Statistics, + StatisticsShortTerm, + ) + for table in tables: + if f"{table.__tablename__}.double precision" in schema_errors: + # Attempt to convert float columns to double precision + _modify_columns( + session_maker, + engine, + table.__tablename__, + [ + "mean DOUBLE PRECISION", + "min DOUBLE PRECISION", + "max DOUBLE PRECISION", + "state DOUBLE PRECISION", + "sum DOUBLE PRECISION", + ], + ) + if f"{table.__tablename__}.µs precision" in schema_errors: + # Attempt to convert datetime columns to µs precision + _modify_columns( + session_maker, + engine, + table.__tablename__, + [ + "last_reset DATETIME(6)", + "start DATETIME(6)", + ], + ) diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 8950365fd95..376087fdb1e 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -1,13 +1,14 @@ """The tests for sensor recorder platform.""" # pylint: disable=protected-access,invalid-name -from datetime import timedelta +from datetime import datetime, timedelta import importlib import sys -from unittest.mock import patch, sentinel +from unittest.mock import ANY, DEFAULT, MagicMock, patch, sentinel import pytest from pytest import approx from sqlalchemy import create_engine +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session from homeassistant.components import recorder @@ -16,6 +17,8 @@ from homeassistant.components.recorder.const import SQLITE_URL_PREFIX from homeassistant.components.recorder.db_schema import StatisticsShortTerm from homeassistant.components.recorder.models import process_timestamp from homeassistant.components.recorder.statistics import ( + _statistics_during_period_with_session, + _update_or_add_metadata, async_add_external_statistics, async_import_statistics, delete_statistics_duplicates, @@ -1475,6 +1478,196 @@ def test_delete_metadata_duplicates_no_duplicates(hass_recorder, caplog): assert "duplicated statistics_meta rows" not in caplog.text +@pytest.mark.parametrize("enable_statistics_table_validation", [True]) +@pytest.mark.parametrize("db_engine", ("mysql", "postgresql")) +async def test_validate_db_schema( + async_setup_recorder_instance, hass, caplog, db_engine +): + """Test validating DB schema with MySQL and PostgreSQL. + + Note: The test uses SQLite, the purpose is only to exercise the code. + """ + with patch( + "homeassistant.components.recorder.core.Recorder.dialect_name", db_engine + ): + await async_setup_recorder_instance(hass) + await async_wait_recording_done(hass) + assert "Schema validation failed" not in caplog.text + assert "Detected statistics schema errors" not in caplog.text + assert "Database is about to correct DB schema errors" not in caplog.text + + +@pytest.mark.parametrize("enable_statistics_table_validation", [True]) +async def test_validate_db_schema_fix_utf8_issue( + async_setup_recorder_instance, hass, caplog +): + """Test validating DB schema with MySQL. + + Note: The test uses SQLite, the purpose is only to exercise the code. + """ + orig_error = MagicMock() + orig_error.args = [1366] + utf8_error = OperationalError("", "", orig=orig_error) + with patch( + "homeassistant.components.recorder.core.Recorder.dialect_name", "mysql" + ), patch( + "homeassistant.components.recorder.statistics._update_or_add_metadata", + side_effect=[utf8_error, DEFAULT, DEFAULT], + wraps=_update_or_add_metadata, + ): + await async_setup_recorder_instance(hass) + await async_wait_recording_done(hass) + + assert "Schema validation failed" not in caplog.text + assert ( + "Database is about to correct DB schema errors: statistics_meta.4-byte UTF-8" + in caplog.text + ) + assert ( + "Updating character set and collation of table statistics_meta to utf8mb4" + in caplog.text + ) + + +@pytest.mark.parametrize("enable_statistics_table_validation", [True]) +@pytest.mark.parametrize("db_engine", ("mysql", "postgresql")) +@pytest.mark.parametrize( + "table, replace_index", (("statistics", 0), ("statistics_short_term", 1)) +) +@pytest.mark.parametrize( + "column, value", + (("max", 1.0), ("mean", 1.0), ("min", 1.0), ("state", 1.0), ("sum", 1.0)), +) +async def test_validate_db_schema_fix_float_issue( + async_setup_recorder_instance, + hass, + caplog, + db_engine, + table, + replace_index, + column, + value, +): + """Test validating DB schema with MySQL. + + Note: The test uses SQLite, the purpose is only to exercise the code. + """ + orig_error = MagicMock() + orig_error.args = [1366] + precise_number = 1.000000000000001 + precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC) + statistics = { + "recorder.db_test": [ + { + "last_reset": precise_time, + "max": precise_number, + "mean": precise_number, + "min": precise_number, + "start": precise_time, + "state": precise_number, + "sum": precise_number, + } + ] + } + statistics["recorder.db_test"][0][column] = value + fake_statistics = [DEFAULT, DEFAULT] + fake_statistics[replace_index] = statistics + + with patch( + "homeassistant.components.recorder.core.Recorder.dialect_name", db_engine + ), patch( + "homeassistant.components.recorder.statistics._statistics_during_period_with_session", + side_effect=fake_statistics, + wraps=_statistics_during_period_with_session, + ), patch( + "homeassistant.components.recorder.migration._modify_columns" + ) as modify_columns_mock: + await async_setup_recorder_instance(hass) + await async_wait_recording_done(hass) + + assert "Schema validation failed" not in caplog.text + assert ( + f"Database is about to correct DB schema errors: {table}.double precision" + in caplog.text + ) + modification = [ + "mean DOUBLE PRECISION", + "min DOUBLE PRECISION", + "max DOUBLE PRECISION", + "state DOUBLE PRECISION", + "sum DOUBLE PRECISION", + ] + modify_columns_mock.assert_called_once_with(ANY, ANY, table, modification) + + +@pytest.mark.parametrize("enable_statistics_table_validation", [True]) +@pytest.mark.parametrize("db_engine", ("mysql", "postgresql")) +@pytest.mark.parametrize( + "table, replace_index", (("statistics", 0), ("statistics_short_term", 1)) +) +@pytest.mark.parametrize( + "column, value", + ( + ("last_reset", "2020-10-06T00:00:00+00:00"), + ("start", "2020-10-06T00:00:00+00:00"), + ), +) +async def test_validate_db_schema_fix_statistics_datetime_issue( + async_setup_recorder_instance, + hass, + caplog, + db_engine, + table, + replace_index, + column, + value, +): + """Test validating DB schema with MySQL. + + Note: The test uses SQLite, the purpose is only to exercise the code. + """ + orig_error = MagicMock() + orig_error.args = [1366] + precise_number = 1.000000000000001 + precise_time = datetime(2020, 10, 6, microsecond=1, tzinfo=dt_util.UTC) + statistics = { + "recorder.db_test": [ + { + "last_reset": precise_time, + "max": precise_number, + "mean": precise_number, + "min": precise_number, + "start": precise_time, + "state": precise_number, + "sum": precise_number, + } + ] + } + statistics["recorder.db_test"][0][column] = value + fake_statistics = [DEFAULT, DEFAULT] + fake_statistics[replace_index] = statistics + + with patch( + "homeassistant.components.recorder.core.Recorder.dialect_name", db_engine + ), patch( + "homeassistant.components.recorder.statistics._statistics_during_period_with_session", + side_effect=fake_statistics, + wraps=_statistics_during_period_with_session, + ), patch( + "homeassistant.components.recorder.migration._modify_columns" + ) as modify_columns_mock: + await async_setup_recorder_instance(hass) + await async_wait_recording_done(hass) + + assert "Schema validation failed" not in caplog.text + assert ( + f"Database is about to correct DB schema errors: {table}.µs precision" + in caplog.text + ) + modification = ["last_reset DATETIME(6)", "start DATETIME(6)"] + modify_columns_mock.assert_called_once_with(ANY, ANY, table, modification) + + def record_states(hass): """Record some test states. diff --git a/tests/conftest.py b/tests/conftest.py index a508b8c32ef..b6638968182 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import asyncio from collections.abc import AsyncGenerator, Callable, Generator from contextlib import asynccontextmanager import functools +import itertools from json import JSONDecoder, loads import logging import sqlite3 @@ -860,6 +861,16 @@ def enable_statistics(): return False +@pytest.fixture +def enable_statistics_table_validation(): + """Fixture to control enabling of recorder's statistics table validation. + + To enable statistics table validation, tests can be marked with: + @pytest.mark.parametrize("enable_statistics_table_validation", [True]) + """ + return False + + @pytest.fixture def enable_nightly_purge(): """Fixture to control enabling of recorder's nightly purge job. @@ -902,6 +913,7 @@ def hass_recorder( recorder_db_url, enable_nightly_purge, enable_statistics, + enable_statistics_table_validation, hass_storage, ): """Home Assistant fixture with in-memory recorder.""" @@ -910,6 +922,11 @@ def hass_recorder( hass = get_test_home_assistant() nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None + stats_validate = ( + recorder.statistics.validate_db_schema + if enable_statistics_table_validation + else itertools.repeat(set()) + ) with patch( "homeassistant.components.recorder.Recorder.async_nightly_tasks", side_effect=nightly, @@ -918,6 +935,10 @@ def hass_recorder( "homeassistant.components.recorder.Recorder.async_periodic_statistics", side_effect=stats, autospec=True, + ), patch( + "homeassistant.components.recorder.migration.statistics_validate_db_schema", + side_effect=stats_validate, + autospec=True, ): def setup_recorder(config=None): @@ -962,12 +983,18 @@ async def async_setup_recorder_instance( hass_fixture_setup, enable_nightly_purge, enable_statistics, + enable_statistics_table_validation, ) -> AsyncGenerator[SetupRecorderInstanceT, None]: """Yield callable to setup recorder instance.""" assert not hass_fixture_setup nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None stats = recorder.Recorder.async_periodic_statistics if enable_statistics else None + stats_validate = ( + recorder.statistics.validate_db_schema + if enable_statistics_table_validation + else itertools.repeat(set()) + ) with patch( "homeassistant.components.recorder.Recorder.async_nightly_tasks", side_effect=nightly, @@ -976,6 +1003,10 @@ async def async_setup_recorder_instance( "homeassistant.components.recorder.Recorder.async_periodic_statistics", side_effect=stats, autospec=True, + ), patch( + "homeassistant.components.recorder.migration.statistics_validate_db_schema", + side_effect=stats_validate, + autospec=True, ): async def async_setup_recorder(