Minor improvements of recorder typing (#80165)

* Minor improvements of recorder typing

* Only allow specifying statistic_ids as lists
This commit is contained in:
Erik Montnemery 2022-10-12 14:59:10 +02:00 committed by GitHub
parent 83557ef762
commit 577f7904b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 30 deletions

View File

@ -610,8 +610,12 @@ class Recorder(threading.Thread):
# wait for startup to complete. If its not live, we need to continue
# on.
self.hass.add_job(self.async_set_db_ready)
# If shutdown happened before Home Assistant finished starting
# We wait to start a live migration until startup has finished
# since it can be cpu intensive and we do not want it to compete
# with startup which is also cpu intensive
if self._wait_startup_or_shutdown() is SHUTDOWN_TASK:
# Shutdown happened before Home Assistant finished starting
self.migration_in_progress = False
# Make sure we cleanly close the run if
# we restart before startup finishes
@ -619,9 +623,6 @@ class Recorder(threading.Thread):
self.hass.add_job(self.async_set_db_ready)
return
# We wait to start the migration until startup has finished
# since it can be cpu intensive and we do not want it to compete
# with startup which is also cpu intensive
if not schema_is_current:
if self._migrate_schema_and_setup_run(current_version):
self.schema_version = SCHEMA_VERSION

View File

@ -1,9 +1,11 @@
"""Schema migration helpers."""
from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
from datetime import timedelta
import logging
from typing import Any, cast
from typing import TYPE_CHECKING, cast
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -40,6 +42,9 @@ from .statistics import (
)
from .util import session_scope
if TYPE_CHECKING:
from . import Recorder
LIVE_MIGRATION_MIN_SCHEMA_VERSION = 0
_LOGGER = logging.getLogger(__name__)
@ -86,7 +91,7 @@ def live_migration(current_version: int) -> bool:
def migrate_schema(
instance: Any,
instance: Recorder,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],

View File

@ -582,9 +582,7 @@ def _compile_hourly_statistics_summary_mean_stmt(
)
def compile_hourly_statistics(
instance: Recorder, session: Session, start: datetime
) -> None:
def _compile_hourly_statistics(session: Session, start: datetime) -> None:
"""Compile hourly statistics.
This will summarize 5-minute statistics for one hour:
@ -700,7 +698,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
if start.minute == 55:
# A full hour is ready, summarize it
compile_hourly_statistics(instance, session, start)
_compile_hourly_statistics(session, start)
session.add(StatisticsRuns(start=start))
@ -776,7 +774,7 @@ def _update_statistics(
def _generate_get_metadata_stmt(
statistic_ids: list[str] | tuple[str] | None = None,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> StatementLambdaElement:
@ -794,10 +792,9 @@ def _generate_get_metadata_stmt(
def get_metadata_with_session(
hass: HomeAssistant,
session: Session,
*,
statistic_ids: list[str] | tuple[str] | None = None,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
@ -834,14 +831,13 @@ def get_metadata_with_session(
def get_metadata(
hass: HomeAssistant,
*,
statistic_ids: list[str] | tuple[str] | None = None,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
"""Return metadata for statistic_ids."""
with session_scope(hass=hass) as session:
return get_metadata_with_session(
hass,
session,
statistic_ids=statistic_ids,
statistic_type=statistic_type,
@ -882,7 +878,7 @@ def update_statistics_metadata(
def list_statistic_ids(
hass: HomeAssistant,
statistic_ids: list[str] | tuple[str] | None = None,
statistic_ids: list[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> list[dict]:
"""Return all statistic_ids (or filtered one) and unit of measurement.
@ -896,7 +892,7 @@ def list_statistic_ids(
# Query the database
with session_scope(hass=hass) as session:
metadata = get_metadata_with_session(
hass, session, statistic_type=statistic_type, statistic_ids=statistic_ids
session, statistic_type=statistic_type, statistic_ids=statistic_ids
)
result = {
@ -1105,7 +1101,7 @@ def statistics_during_period(
metadata = None
with session_scope(hass=hass) as session:
# Fetch metadata for the given (or all) statistic_ids
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids)
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata:
return {}
@ -1196,7 +1192,7 @@ def _get_last_statistics(
statistic_ids = [statistic_id]
with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_id
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids)
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata:
return {}
metadata_id = metadata[statistic_id][0]
@ -1280,9 +1276,7 @@ def get_latest_short_term_statistics(
with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_ids
if not metadata:
metadata = get_metadata_with_session(
hass, session, statistic_ids=statistic_ids
)
metadata = get_metadata_with_session(session, statistic_ids=statistic_ids)
if not metadata:
return {}
metadata_ids = [
@ -1565,7 +1559,7 @@ def import_statistics(
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
old_metadata_dict = get_metadata_with_session(
instance.hass, session, statistic_ids=[metadata["statistic_id"]]
session, statistic_ids=[metadata["statistic_id"]]
)
metadata_id = _update_or_add_metadata(session, metadata, old_metadata_dict)
for stat in statistics:
@ -1590,9 +1584,7 @@ def adjust_statistics(
"""Process an add_statistics job."""
with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(
instance.hass, session, statistic_ids=(statistic_id,)
)
metadata = get_metadata_with_session(session, statistic_ids=[statistic_id])
if statistic_id not in metadata:
return True
@ -1652,9 +1644,9 @@ def change_statistics_unit(
) -> None:
"""Change statistics unit for a statistic_id."""
with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(
instance.hass, session, statistic_ids=(statistic_id,)
).get(statistic_id)
metadata = get_metadata_with_session(session, statistic_ids=[statistic_id]).get(
statistic_id
)
# Guard against the statistics being removed or updated before the
# change_statistics_unit job executes

View File

@ -387,7 +387,7 @@ def _compile_statistics( # noqa: C901
sensor_states = _get_sensor_states(hass)
wanted_statistics = _wanted_statistics(sensor_states)
old_metadatas = statistics.get_metadata_with_session(
hass, session, statistic_ids=[i.entity_id for i in sensor_states]
session, statistic_ids=[i.entity_id for i in sensor_states]
)
# Get history between start and end