From 2fe8e953092f1e3a653d8200284d5a990c398d66 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Sat, 26 Nov 2022 19:00:40 +0100 Subject: [PATCH] Add helper to calculate statistic period start and end (#82493) * Add helper to calculate statistic period start and end * Don't parse values in resolve_period * Add specific test for resolve_period * Improve typing * Move to recorder/util.py * Extract period schema --- homeassistant/components/recorder/models.py | 33 ++++++- homeassistant/components/recorder/util.py | 84 ++++++++++++++++- .../components/recorder/websocket_api.py | 94 +++---------------- tests/components/recorder/test_util.py | 81 +++++++++++++++- 4 files changed, 206 insertions(+), 86 deletions(-) diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 3ab8b890838..48b45b4da2e 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -1,9 +1,9 @@ """Models for Recorder.""" from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta import logging -from typing import Any, TypedDict, overload +from typing import Any, Literal, TypedDict, overload from sqlalchemy.engine.row import Row @@ -284,3 +284,32 @@ def row_to_compressed_state( row_changed_changed ) return comp_state + + +class CalendarStatisticPeriod(TypedDict, total=False): + """Statistic period definition.""" + + period: Literal["hour", "day", "week", "month", "year"] + offset: int + + +class FixedStatisticPeriod(TypedDict, total=False): + """Statistic period definition.""" + + end_time: datetime + start_time: datetime + + +class RollingWindowStatisticPeriod(TypedDict, total=False): + """Statistic period definition.""" + + duration: timedelta + offset: timedelta + + +class StatisticPeriod(TypedDict, total=False): + """Statistic period definition.""" + + calendar: CalendarStatisticPeriod + fixed_period: FixedStatisticPeriod + rolling_window: RollingWindowStatisticPeriod diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py index 8ee9a4e0401..a52a067b975 100644 --- a/homeassistant/components/recorder/util.py +++ b/homeassistant/components/recorder/util.py @@ -24,8 +24,10 @@ from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.lambdas import StatementLambdaElement from typing_extensions import Concatenate, ParamSpec +import voluptuous as vol from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv import homeassistant.util.dt as dt_util from .const import DATA_INSTANCE, SQLITE_URL_PREFIX, SupportedDialect @@ -35,7 +37,7 @@ from .db_schema import ( TABLES_TO_CHECK, RecorderRuns, ) -from .models import UnsupportedDialect, process_timestamp +from .models import StatisticPeriod, UnsupportedDialect, process_timestamp if TYPE_CHECKING: from . import Recorder @@ -604,3 +606,83 @@ def get_instance(hass: HomeAssistant) -> Recorder: """Get the recorder instance.""" instance: Recorder = hass.data[DATA_INSTANCE] return instance + + +PERIOD_SCHEMA = vol.Schema( + { + vol.Exclusive("calendar", "period"): vol.Schema( + { + vol.Required("period"): vol.Any("hour", "day", "week", "month", "year"), + vol.Optional("offset"): int, + } + ), + vol.Exclusive("fixed_period", "period"): vol.Schema( + { + vol.Optional("start_time"): vol.All(cv.datetime, dt_util.as_utc), + vol.Optional("end_time"): vol.All(cv.datetime, dt_util.as_utc), + } + ), + vol.Exclusive("rolling_window", "period"): vol.Schema( + { + vol.Required("duration"): cv.time_period_dict, + vol.Optional("offset"): cv.time_period_dict, + } + ), + } +) + + +def resolve_period( + period_def: StatisticPeriod, +) -> tuple[datetime | None, datetime | None]: + """Return start and end datetimes for a statistic period definition.""" + start_time = None + end_time = None + + if "calendar" in period_def: + calendar_period = period_def["calendar"]["period"] + start_of_day = dt_util.start_of_local_day() + cal_offset = period_def["calendar"].get("offset", 0) + if calendar_period == "hour": + start_time = dt_util.now().replace(minute=0, second=0, microsecond=0) + start_time += timedelta(hours=cal_offset) + end_time = start_time + timedelta(hours=1) + elif calendar_period == "day": + start_time = start_of_day + start_time += timedelta(days=cal_offset) + end_time = start_time + timedelta(days=1) + elif calendar_period == "week": + start_time = start_of_day - timedelta(days=start_of_day.weekday()) + start_time += timedelta(days=cal_offset * 7) + end_time = start_time + timedelta(weeks=1) + elif calendar_period == "month": + start_time = start_of_day.replace(day=28) + # This works for up to 48 months of offset + start_time = (start_time + timedelta(days=cal_offset * 31)).replace(day=1) + end_time = (start_time + timedelta(days=31)).replace(day=1) + else: # calendar_period = "year" + start_time = start_of_day.replace(month=12, day=31) + # This works for 100+ years of offset + start_time = (start_time + timedelta(days=cal_offset * 366)).replace( + month=1, day=1 + ) + end_time = (start_time + timedelta(days=365)).replace(day=1) + + start_time = dt_util.as_utc(start_time) + end_time = dt_util.as_utc(end_time) + + elif "fixed_period" in period_def: + start_time = period_def["fixed_period"].get("start_time") + end_time = period_def["fixed_period"].get("end_time") + + elif "rolling_window" in period_def: + duration = period_def["rolling_window"]["duration"] + now = dt_util.utcnow() + start_time = now - duration + end_time = start_time + duration + + if offset := period_def["rolling_window"].get("offset"): + start_time += offset + end_time += offset + + return (start_time, end_time) diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index a9fe1589654..35879bfc076 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -1,9 +1,9 @@ """The Recorder websocket API.""" from __future__ import annotations -from datetime import datetime as dt, timedelta +from datetime import datetime as dt import logging -from typing import Any, Literal +from typing import Any, Literal, cast import voluptuous as vol @@ -26,6 +26,7 @@ from homeassistant.util.unit_conversion import ( ) from .const import MAX_QUEUE_BACKLOG +from .models import StatisticPeriod from .statistics import ( STATISTIC_UNIT_TO_UNIT_CONVERTER, async_add_external_statistics, @@ -36,7 +37,13 @@ from .statistics import ( statistics_during_period, validate_statistics, ) -from .util import async_migration_in_progress, async_migration_is_live, get_instance +from .util import ( + PERIOD_SCHEMA, + async_migration_in_progress, + async_migration_is_live, + get_instance, + resolve_period, +) _LOGGER: logging.Logger = logging.getLogger(__package__) @@ -82,24 +89,6 @@ def _ws_get_statistic_during_period( @websocket_api.websocket_command( { vol.Required("type"): "recorder/statistic_during_period", - vol.Exclusive("calendar", "period"): vol.Schema( - { - vol.Required("period"): vol.Any("hour", "day", "week", "month", "year"), - vol.Optional("offset"): int, - } - ), - vol.Exclusive("fixed_period", "period"): vol.Schema( - { - vol.Optional("start_time"): str, - vol.Optional("end_time"): str, - } - ), - vol.Exclusive("rolling_window", "period"): vol.Schema( - { - vol.Required("duration"): cv.time_period_dict, - vol.Optional("offset"): cv.time_period_dict, - } - ), vol.Optional("statistic_id"): str, vol.Optional("types"): vol.All( [vol.Any("max", "mean", "min", "change")], vol.Coerce(set) @@ -116,6 +105,7 @@ def _ws_get_statistic_during_period( vol.Optional("volume"): vol.In(VolumeConverter.VALID_UNITS), } ), + **PERIOD_SCHEMA.schema, } ) @websocket_api.async_response @@ -128,67 +118,7 @@ async def ws_get_statistic_during_period( if "offset" in msg and "duration" not in msg: raise HomeAssistantError - start_time = None - end_time = None - - if "calendar" in msg: - calendar_period = msg["calendar"]["period"] - start_of_day = dt_util.start_of_local_day() - offset = msg["calendar"].get("offset", 0) - if calendar_period == "hour": - start_time = dt_util.now().replace(minute=0, second=0, microsecond=0) - start_time += timedelta(hours=offset) - end_time = start_time + timedelta(hours=1) - elif calendar_period == "day": - start_time = start_of_day - start_time += timedelta(days=offset) - end_time = start_time + timedelta(days=1) - elif calendar_period == "week": - start_time = start_of_day - timedelta(days=start_of_day.weekday()) - start_time += timedelta(days=offset * 7) - end_time = start_time + timedelta(weeks=1) - elif calendar_period == "month": - start_time = start_of_day.replace(day=28) - # This works for up to 48 months of offset - start_time = (start_time + timedelta(days=offset * 31)).replace(day=1) - end_time = (start_time + timedelta(days=31)).replace(day=1) - else: # calendar_period = "year" - start_time = start_of_day.replace(month=12, day=31) - # This works for 100+ years of offset - start_time = (start_time + timedelta(days=offset * 366)).replace( - month=1, day=1 - ) - end_time = (start_time + timedelta(days=365)).replace(day=1) - - start_time = dt_util.as_utc(start_time) - end_time = dt_util.as_utc(end_time) - - elif "fixed_period" in msg: - if start_time_str := msg["fixed_period"].get("start_time"): - if start_time := dt_util.parse_datetime(start_time_str): - start_time = dt_util.as_utc(start_time) - else: - connection.send_error( - msg["id"], "invalid_start_time", "Invalid start_time" - ) - return - - if end_time_str := msg["fixed_period"].get("end_time"): - if end_time := dt_util.parse_datetime(end_time_str): - end_time = dt_util.as_utc(end_time) - else: - connection.send_error(msg["id"], "invalid_end_time", "Invalid end_time") - return - - elif "rolling_window" in msg: - duration = msg["rolling_window"]["duration"] - now = dt_util.utcnow() - start_time = now - duration - end_time = start_time + duration - - if offset := msg["rolling_window"].get("offset"): - start_time += offset - end_time += offset + start_time, end_time = resolve_period(cast(StatisticPeriod, msg)) connection.send_message( await get_instance(hass).async_add_executor_job( diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py index 9000379c17d..54f9bb9b9f2 100644 --- a/tests/components/recorder/test_util.py +++ b/tests/components/recorder/test_util.py @@ -1,9 +1,10 @@ """Test util methods.""" -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import os import sqlite3 from unittest.mock import MagicMock, Mock, patch +from freezegun import freeze_time import pytest from sqlalchemy import text from sqlalchemy.engine.result import ChunkedIteratorResult @@ -19,6 +20,7 @@ from homeassistant.components.recorder.models import UnsupportedDialect from homeassistant.components.recorder.util import ( end_incomplete_runs, is_second_sunday, + resolve_period, session_scope, ) from homeassistant.const import EVENT_HOMEASSISTANT_STOP @@ -776,3 +778,80 @@ def test_execute_stmt_lambda_element(hass_recorder): with patch.object(session, "execute", MockExecutor): rows = util.execute_stmt_lambda_element(session, stmt, now, tomorrow) assert rows == ["mock_row"] + + +@freeze_time(datetime(2022, 10, 21, 7, 25, tzinfo=timezone.utc)) +async def test_resolve_period(hass): + """Test statistic_during_period.""" + + now = dt_util.utcnow() + + start_t, end_t = resolve_period({"calendar": {"period": "hour"}}) + assert start_t.isoformat() == "2022-10-21T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-21T08:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "hour"}}) + assert start_t.isoformat() == "2022-10-21T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-21T08:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "hour", "offset": -1}}) + assert start_t.isoformat() == "2022-10-21T06:00:00+00:00" + assert end_t.isoformat() == "2022-10-21T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "day"}}) + assert start_t.isoformat() == "2022-10-21T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-22T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "day", "offset": -1}}) + assert start_t.isoformat() == "2022-10-20T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-21T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "week"}}) + assert start_t.isoformat() == "2022-10-17T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-24T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "week", "offset": -1}}) + assert start_t.isoformat() == "2022-10-10T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-17T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "month"}}) + assert start_t.isoformat() == "2022-10-01T07:00:00+00:00" + assert end_t.isoformat() == "2022-11-01T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "month", "offset": -1}}) + assert start_t.isoformat() == "2022-09-01T07:00:00+00:00" + assert end_t.isoformat() == "2022-10-01T07:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "year"}}) + assert start_t.isoformat() == "2022-01-01T08:00:00+00:00" + assert end_t.isoformat() == "2023-01-01T08:00:00+00:00" + + start_t, end_t = resolve_period({"calendar": {"period": "year", "offset": -1}}) + assert start_t.isoformat() == "2021-01-01T08:00:00+00:00" + assert end_t.isoformat() == "2022-01-01T08:00:00+00:00" + + # Fixed period + assert resolve_period({}) == (None, None) + + assert resolve_period({"fixed_period": {"end_time": now}}) == (None, now) + + assert resolve_period({"fixed_period": {"start_time": now}}) == (now, None) + + assert resolve_period({"fixed_period": {"end_time": now, "start_time": now}}) == ( + now, + now, + ) + + # Rolling window + assert resolve_period( + {"rolling_window": {"duration": timedelta(hours=1, minutes=25)}} + ) == (now - timedelta(hours=1, minutes=25), now) + + assert resolve_period( + { + "rolling_window": { + "duration": timedelta(hours=1), + "offset": timedelta(minutes=-25), + } + } + ) == (now - timedelta(hours=1, minutes=25), now - timedelta(minutes=25))