diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 3468acaed4c..32a6e9c24dc 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -86,6 +86,7 @@ from .table_managers.event_types import EventTypeManager from .table_managers.state_attributes import StateAttributesManager from .table_managers.states import StatesManager from .table_managers.states_meta import StatesMetaManager +from .table_managers.statistics_meta import StatisticsMetaManager from .tasks import ( AdjustLRUSizeTask, AdjustStatisticsTask, @@ -172,6 +173,7 @@ class Recorder(threading.Thread): threading.Thread.__init__(self, name="Recorder") self.hass = hass + self.thread_id: int | None = None self.auto_purge = auto_purge self.auto_repack = auto_repack self.keep_days = keep_days @@ -208,6 +210,7 @@ class Recorder(threading.Thread): self.state_attributes_manager = StateAttributesManager( self, exclude_attributes_by_domain ) + self.statistics_meta_manager = StatisticsMetaManager(self) self.event_session: Session | None = None self._get_session: Callable[[], Session] | None = None self._completed_first_database_setup: bool | None = None @@ -613,6 +616,7 @@ class Recorder(threading.Thread): def run(self) -> None: """Start processing events to save.""" + self.thread_id = threading.get_ident() setup_result = self._setup_recorder() if not setup_result: @@ -668,7 +672,7 @@ class Recorder(threading.Thread): "Database Migration Failed", "recorder_database_migration", ) - self._activate_and_set_db_ready() + self.hass.add_job(self.async_set_db_ready) self._shutdown() return @@ -687,7 +691,14 @@ class Recorder(threading.Thread): def _activate_and_set_db_ready(self) -> None: """Activate the table managers or schedule migrations and mark the db as ready.""" - with session_scope(session=self.get_session()) as session: + with session_scope(session=self.get_session(), read_only=True) as session: + # Prime the statistics meta manager as soon as possible + # since we want the frontend queries to avoid a thundering + # herd of queries to find the statistics meta data if + # there are a lot of statistics graphs on the frontend. + if self.schema_version >= 23: + self.statistics_meta_manager.load(session) + if ( self.schema_version < 36 or session.execute(has_events_context_ids_to_migrate()).scalar() @@ -758,10 +769,11 @@ class Recorder(threading.Thread): non_state_change_events.append(event_) assert self.event_session is not None - self.event_data_manager.load(non_state_change_events, self.event_session) - self.event_type_manager.load(non_state_change_events, self.event_session) - self.states_meta_manager.load(state_change_events, self.event_session) - self.state_attributes_manager.load(state_change_events, self.event_session) + session = self.event_session + self.event_data_manager.load(non_state_change_events, session) + self.event_type_manager.load(non_state_change_events, session) + self.states_meta_manager.load(state_change_events, session) + self.state_attributes_manager.load(state_change_events, session) def _guarded_process_one_task_or_recover(self, task: RecorderTask) -> None: """Process a task, guarding against exceptions to ensure the loop does not collapse.""" @@ -1077,6 +1089,7 @@ class Recorder(threading.Thread): self.event_data_manager.reset() self.event_type_manager.reset() self.states_meta_manager.reset() + self.statistics_meta_manager.reset() if not self.event_session: return diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index d594118ea54..08f5f21b896 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -873,7 +873,7 @@ def _apply_update( # noqa: C901 # There may be duplicated statistics_meta entries, delete duplicates # and try again with session_scope(session=session_maker()) as session: - delete_statistics_meta_duplicates(session) + delete_statistics_meta_duplicates(instance, session) _create_index( session_maker, "statistics_meta", "ix_statistics_meta_statistic_id" ) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 645b7d4f042..36874869bf9 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -21,7 +21,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.engine.row import Row from sqlalchemy.exc import OperationalError, SQLAlchemyError, StatementError from sqlalchemy.orm.session import Session -from sqlalchemy.sql.expression import literal_column, true +from sqlalchemy.sql.expression import literal_column from sqlalchemy.sql.lambdas import StatementLambdaElement import voluptuous as vol @@ -132,16 +132,6 @@ QUERY_STATISTICS_SUMMARY_SUM = ( .label("rownum"), ) -QUERY_STATISTIC_META = ( - StatisticsMeta.id, - StatisticsMeta.statistic_id, - StatisticsMeta.source, - StatisticsMeta.unit_of_measurement, - StatisticsMeta.has_mean, - StatisticsMeta.has_sum, - StatisticsMeta.name, -) - STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = { **{unit: DataRateConverter for unit in DataRateConverter.VALID_UNITS}, @@ -373,56 +363,6 @@ def get_start_time() -> datetime: return last_period -def _update_or_add_metadata( - session: Session, - new_metadata: StatisticMetaData, - old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], -) -> int: - """Get metadata_id for a statistic_id. - - If the statistic_id is previously unknown, add it. If it's already known, update - metadata if needed. - - Updating metadata source is not possible. - """ - statistic_id = new_metadata["statistic_id"] - if statistic_id not in old_metadata_dict: - meta = StatisticsMeta.from_meta(new_metadata) - session.add(meta) - session.flush() # Flush to get the metadata id assigned - _LOGGER.debug( - "Added new statistics metadata for %s, new_metadata: %s", - statistic_id, - new_metadata, - ) - return meta.id - - metadata_id, old_metadata = old_metadata_dict[statistic_id] - if ( - old_metadata["has_mean"] != new_metadata["has_mean"] - or old_metadata["has_sum"] != new_metadata["has_sum"] - or old_metadata["name"] != new_metadata["name"] - or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"] - ): - session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update( - { - StatisticsMeta.has_mean: new_metadata["has_mean"], - StatisticsMeta.has_sum: new_metadata["has_sum"], - StatisticsMeta.name: new_metadata["name"], - StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"], - }, - synchronize_session=False, - ) - _LOGGER.debug( - "Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s", - statistic_id, - old_metadata, - new_metadata, - ) - - return metadata_id - - def _find_duplicates( session: Session, table: type[StatisticsBase] ) -> tuple[list[int], list[dict]]: @@ -642,13 +582,16 @@ def _delete_statistics_meta_duplicates(session: Session) -> int: return total_deleted_rows -def delete_statistics_meta_duplicates(session: Session) -> None: +def delete_statistics_meta_duplicates(instance: Recorder, session: Session) -> None: """Identify and delete duplicated statistics_meta. This is used when migrating from schema version 28 to schema version 29. """ deleted_statistics_rows = _delete_statistics_meta_duplicates(session) if deleted_statistics_rows: + statistics_meta_manager = instance.statistics_meta_manager + statistics_meta_manager.reset() + statistics_meta_manager.load(session) _LOGGER.info( "Deleted %s duplicated statistics_meta rows", deleted_statistics_rows ) @@ -750,6 +693,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) - """ start = dt_util.as_utc(start) end = start + timedelta(minutes=5) + statistics_meta_manager = instance.statistics_meta_manager # Return if we already have 5-minute statistics for the requested period with session_scope( @@ -782,7 +726,7 @@ def compile_statistics(instance: Recorder, start: datetime, fire_events: bool) - # Insert collected statistics in the database for stats in platform_stats: - metadata_id = _update_or_add_metadata( + _, metadata_id = statistics_meta_manager.update_or_add( session, stats["meta"], current_metadata ) _insert_statistics( @@ -877,28 +821,8 @@ def _update_statistics( ) -def _generate_get_metadata_stmt( - statistic_ids: list[str] | None = None, - statistic_type: Literal["mean"] | Literal["sum"] | None = None, - statistic_source: str | None = None, -) -> StatementLambdaElement: - """Generate a statement to fetch metadata.""" - stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) - if statistic_ids: - stmt += lambda q: q.where( - # https://github.com/python/mypy/issues/2608 - StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type] - ) - if statistic_source is not None: - stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) - if statistic_type == "mean": - stmt += lambda q: q.where(StatisticsMeta.has_mean == true()) - elif statistic_type == "sum": - stmt += lambda q: q.where(StatisticsMeta.has_sum == true()) - return stmt - - def get_metadata_with_session( + instance: Recorder, session: Session, *, statistic_ids: list[str] | None = None, @@ -908,31 +832,15 @@ def get_metadata_with_session( """Fetch meta data. Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id. - If statistic_ids is given, fetch metadata only for the listed statistics_ids. If statistic_type is given, fetch metadata only for statistic_ids supporting it. """ - - # Fetch metatadata from the database - stmt = _generate_get_metadata_stmt(statistic_ids, statistic_type, statistic_source) - result = execute_stmt_lambda_element(session, stmt) - if not result: - return {} - - return { - meta.statistic_id: ( - meta.id, - { - "has_mean": meta.has_mean, - "has_sum": meta.has_sum, - "name": meta.name, - "source": meta.source, - "statistic_id": meta.statistic_id, - "unit_of_measurement": meta.unit_of_measurement, - }, - ) - for meta in result - } + return instance.statistics_meta_manager.get_many( + session, + statistic_ids=statistic_ids, + statistic_type=statistic_type, + statistic_source=statistic_source, + ) def get_metadata( @@ -945,6 +853,7 @@ def get_metadata( """Return metadata for statistic_ids.""" with session_scope(hass=hass, read_only=True) as session: return get_metadata_with_session( + get_instance(hass), session, statistic_ids=statistic_ids, statistic_type=statistic_type, @@ -952,17 +861,10 @@ def get_metadata( ) -def _clear_statistics_with_session(session: Session, statistic_ids: list[str]) -> None: - """Clear statistics for a list of statistic_ids.""" - session.query(StatisticsMeta).filter( - StatisticsMeta.statistic_id.in_(statistic_ids) - ).delete(synchronize_session=False) - - def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: """Clear statistics for a list of statistic_ids.""" with session_scope(session=instance.get_session()) as session: - _clear_statistics_with_session(session, statistic_ids) + instance.statistics_meta_manager.delete(session, statistic_ids) def update_statistics_metadata( @@ -972,20 +874,20 @@ def update_statistics_metadata( new_unit_of_measurement: str | None | UndefinedType, ) -> None: """Update statistics metadata for a statistic_id.""" + statistics_meta_manager = instance.statistics_meta_manager if new_unit_of_measurement is not UNDEFINED: with session_scope(session=instance.get_session()) as session: - session.query(StatisticsMeta).filter( - StatisticsMeta.statistic_id == statistic_id - ).update({StatisticsMeta.unit_of_measurement: new_unit_of_measurement}) - if new_statistic_id is not UNDEFINED: + statistics_meta_manager.update_unit_of_measurement( + session, statistic_id, new_unit_of_measurement + ) + if new_statistic_id is not UNDEFINED and new_statistic_id is not None: with session_scope( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: - session.query(StatisticsMeta).filter( - (StatisticsMeta.statistic_id == statistic_id) - & (StatisticsMeta.source == DOMAIN) - ).update({StatisticsMeta.statistic_id: new_statistic_id}) + statistics_meta_manager.update_statistic_id( + session, DOMAIN, statistic_id, new_statistic_id + ) def list_statistic_ids( @@ -1004,7 +906,7 @@ def list_statistic_ids( # Query the database with session_scope(hass=hass, read_only=True) as session: - metadata = get_metadata_with_session( + metadata = get_instance(hass).statistics_meta_manager.get_many( session, statistic_type=statistic_type, statistic_ids=statistic_ids ) @@ -1609,11 +1511,13 @@ def statistic_during_period( with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_id if not ( - metadata := get_metadata_with_session(session, statistic_ids=[statistic_id]) + metadata := get_instance(hass).statistics_meta_manager.get( + session, statistic_id + ) ): return result - metadata_id = metadata[statistic_id][0] + metadata_id = metadata[0] oldest_stat = _first_statistic(session, Statistics, metadata_id) oldest_5_min_stat = None @@ -1724,7 +1628,7 @@ def statistic_during_period( else: result["change"] = None - state_unit = unit = metadata[statistic_id][1]["unit_of_measurement"] + state_unit = unit = metadata[1]["unit_of_measurement"] if state := hass.states.get(statistic_id): state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) convert = _get_statistic_to_display_unit_converter(unit, state_unit, units) @@ -1749,7 +1653,9 @@ def _statistics_during_period_with_session( """ metadata = None # Fetch metadata for the given (or all) statistic_ids - metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) + metadata = get_instance(hass).statistics_meta_manager.get_many( + session, statistic_ids=statistic_ids + ) if not metadata: return {} @@ -1885,7 +1791,9 @@ def _get_last_statistics( statistic_ids = [statistic_id] with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_id - metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) + metadata = get_instance(hass).statistics_meta_manager.get_many( + session, statistic_ids=statistic_ids + ) if not metadata: return {} metadata_id = metadata[statistic_id][0] @@ -1973,7 +1881,9 @@ def get_latest_short_term_statistics( with session_scope(hass=hass, read_only=True) as session: # Fetch metadata for the given statistic_ids if not metadata: - metadata = get_metadata_with_session(session, statistic_ids=statistic_ids) + metadata = get_instance(hass).statistics_meta_manager.get_many( + session, statistic_ids=statistic_ids + ) if not metadata: return {} metadata_ids = [ @@ -2318,16 +2228,20 @@ def _filter_unique_constraint_integrity_error( def _import_statistics_with_session( + instance: Recorder, session: Session, metadata: StatisticMetaData, statistics: Iterable[StatisticData], table: type[StatisticsBase], ) -> bool: """Import statistics to the database.""" - old_metadata_dict = get_metadata_with_session( + statistics_meta_manager = instance.statistics_meta_manager + old_metadata_dict = statistics_meta_manager.get_many( session, statistic_ids=[metadata["statistic_id"]] ) - metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict) + _, metadata_id = statistics_meta_manager.update_or_add( + session, metadata, old_metadata_dict + ) for stat in statistics: if stat_id := _statistics_exists(session, table, metadata_id, stat["start"]): _update_statistics(session, table, stat_id, stat) @@ -2350,7 +2264,9 @@ def import_statistics( session=instance.get_session(), exception_filter=_filter_unique_constraint_integrity_error(instance), ) as session: - return _import_statistics_with_session(session, metadata, statistics, table) + return _import_statistics_with_session( + instance, session, metadata, statistics, table + ) @retryable_database_job("adjust_statistics") @@ -2364,7 +2280,9 @@ def adjust_statistics( """Process an add_statistics job.""" with session_scope(session=instance.get_session()) as session: - metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]) + metadata = instance.statistics_meta_manager.get_many( + session, statistic_ids=[statistic_id] + ) if statistic_id not in metadata: return True @@ -2423,10 +2341,9 @@ def change_statistics_unit( old_unit: str, ) -> None: """Change statistics unit for a statistic_id.""" + statistics_meta_manager = instance.statistics_meta_manager with session_scope(session=instance.get_session()) as session: - metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]).get( - statistic_id - ) + metadata = statistics_meta_manager.get(session, statistic_id) # Guard against the statistics being removed or updated before the # change_statistics_unit job executes @@ -2447,9 +2364,10 @@ def change_statistics_unit( ) for table in tables: _change_statistics_unit_for_table(session, table, metadata_id, convert) - session.query(StatisticsMeta).filter( - StatisticsMeta.statistic_id == statistic_id - ).update({StatisticsMeta.unit_of_measurement: new_unit}) + + statistics_meta_manager.update_unit_of_measurement( + session, statistic_id, new_unit + ) @callback @@ -2495,16 +2413,19 @@ def _validate_db_schema_utf8( "statistic_id": statistic_id, "unit_of_measurement": None, } + statistics_meta_manager = instance.statistics_meta_manager # Try inserting some metadata which needs utfmb4 support try: with session_scope(session=session_maker()) as session: - old_metadata_dict = get_metadata_with_session( + old_metadata_dict = statistics_meta_manager.get_many( session, statistic_ids=[statistic_id] ) try: - _update_or_add_metadata(session, metadata, old_metadata_dict) - _clear_statistics_with_session(session, statistic_ids=[statistic_id]) + statistics_meta_manager.update_or_add( + session, metadata, old_metadata_dict + ) + statistics_meta_manager.delete(session, statistic_ids=[statistic_id]) except OperationalError as err: if err.orig and err.orig.args[0] == 1366: _LOGGER.debug( @@ -2524,6 +2445,7 @@ def _validate_db_schema( ) -> set[str]: """Do some basic checks for common schema errors caused by manual migration.""" schema_errors: set[str] = set() + statistics_meta_manager = instance.statistics_meta_manager # Wrong precision is only an issue for MySQL / MariaDB / PostgreSQL if instance.dialect_name not in ( @@ -2586,7 +2508,9 @@ def _validate_db_schema( try: with session_scope(session=session_maker()) as session: for table in tables: - _import_statistics_with_session(session, metadata, (statistics,), table) + _import_statistics_with_session( + instance, session, metadata, (statistics,), table + ) stored_statistics = _statistics_during_period_with_session( hass, session, @@ -2625,7 +2549,7 @@ def _validate_db_schema( table.__tablename__, "µs precision", ) - _clear_statistics_with_session(session, statistic_ids=[statistic_id]) + statistics_meta_manager.delete(session, statistic_ids=[statistic_id]) except Exception as exc: # pylint: disable=broad-except _LOGGER.exception("Error when validating DB schema: %s", exc) diff --git a/homeassistant/components/recorder/table_managers/event_data.py b/homeassistant/components/recorder/table_managers/event_data.py index a99b25fe0b4..4c661e3dc29 100644 --- a/homeassistant/components/recorder/table_managers/event_data.py +++ b/homeassistant/components/recorder/table_managers/event_data.py @@ -14,7 +14,7 @@ from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import EventData from ..queries import get_shared_event_datas -from ..util import chunked +from ..util import chunked, execute_stmt_lambda_element if TYPE_CHECKING: from ..core import Recorder @@ -96,8 +96,8 @@ class EventDataManager(BaseLRUTableManager[EventData]): results: dict[str, int | None] = {} with session.no_autoflush: for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS): - for data_id, shared_data in session.execute( - get_shared_event_datas(hashs_chunk) + for data_id, shared_data in execute_stmt_lambda_element( + session, get_shared_event_datas(hashs_chunk) ): results[shared_data] = self._id_map[shared_data] = cast( int, data_id diff --git a/homeassistant/components/recorder/table_managers/event_types.py b/homeassistant/components/recorder/table_managers/event_types.py index 3cb3d9fad97..5b77e9116c7 100644 --- a/homeassistant/components/recorder/table_managers/event_types.py +++ b/homeassistant/components/recorder/table_managers/event_types.py @@ -12,7 +12,7 @@ from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import EventTypes from ..queries import find_event_type_ids -from ..util import chunked +from ..util import chunked, execute_stmt_lambda_element if TYPE_CHECKING: from ..core import Recorder @@ -68,8 +68,8 @@ class EventTypeManager(BaseLRUTableManager[EventTypes]): with session.no_autoflush: for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS): - for event_type_id, event_type in session.execute( - find_event_type_ids(missing_chunk) + for event_type_id, event_type in execute_stmt_lambda_element( + session, find_event_type_ids(missing_chunk) ): results[event_type] = self._id_map[event_type] = cast( int, event_type_id diff --git a/homeassistant/components/recorder/table_managers/state_attributes.py b/homeassistant/components/recorder/table_managers/state_attributes.py index 7489a6f165d..51c626bd366 100644 --- a/homeassistant/components/recorder/table_managers/state_attributes.py +++ b/homeassistant/components/recorder/table_managers/state_attributes.py @@ -15,7 +15,7 @@ from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import StateAttributes from ..queries import get_shared_attributes -from ..util import chunked +from ..util import chunked, execute_stmt_lambda_element if TYPE_CHECKING: from ..core import Recorder @@ -113,8 +113,8 @@ class StateAttributesManager(BaseLRUTableManager[StateAttributes]): results: dict[str, int | None] = {} with session.no_autoflush: for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS): - for attributes_id, shared_attrs in session.execute( - get_shared_attributes(hashs_chunk) + for attributes_id, shared_attrs in execute_stmt_lambda_element( + session, get_shared_attributes(hashs_chunk) ): results[shared_attrs] = self._id_map[shared_attrs] = cast( int, attributes_id diff --git a/homeassistant/components/recorder/table_managers/states_meta.py b/homeassistant/components/recorder/table_managers/states_meta.py index b8b763aae33..76b748d4697 100644 --- a/homeassistant/components/recorder/table_managers/states_meta.py +++ b/homeassistant/components/recorder/table_managers/states_meta.py @@ -12,7 +12,7 @@ from . import BaseLRUTableManager from ..const import SQLITE_MAX_BIND_VARS from ..db_schema import StatesMeta from ..queries import find_all_states_metadata_ids, find_states_metadata_ids -from ..util import chunked +from ..util import chunked, execute_stmt_lambda_element if TYPE_CHECKING: from ..core import Recorder @@ -98,8 +98,8 @@ class StatesMetaManager(BaseLRUTableManager[StatesMeta]): with session.no_autoflush: for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS): - for metadata_id, entity_id in session.execute( - find_states_metadata_ids(missing_chunk) + for metadata_id, entity_id in execute_stmt_lambda_element( + session, find_states_metadata_ids(missing_chunk) ): metadata_id = cast(int, metadata_id) results[entity_id] = metadata_id diff --git a/homeassistant/components/recorder/table_managers/statistics_meta.py b/homeassistant/components/recorder/table_managers/statistics_meta.py new file mode 100644 index 00000000000..93417b43253 --- /dev/null +++ b/homeassistant/components/recorder/table_managers/statistics_meta.py @@ -0,0 +1,322 @@ +"""Support managing StatesMeta.""" +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, Literal, cast + +from lru import LRU # pylint: disable=no-name-in-module +from sqlalchemy import lambda_stmt, select +from sqlalchemy.orm.session import Session +from sqlalchemy.sql.expression import true +from sqlalchemy.sql.lambdas import StatementLambdaElement + +from ..db_schema import StatisticsMeta +from ..models import StatisticMetaData +from ..util import execute_stmt_lambda_element + +if TYPE_CHECKING: + from ..core import Recorder + +CACHE_SIZE = 8192 + +_LOGGER = logging.getLogger(__name__) + +QUERY_STATISTIC_META = ( + StatisticsMeta.id, + StatisticsMeta.statistic_id, + StatisticsMeta.source, + StatisticsMeta.unit_of_measurement, + StatisticsMeta.has_mean, + StatisticsMeta.has_sum, + StatisticsMeta.name, +) + + +def _generate_get_metadata_stmt( + statistic_ids: list[str] | None = None, + statistic_type: Literal["mean"] | Literal["sum"] | None = None, + statistic_source: str | None = None, +) -> StatementLambdaElement: + """Generate a statement to fetch metadata.""" + stmt = lambda_stmt(lambda: select(*QUERY_STATISTIC_META)) + if statistic_ids: + stmt += lambda q: q.where( + # https://github.com/python/mypy/issues/2608 + StatisticsMeta.statistic_id.in_(statistic_ids) # type:ignore[arg-type] + ) + if statistic_source is not None: + stmt += lambda q: q.where(StatisticsMeta.source == statistic_source) + if statistic_type == "mean": + stmt += lambda q: q.where(StatisticsMeta.has_mean == true()) + elif statistic_type == "sum": + stmt += lambda q: q.where(StatisticsMeta.has_sum == true()) + return stmt + + +def _statistics_meta_to_id_statistics_metadata( + meta: StatisticsMeta, +) -> tuple[int, StatisticMetaData]: + """Convert StatisticsMeta tuple of metadata_id and StatisticMetaData.""" + return ( + meta.id, + { + "has_mean": meta.has_mean, # type: ignore[typeddict-item] + "has_sum": meta.has_sum, # type: ignore[typeddict-item] + "name": meta.name, + "source": meta.source, # type: ignore[typeddict-item] + "statistic_id": meta.statistic_id, # type: ignore[typeddict-item] + "unit_of_measurement": meta.unit_of_measurement, + }, + ) + + +class StatisticsMetaManager: + """Manage the StatisticsMeta table.""" + + def __init__(self, recorder: Recorder) -> None: + """Initialize the statistics meta manager.""" + self.recorder = recorder + self._stat_id_to_id_meta: dict[str, tuple[int, StatisticMetaData]] = LRU( + CACHE_SIZE + ) + + def _clear_cache(self, statistic_ids: list[str]) -> None: + """Clear the cache.""" + for statistic_id in statistic_ids: + self._stat_id_to_id_meta.pop(statistic_id, None) + + def _get_from_database( + self, + session: Session, + statistic_ids: list[str] | None = None, + statistic_type: Literal["mean"] | Literal["sum"] | None = None, + statistic_source: str | None = None, + ) -> dict[str, tuple[int, StatisticMetaData]]: + """Fetch meta data and process it into results and/or cache.""" + # Only update the cache if we are in the recorder thread and there are no + # new objects that are not yet committed to the database in the session. + update_cache = ( + not session.new + and not session.dirty + and self.recorder.thread_id == threading.get_ident() + ) + results: dict[str, tuple[int, StatisticMetaData]] = {} + with session.no_autoflush: + stat_id_to_id_meta = self._stat_id_to_id_meta + for row in execute_stmt_lambda_element( + session, + _generate_get_metadata_stmt( + statistic_ids, statistic_type, statistic_source + ), + ): + statistics_meta = cast(StatisticsMeta, row) + id_meta = _statistics_meta_to_id_statistics_metadata(statistics_meta) + statistic_id = cast(str, statistics_meta.statistic_id) + results[statistic_id] = id_meta + if update_cache: + stat_id_to_id_meta[statistic_id] = id_meta + return results + + def _assert_in_recorder_thread(self) -> None: + """Assert that we are in the recorder thread.""" + if self.recorder.thread_id != threading.get_ident(): + raise RuntimeError("Detected unsafe call not in recorder thread") + + def _add_metadata( + self, session: Session, statistic_id: str, new_metadata: StatisticMetaData + ) -> int: + """Add metadata to the database. + + This call is not thread-safe and must be called from the + recorder thread. + """ + self._assert_in_recorder_thread() + meta = StatisticsMeta.from_meta(new_metadata) + session.add(meta) + # Flush to assign an ID + session.flush() + _LOGGER.debug( + "Added new statistics metadata for %s, new_metadata: %s", + statistic_id, + new_metadata, + ) + return meta.id + + def _update_metadata( + self, + session: Session, + statistic_id: str, + new_metadata: StatisticMetaData, + old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], + ) -> tuple[bool, int]: + """Update metadata in the database. + + This call is not thread-safe and must be called from the + recorder thread. + """ + metadata_id, old_metadata = old_metadata_dict[statistic_id] + if not ( + old_metadata["has_mean"] != new_metadata["has_mean"] + or old_metadata["has_sum"] != new_metadata["has_sum"] + or old_metadata["name"] != new_metadata["name"] + or old_metadata["unit_of_measurement"] + != new_metadata["unit_of_measurement"] + ): + return False, metadata_id + + self._assert_in_recorder_thread() + session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update( + { + StatisticsMeta.has_mean: new_metadata["has_mean"], + StatisticsMeta.has_sum: new_metadata["has_sum"], + StatisticsMeta.name: new_metadata["name"], + StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"], + }, + synchronize_session=False, + ) + self._clear_cache([statistic_id]) + _LOGGER.debug( + "Updated statistics metadata for %s, old_metadata: %s, new_metadata: %s", + statistic_id, + old_metadata, + new_metadata, + ) + return True, metadata_id + + def load(self, session: Session) -> None: + """Load the statistic_id to metadata_id mapping into memory. + + This call is not thread-safe and must be called from the + recorder thread. + """ + self.get_many(session) + + def get( + self, session: Session, statistic_id: str + ) -> tuple[int, StatisticMetaData] | None: + """Resolve statistic_id to the metadata_id.""" + return self.get_many(session, [statistic_id]).get(statistic_id) + + def get_many( + self, + session: Session, + statistic_ids: list[str] | None = None, + statistic_type: Literal["mean"] | Literal["sum"] | None = None, + statistic_source: str | None = None, + ) -> dict[str, tuple[int, StatisticMetaData]]: + """Fetch meta data. + + Returns a dict of (metadata_id, StatisticMetaData) tuples indexed by statistic_id. + + If statistic_ids is given, fetch metadata only for the listed statistics_ids. + If statistic_type is given, fetch metadata only for statistic_ids supporting it. + """ + if statistic_ids is None: + # Fetch metadata from the database + return self._get_from_database( + session, + statistic_type=statistic_type, + statistic_source=statistic_source, + ) + + if statistic_type is not None or statistic_source is not None: + # This was originally implemented but we never used it + # so the code was ripped out to reduce the maintenance + # burden. + raise ValueError( + "Providing statistic_type and statistic_source is mutually exclusive of statistic_ids" + ) + + results: dict[str, tuple[int, StatisticMetaData]] = {} + missing_statistic_id: list[str] = [] + + for statistic_id in statistic_ids: + if id_meta := self._stat_id_to_id_meta.get(statistic_id): + results[statistic_id] = id_meta + else: + missing_statistic_id.append(statistic_id) + + if not missing_statistic_id: + return results + + # Fetch metadata from the database + return results | self._get_from_database( + session, statistic_ids=missing_statistic_id + ) + + def update_or_add( + self, + session: Session, + new_metadata: StatisticMetaData, + old_metadata_dict: dict[str, tuple[int, StatisticMetaData]], + ) -> tuple[bool, int]: + """Get metadata_id for a statistic_id. + + If the statistic_id is previously unknown, add it. If it's already known, update + metadata if needed. + + Updating metadata source is not possible. + + Returns a tuple of (updated, metadata_id). + + updated is True if the metadata was updated, False if it was not updated. + + This call is not thread-safe and must be called from the + recorder thread. + """ + statistic_id = new_metadata["statistic_id"] + if statistic_id not in old_metadata_dict: + return True, self._add_metadata(session, statistic_id, new_metadata) + return self._update_metadata( + session, statistic_id, new_metadata, old_metadata_dict + ) + + def update_unit_of_measurement( + self, session: Session, statistic_id: str, new_unit: str | None + ) -> None: + """Update the unit of measurement for a statistic_id. + + This call is not thread-safe and must be called from the + recorder thread. + """ + self._assert_in_recorder_thread() + session.query(StatisticsMeta).filter( + StatisticsMeta.statistic_id == statistic_id + ).update({StatisticsMeta.unit_of_measurement: new_unit}) + self._clear_cache([statistic_id]) + + def update_statistic_id( + self, + session: Session, + source: str, + old_statistic_id: str, + new_statistic_id: str, + ) -> None: + """Update the statistic_id for a statistic_id. + + This call is not thread-safe and must be called from the + recorder thread. + """ + self._assert_in_recorder_thread() + session.query(StatisticsMeta).filter( + (StatisticsMeta.statistic_id == old_statistic_id) + & (StatisticsMeta.source == source) + ).update({StatisticsMeta.statistic_id: new_statistic_id}) + self._clear_cache([old_statistic_id, new_statistic_id]) + + def delete(self, session: Session, statistic_ids: list[str]) -> None: + """Clear statistics for a list of statistic_ids. + + This call is not thread-safe and must be called from the + recorder thread. + """ + self._assert_in_recorder_thread() + session.query(StatisticsMeta).filter( + StatisticsMeta.statistic_id.in_(statistic_ids) + ).delete(synchronize_session=False) + self._clear_cache(statistic_ids) + + def reset(self) -> None: + """Reset the cache.""" + self._stat_id_to_id_meta = {} diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index bd4facbea17..8d5af155fd7 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -145,31 +145,36 @@ def _parse_float(state: str) -> float: return fstate +def _float_or_none(state: str) -> float | None: + """Return a float or None.""" + try: + return _parse_float(state) + except (ValueError, TypeError): + return None + + +def _entity_history_to_float_and_state( + entity_history: Iterable[State], +) -> list[tuple[float, State]]: + """Return a list of (float, state) tuples for the given entity.""" + return [ + (fstate, state) + for state in entity_history + if (fstate := _float_or_none(state.state)) is not None + ] + + def _normalize_states( hass: HomeAssistant, - session: Session, old_metadatas: dict[str, tuple[int, StatisticMetaData]], - entity_history: Iterable[State], + fstates: list[tuple[float, State]], entity_id: str, ) -> tuple[str | None, list[tuple[float, State]]]: """Normalize units.""" - old_metadata = old_metadatas[entity_id][1] if entity_id in old_metadatas else None state_unit: str | None = None - - fstates: list[tuple[float, State]] = [] - for state in entity_history: - try: - fstate = _parse_float(state.state) - except (ValueError, TypeError): # TypeError to guard for NULL state in DB - continue - fstates.append((fstate, state)) - - if not fstates: - return None, fstates - - state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT) - statistics_unit: str | None + state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT) + old_metadata = old_metadatas[entity_id][1] if entity_id in old_metadatas else None if not old_metadata: # We've not seen this sensor before, the first valid state determines the unit # used for statistics @@ -379,7 +384,15 @@ def compile_statistics( Note: This will query the database and must not be run in the event loop """ - with recorder_util.session_scope(hass=hass) as session: + # There is already an active session when this code is called since + # it is called from the recorder statistics. We need to make sure + # this session never gets committed since it would be out of sync + # with the recorder statistics session so we mark it as read only. + # + # If we ever need to write to the database from this function we + # will need to refactor the recorder statistics to use a single + # session. + with recorder_util.session_scope(hass=hass, read_only=True) as session: compiled = _compile_statistics(hass, session, start, end) return compiled @@ -395,10 +408,6 @@ def _compile_statistics( # noqa: C901 sensor_states = _get_sensor_states(hass) wanted_statistics = _wanted_statistics(sensor_states) - old_metadatas = statistics.get_metadata_with_session( - session, statistic_ids=[i.entity_id for i in sensor_states] - ) - # Get history between start and end entities_full_history = [ i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id] @@ -427,34 +436,41 @@ def _compile_statistics( # noqa: C901 entity_ids=entities_significant_history, ) history_list = {**history_list, **_history_list} - # If there are no recent state changes, the sensor's state may already be pruned - # from the recorder. Get the state from the state machine instead. - for _state in sensor_states: - if _state.entity_id not in history_list: - history_list[_state.entity_id] = [_state] - to_process = [] - to_query = [] + entities_with_float_states: dict[str, list[tuple[float, State]]] = {} for _state in sensor_states: entity_id = _state.entity_id - if entity_id not in history_list: + # If there are no recent state changes, the sensor's state may already be pruned + # from the recorder. Get the state from the state machine instead. + if not (entity_history := history_list.get(entity_id, [_state])): continue + if not (float_states := _entity_history_to_float_and_state(entity_history)): + continue + entities_with_float_states[entity_id] = float_states - entity_history = history_list[entity_id] - statistics_unit, fstates = _normalize_states( + # Only lookup metadata for entities that have valid float states + # since it will result in cache misses for statistic_ids + # that are not in the metadata table and we are not working + # with them anyway. + old_metadatas = statistics.get_metadata_with_session( + get_instance(hass), session, statistic_ids=list(entities_with_float_states) + ) + to_process: list[tuple[str, str | None, str, list[tuple[float, State]]]] = [] + to_query: list[str] = [] + for _state in sensor_states: + entity_id = _state.entity_id + if not (maybe_float_states := entities_with_float_states.get(entity_id)): + continue + statistics_unit, valid_float_states = _normalize_states( hass, - session, old_metadatas, - entity_history, + maybe_float_states, entity_id, ) - - if not fstates: + if not valid_float_states: continue - - state_class = _state.attributes[ATTR_STATE_CLASS] - - to_process.append((entity_id, statistics_unit, state_class, fstates)) + state_class: str = _state.attributes[ATTR_STATE_CLASS] + to_process.append((entity_id, statistics_unit, state_class, valid_float_states)) if "sum" in wanted_statistics[entity_id]: to_query.append(entity_id) @@ -465,7 +481,7 @@ def _compile_statistics( # noqa: C901 entity_id, statistics_unit, state_class, - fstates, + valid_float_states, ) in to_process: # Check metadata if old_metadata := old_metadatas.get(entity_id): @@ -507,20 +523,20 @@ def _compile_statistics( # noqa: C901 if "max" in wanted_statistics[entity_id]: stat["max"] = max( *itertools.islice( - zip(*fstates), # type: ignore[typeddict-item] + zip(*valid_float_states), # type: ignore[typeddict-item] 1, ) ) if "min" in wanted_statistics[entity_id]: stat["min"] = min( *itertools.islice( - zip(*fstates), # type: ignore[typeddict-item] + zip(*valid_float_states), # type: ignore[typeddict-item] 1, ) ) if "mean" in wanted_statistics[entity_id]: - stat["mean"] = _time_weighted_average(fstates, start, end) + stat["mean"] = _time_weighted_average(valid_float_states, start, end) if "sum" in wanted_statistics[entity_id]: last_reset = old_last_reset = None @@ -535,7 +551,7 @@ def _compile_statistics( # noqa: C901 new_state = old_state = last_stat["state"] _sum = last_stat["sum"] or 0.0 - for fstate, state in fstates: + for fstate, state in valid_float_states: reset = False if ( state_class != SensorStateClass.TOTAL_INCREASING diff --git a/tests/components/recorder/table_managers/__init__.py b/tests/components/recorder/table_managers/__init__.py new file mode 100644 index 00000000000..52685ec18fa --- /dev/null +++ b/tests/components/recorder/table_managers/__init__.py @@ -0,0 +1 @@ +"""Tests for the recorder table managers.""" diff --git a/tests/components/recorder/table_managers/test_statistics_meta.py b/tests/components/recorder/table_managers/test_statistics_meta.py new file mode 100644 index 00000000000..8ec3f9367d6 --- /dev/null +++ b/tests/components/recorder/table_managers/test_statistics_meta.py @@ -0,0 +1,53 @@ +"""The tests for the Recorder component.""" +from __future__ import annotations + +import pytest + +from homeassistant.components import recorder +from homeassistant.components.recorder.util import session_scope +from homeassistant.core import HomeAssistant + +from tests.typing import RecorderInstanceGenerator + + +async def test_passing_mutually_exclusive_options_to_get_many( + async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant +) -> None: + """Test passing mutually exclusive options to get_many.""" + instance = await async_setup_recorder_instance( + hass, {recorder.CONF_COMMIT_INTERVAL: 0} + ) + with session_scope(session=instance.get_session()) as session: + with pytest.raises(ValueError): + instance.statistics_meta_manager.get_many( + session, + statistic_ids=["light.kitchen"], + statistic_type="mean", + ) + with pytest.raises(ValueError): + instance.statistics_meta_manager.get_many( + session, statistic_ids=["light.kitchen"], statistic_source="sensor" + ) + assert ( + instance.statistics_meta_manager.get_many( + session, + statistic_ids=["light.kitchen"], + ) + == {} + ) + + +async def test_unsafe_calls_to_statistics_meta_manager( + async_setup_recorder_instance: RecorderInstanceGenerator, hass: HomeAssistant +) -> None: + """Test we raise when trying to call non-threadsafe functions on statistics_meta_manager.""" + instance = await async_setup_recorder_instance( + hass, {recorder.CONF_COMMIT_INTERVAL: 0} + ) + with session_scope(session=instance.get_session()) as session, pytest.raises( + RuntimeError, match="Detected unsafe call not in recorder thread" + ): + instance.statistics_meta_manager.delete( + session, + statistic_ids=["light.kitchen"], + ) diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index e3b9145dc8b..8c1d8ef00aa 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -564,32 +564,6 @@ def _add_entities(hass, entity_ids): return states -def _add_events(hass, events): - with session_scope(hass=hass) as session: - session.query(Events).delete(synchronize_session=False) - for event_type in events: - hass.bus.fire(event_type) - wait_recording_done(hass) - - with session_scope(hass=hass) as session: - events = [] - for event, event_data, event_types in ( - session.query(Events, EventData, EventTypes) - .outerjoin(EventTypes, (Events.event_type_id == EventTypes.event_type_id)) - .outerjoin(EventData, Events.data_id == EventData.data_id) - ): - event = cast(Events, event) - event_data = cast(EventData, event_data) - event_types = cast(EventTypes, event_types) - - native_event = event.to_native() - if event_data: - native_event.data = event_data.to_native() - native_event.event_type = event_types.event_type - events.append(native_event) - return events - - def _state_with_context(hass, entity_id): # We don't restore context unless we need it by joining the # events table on the event_id for state_changed events @@ -646,25 +620,53 @@ def test_saving_state_incl_entities( assert _state_with_context(hass, "test2.recorder").as_dict() == states[0].as_dict() -def test_saving_event_exclude_event_type( - hass_recorder: Callable[..., HomeAssistant] +async def test_saving_event_exclude_event_type( + async_setup_recorder_instance: RecorderInstanceGenerator, + hass: HomeAssistant, ) -> None: """Test saving and restoring an event.""" - hass = hass_recorder( - { - "exclude": { - "event_types": [ - "service_registered", - "homeassistant_start", - "component_loaded", - "core_config_updated", - "homeassistant_started", - "test", - ] - } + config = { + "exclude": { + "event_types": [ + "service_registered", + "homeassistant_start", + "component_loaded", + "core_config_updated", + "homeassistant_started", + "test", + ] } - ) - events = _add_events(hass, ["test", "test2"]) + } + instance = await async_setup_recorder_instance(hass, config) + events = ["test", "test2"] + for event_type in events: + hass.bus.async_fire(event_type) + + await async_wait_recording_done(hass) + + def _get_events(hass: HomeAssistant, event_types: list[str]) -> list[Event]: + with session_scope(hass=hass) as session: + events = [] + for event, event_data, event_types in ( + session.query(Events, EventData, EventTypes) + .outerjoin( + EventTypes, (Events.event_type_id == EventTypes.event_type_id) + ) + .outerjoin(EventData, Events.data_id == EventData.data_id) + .where(EventTypes.event_type.in_(event_types)) + ): + event = cast(Events, event) + event_data = cast(EventData, event_data) + event_types = cast(EventTypes, event_types) + + native_event = event.to_native() + if event_data: + native_event.data = event_data.to_native() + native_event.event_type = event_types.event_type + events.append(native_event) + return events + + events = await instance.async_add_executor_job(_get_events, hass, ["test", "test2"]) assert len(events) == 1 assert events[0].event_type == "test2" diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 7c064a03edf..46d2c92e463 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -22,12 +22,10 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.components.recorder.statistics import ( STATISTIC_UNIT_TO_UNIT_CONVERTER, - _generate_get_metadata_stmt, _generate_max_mean_min_statistic_in_sub_period_stmt, _generate_statistics_at_time_stmt, _generate_statistics_during_period_stmt, _statistics_during_period_with_session, - _update_or_add_metadata, async_add_external_statistics, async_import_statistics, delete_statistics_duplicates, @@ -38,6 +36,10 @@ from homeassistant.components.recorder.statistics import ( get_metadata, list_statistic_ids, ) +from homeassistant.components.recorder.table_managers.statistics_meta import ( + StatisticsMetaManager, + _generate_get_metadata_stmt, +) from homeassistant.components.recorder.util import session_scope from homeassistant.components.sensor import UNIT_CONVERTERS from homeassistant.const import UnitOfTemperature @@ -1520,7 +1522,8 @@ def test_delete_metadata_duplicates_no_duplicates( hass = hass_recorder() wait_recording_done(hass) with session_scope(hass=hass) as session: - delete_statistics_meta_duplicates(session) + instance = recorder.get_instance(hass) + delete_statistics_meta_duplicates(instance, session) assert "duplicated statistics_meta rows" not in caplog.text @@ -1562,9 +1565,9 @@ async def test_validate_db_schema_fix_utf8_issue( with patch( "homeassistant.components.recorder.core.Recorder.dialect_name", "mysql" ), patch( - "homeassistant.components.recorder.statistics._update_or_add_metadata", + "homeassistant.components.recorder.table_managers.statistics_meta.StatisticsMetaManager.update_or_add", + wraps=StatisticsMetaManager.update_or_add, side_effect=[utf8_error, DEFAULT, DEFAULT], - wraps=_update_or_add_metadata, ): await async_setup_recorder_instance(hass) await async_wait_recording_done(hass)