Minor improvements of recorder typing (#80165)

* Minor improvements of recorder typing

* Only allow specifying statistic_ids as lists
This commit is contained in:
Erik Montnemery 2022-10-12 14:59:10 +02:00 committed by GitHub
parent 83557ef762
commit 577f7904b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 30 deletions

View File

@ -610,8 +610,12 @@ class Recorder(threading.Thread):
# wait for startup to complete. If its not live, we need to continue # wait for startup to complete. If its not live, we need to continue
# on. # on.
self.hass.add_job(self.async_set_db_ready) self.hass.add_job(self.async_set_db_ready)
# If shutdown happened before Home Assistant finished starting
# We wait to start a live migration until startup has finished
# since it can be cpu intensive and we do not want it to compete
# with startup which is also cpu intensive
if self._wait_startup_or_shutdown() is SHUTDOWN_TASK: if self._wait_startup_or_shutdown() is SHUTDOWN_TASK:
# Shutdown happened before Home Assistant finished starting
self.migration_in_progress = False self.migration_in_progress = False
# Make sure we cleanly close the run if # Make sure we cleanly close the run if
# we restart before startup finishes # we restart before startup finishes
@ -619,9 +623,6 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_set_db_ready) self.hass.add_job(self.async_set_db_ready)
return return
# We wait to start the migration until startup has finished
# since it can be cpu intensive and we do not want it to compete
# with startup which is also cpu intensive
if not schema_is_current: if not schema_is_current:
if self._migrate_schema_and_setup_run(current_version): if self._migrate_schema_and_setup_run(current_version):
self.schema_version = SCHEMA_VERSION self.schema_version = SCHEMA_VERSION

View File

@ -1,9 +1,11 @@
"""Schema migration helpers.""" """Schema migration helpers."""
from __future__ import annotations
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
import contextlib import contextlib
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any, cast from typing import TYPE_CHECKING, cast
import sqlalchemy import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -40,6 +42,9 @@ from .statistics import (
) )
from .util import session_scope from .util import session_scope
if TYPE_CHECKING:
from . import Recorder
LIVE_MIGRATION_MIN_SCHEMA_VERSION = 0 LIVE_MIGRATION_MIN_SCHEMA_VERSION = 0
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -86,7 +91,7 @@ def live_migration(current_version: int) -> bool:
def migrate_schema( def migrate_schema(
instance: Any, instance: Recorder,
hass: HomeAssistant, hass: HomeAssistant,
engine: Engine, engine: Engine,
session_maker: Callable[[], Session], session_maker: Callable[[], Session],

View File

@ -582,9 +582,7 @@ def _compile_hourly_statistics_summary_mean_stmt(
) )
def compile_hourly_statistics( def _compile_hourly_statistics(session: Session, start: datetime) -> None:
instance: Recorder, session: Session, start: datetime
) -> None:
"""Compile hourly statistics. """Compile hourly statistics.
This will summarize 5-minute statistics for one hour: This will summarize 5-minute statistics for one hour:
@ -700,7 +698,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
if start.minute == 55: if start.minute == 55:
# A full hour is ready, summarize it # A full hour is ready, summarize it
compile_hourly_statistics(instance, session, start) _compile_hourly_statistics(session, start)
session.add(StatisticsRuns(start=start)) session.add(StatisticsRuns(start=start))
@ -776,7 +774,7 @@ def _update_statistics(
def _generate_get_metadata_stmt( def _generate_get_metadata_stmt(
statistic_ids: list[str] | tuple[str] | None = None, statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> StatementLambdaElement: ) -> StatementLambdaElement:
@ -794,10 +792,9 @@ def _generate_get_metadata_stmt(
def get_metadata_with_session( def get_metadata_with_session(
hass: HomeAssistant,
session: Session, session: Session,
*, *,
statistic_ids: list[str] | tuple[str] | None = None, statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]: ) -> dict[str, tuple[int, StatisticMetaData]]:
@ -834,14 +831,13 @@ def get_metadata_with_session(
def get_metadata( def get_metadata(
hass: HomeAssistant, hass: HomeAssistant,
*, *,
statistic_ids: list[str] | tuple[str] | None = None, statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None, statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]: ) -> dict[str, tuple[int, StatisticMetaData]]:
"""Return metadata for statistic_ids.""" """Return metadata for statistic_ids."""
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
return get_metadata_with_session( return get_metadata_with_session(
hass,
session, session,
statistic_ids=statistic_ids, statistic_ids=statistic_ids,
statistic_type=statistic_type, statistic_type=statistic_type,
@ -882,7 +878,7 @@ def update_statistics_metadata(
def list_statistic_ids( def list_statistic_ids(
hass: HomeAssistant, hass: HomeAssistant,
statistic_ids: list[str] | tuple[str] | None = None, statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> list[dict]: ) -> list[dict]:
"""Return all statistic_ids (or filtered one) and unit of measurement. """Return all statistic_ids (or filtered one) and unit of measurement.
@ -896,7 +892,7 @@ def list_statistic_ids(
# Query the database # Query the database
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
metadata = get_metadata_with_session( metadata = get_metadata_with_session(
hass, session, statistic_type=statistic_type, statistic_ids=statistic_ids session, statistic_type=statistic_type, statistic_ids=statistic_ids
) )
result = { result = {
@ -1105,7 +1101,7 @@ def statistics_during_period(
metadata = None metadata = None
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
# Fetch metadata for the given (or all) statistic_ids # Fetch metadata for the given (or all) statistic_ids
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids) metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata: if not metadata:
return {} return {}
@ -1196,7 +1192,7 @@ def _get_last_statistics(
statistic_ids = [statistic_id] statistic_ids = [statistic_id]
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_id # Fetch metadata for the given statistic_id
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids) metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata: if not metadata:
return {} return {}
metadata_id = metadata[statistic_id][0] metadata_id = metadata[statistic_id][0]
@ -1280,9 +1276,7 @@ def get_latest_short_term_statistics(
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_ids # Fetch metadata for the given statistic_ids
if not metadata: if not metadata:
metadata = get_metadata_with_session( metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
hass, session, statistic_ids=statistic_ids
)
if not metadata: if not metadata:
return {} return {}
metadata_ids = [ metadata_ids = [
@ -1565,7 +1559,7 @@ def import_statistics(
exception_filter=_filter_unique_constraint_integrity_error(instance), exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session: ) as session:
old_metadata_dict = get_metadata_with_session( old_metadata_dict = get_metadata_with_session(
instance.hass, session, statistic_ids=[metadata["statistic_id"]] session, statistic_ids=[metadata["statistic_id"]]
) )
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict) metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
for stat in statistics: for stat in statistics:
@ -1590,9 +1584,7 @@ def adjust_statistics(
"""Process an add_statistics job.""" """Process an add_statistics job."""
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session( metadata = get_metadata_with_session(session, statistic_ids=[statistic_id])
instance.hass, session, statistic_ids=(statistic_id,)
)
if statistic_id not in metadata: if statistic_id not in metadata:
return True return True
@ -1652,9 +1644,9 @@ def change_statistics_unit(
) -> None: ) -> None:
"""Change statistics unit for a statistic_id.""" """Change statistics unit for a statistic_id."""
with session_scope(session=instance.get_session()) as session: with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session( metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]).get(
instance.hass, session, statistic_ids=(statistic_id,) statistic_id
).get(statistic_id) )
# Guard against the statistics being removed or updated before the # Guard against the statistics being removed or updated before the
# change_statistics_unit job executes # change_statistics_unit job executes

View File

@ -387,7 +387,7 @@ def _compile_statistics( # noqa: C901
sensor_states = _get_sensor_states(hass) sensor_states = _get_sensor_states(hass)
wanted_statistics = _wanted_statistics(sensor_states) wanted_statistics = _wanted_statistics(sensor_states)
old_metadatas = statistics.get_metadata_with_session( old_metadatas = statistics.get_metadata_with_session(
hass, session, statistic_ids=[i.entity_id for i in sensor_states] session, statistic_ids=[i.entity_id for i in sensor_states]
) )
# Get history between start and end # Get history between start and end