diff --git a/homeassistant/components/demo/__init__.py b/homeassistant/components/demo/__init__.py index 7ed989903e5..4d0ef03c564 100644 --- a/homeassistant/components/demo/__init__.py +++ b/homeassistant/components/demo/__init__.py @@ -295,6 +295,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: metadata: StatisticMetaData = { "source": DOMAIN, "name": "Outdoor temperature", + "state_unit_of_measurement": TEMP_CELSIUS, "statistic_id": f"{DOMAIN}:temperature_outdoor", "unit_of_measurement": TEMP_CELSIUS, "has_mean": True, @@ -308,6 +309,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: metadata = { "source": DOMAIN, "name": "Energy consumption 1", + "state_unit_of_measurement": ENERGY_KILO_WATT_HOUR, "statistic_id": f"{DOMAIN}:energy_consumption_kwh", "unit_of_measurement": ENERGY_KILO_WATT_HOUR, "has_mean": False, @@ -320,6 +322,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: metadata = { "source": DOMAIN, "name": "Energy consumption 2", + "state_unit_of_measurement": ENERGY_MEGA_WATT_HOUR, "statistic_id": f"{DOMAIN}:energy_consumption_mwh", "unit_of_measurement": ENERGY_MEGA_WATT_HOUR, "has_mean": False, @@ -334,6 +337,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: metadata = { "source": DOMAIN, "name": "Gas consumption 1", + "state_unit_of_measurement": VOLUME_CUBIC_METERS, "statistic_id": f"{DOMAIN}:gas_consumption_m3", "unit_of_measurement": VOLUME_CUBIC_METERS, "has_mean": False, @@ -348,6 +352,7 @@ async def _insert_statistics(hass: HomeAssistant) -> None: metadata = { "source": DOMAIN, "name": "Gas consumption 2", + "state_unit_of_measurement": VOLUME_CUBIC_FEET, "statistic_id": f"{DOMAIN}:gas_consumption_ft3", "unit_of_measurement": VOLUME_CUBIC_FEET, "has_mean": False, diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index 40c0453ea0b..363604d525b 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -53,7 +53,7 @@ from .models import StatisticData, StatisticMetaData, process_timestamp # pylint: disable=invalid-name Base = declarative_base() -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 _StatisticsBaseSelfT = TypeVar("_StatisticsBaseSelfT", bound="StatisticsBase") @@ -494,6 +494,7 @@ class StatisticsMeta(Base): # type: ignore[misc,valid-type] id = Column(Integer, Identity(), primary_key=True) statistic_id = Column(String(255), index=True, unique=True) source = Column(String(32)) + state_unit_of_measurement = Column(String(255)) unit_of_measurement = Column(String(255)) has_mean = Column(Boolean) has_sum = Column(Boolean) diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index ab9b93de5e5..e2169727382 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -747,6 +747,25 @@ def _apply_update( # noqa: C901 _create_index( session_maker, "statistics_meta", "ix_statistics_meta_statistic_id" ) + elif new_version == 30: + _add_columns( + session_maker, + "statistics_meta", + ["state_unit_of_measurement VARCHAR(255)"], + ) + # When querying the database, be careful to only explicitly query for columns + # which were present in schema version 30. If querying the table, SQLAlchemy + # will refer to future columns. + with session_scope(session=session_maker()) as session: + for statistics_meta in session.query( + StatisticsMeta.id, StatisticsMeta.unit_of_measurement + ): + session.query(StatisticsMeta).filter_by(id=statistics_meta.id).update( + { + StatisticsMeta.state_unit_of_measurement: statistics_meta.unit_of_measurement, + }, + synchronize_session=False, + ) else: raise ValueError(f"No schema migration defined for version {new_version}") diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 2004c3ec30d..78ebaabc0fd 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -64,6 +64,7 @@ class StatisticMetaData(TypedDict): has_sum: bool name: str | None source: str + state_unit_of_measurement: str | None statistic_id: str unit_of_measurement: str | None diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index a1ab58ee011..8585ca37fac 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -12,7 +12,7 @@ import logging import os import re from statistics import mean -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import bindparam, func, lambda_stmt, select from sqlalchemy.engine.row import Row @@ -24,6 +24,9 @@ from sqlalchemy.sql.selectable import Subquery import voluptuous as vol from homeassistant.const import ( + ENERGY_KILO_WATT_HOUR, + POWER_KILO_WATT, + POWER_WATT, PRESSURE_PA, TEMP_CELSIUS, VOLUME_CUBIC_FEET, @@ -115,6 +118,7 @@ QUERY_STATISTIC_META = [ StatisticsMeta.id, StatisticsMeta.statistic_id, StatisticsMeta.source, + StatisticsMeta.state_unit_of_measurement, StatisticsMeta.unit_of_measurement, StatisticsMeta.has_mean, StatisticsMeta.has_sum, @@ -127,24 +131,49 @@ QUERY_STATISTIC_META_ID = [ ] -# Convert pressure, temperature and volume statistics from the normalized unit used for -# statistics to the unit configured by the user -STATISTIC_UNIT_TO_DISPLAY_UNIT_CONVERSIONS = { - PRESSURE_PA: lambda x, units: pressure_util.convert( - x, PRESSURE_PA, units.pressure_unit - ) - if x is not None - else None, - TEMP_CELSIUS: lambda x, units: temperature_util.convert( - x, TEMP_CELSIUS, units.temperature_unit - ) - if x is not None - else None, - VOLUME_CUBIC_METERS: lambda x, units: volume_util.convert( - x, VOLUME_CUBIC_METERS, _configured_unit(VOLUME_CUBIC_METERS, units) - ) - if x is not None - else None, +def _convert_power(value: float | None, state_unit: str, _: UnitSystem) -> float | None: + """Convert power in W to to_unit.""" + if value is None: + return None + if state_unit == POWER_KILO_WATT: + return value / 1000 + return value + + +def _convert_pressure( + value: float | None, state_unit: str, _: UnitSystem +) -> float | None: + """Convert pressure in Pa to to_unit.""" + if value is None: + return None + return pressure_util.convert(value, PRESSURE_PA, state_unit) + + +def _convert_temperature( + value: float | None, state_unit: str, _: UnitSystem +) -> float | None: + """Convert temperature in °C to to_unit.""" + if value is None: + return None + return temperature_util.convert(value, TEMP_CELSIUS, state_unit) + + +def _convert_volume(value: float | None, _: str, units: UnitSystem) -> float | None: + """Convert volume in m³ to ft³ or m³.""" + if value is None: + return None + return volume_util.convert(value, VOLUME_CUBIC_METERS, _volume_unit(units)) + + +# Convert power, pressure, temperature and volume statistics from the normalized unit +# used for statistics to the unit configured by the user +STATISTIC_UNIT_TO_DISPLAY_UNIT_CONVERSIONS: dict[ + str, Callable[[float | None, str, UnitSystem], float | None] +] = { + POWER_WATT: _convert_power, + PRESSURE_PA: _convert_pressure, + TEMP_CELSIUS: _convert_temperature, + VOLUME_CUBIC_METERS: _convert_volume, } # Convert volume statistics from the display unit configured by the user @@ -154,7 +183,7 @@ DISPLAY_UNIT_TO_STATISTIC_UNIT_CONVERSIONS: dict[ str, Callable[[float, UnitSystem], float] ] = { VOLUME_CUBIC_FEET: lambda x, units: volume_util.convert( - x, _configured_unit(VOLUME_CUBIC_METERS, units), VOLUME_CUBIC_METERS + x, _volume_unit(units), VOLUME_CUBIC_METERS ), } @@ -268,6 +297,8 @@ def _update_or_add_metadata( old_metadata["has_mean"] != new_metadata["has_mean"] or old_metadata["has_sum"] != new_metadata["has_sum"] or old_metadata["name"] != new_metadata["name"] + or old_metadata["state_unit_of_measurement"] + != new_metadata["state_unit_of_measurement"] or old_metadata["unit_of_measurement"] != new_metadata["unit_of_measurement"] ): session.query(StatisticsMeta).filter_by(statistic_id=statistic_id).update( @@ -275,6 +306,9 @@ def _update_or_add_metadata( StatisticsMeta.has_mean: new_metadata["has_mean"], StatisticsMeta.has_sum: new_metadata["has_sum"], StatisticsMeta.name: new_metadata["name"], + StatisticsMeta.state_unit_of_measurement: new_metadata[ + "state_unit_of_measurement" + ], StatisticsMeta.unit_of_measurement: new_metadata["unit_of_measurement"], }, synchronize_session=False, @@ -737,12 +771,13 @@ def get_metadata_with_session( meta["statistic_id"]: ( meta["id"], { - "source": meta["source"], - "statistic_id": meta["statistic_id"], - "unit_of_measurement": meta["unit_of_measurement"], "has_mean": meta["has_mean"], "has_sum": meta["has_sum"], "name": meta["name"], + "source": meta["source"], + "state_unit_of_measurement": meta["state_unit_of_measurement"], + "statistic_id": meta["statistic_id"], + "unit_of_measurement": meta["unit_of_measurement"], }, ) for meta in result @@ -767,27 +802,26 @@ def get_metadata( ) -@overload -def _configured_unit(unit: None, units: UnitSystem) -> None: - ... +def _volume_unit(units: UnitSystem) -> str: + """Return the preferred volume unit according to unit system.""" + if units.is_metric: + return VOLUME_CUBIC_METERS + return VOLUME_CUBIC_FEET -@overload -def _configured_unit(unit: str, units: UnitSystem) -> str: - ... +def _configured_unit( + unit: str | None, state_unit: str | None, units: UnitSystem +) -> str | None: + """Return the pressure and temperature units configured by the user. - -def _configured_unit(unit: str | None, units: UnitSystem) -> str | None: - """Return the pressure and temperature units configured by the user.""" - if unit == PRESSURE_PA: - return units.pressure_unit - if unit == TEMP_CELSIUS: - return units.temperature_unit + Energy and volume is normalized for the energy dashboard. + For other units, display in the unit of the source. + """ + if unit == ENERGY_KILO_WATT_HOUR: + return ENERGY_KILO_WATT_HOUR if unit == VOLUME_CUBIC_METERS: - if units.is_metric: - return VOLUME_CUBIC_METERS - return VOLUME_CUBIC_FEET - return unit + return _volume_unit(units) + return state_unit def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None: @@ -834,10 +868,10 @@ def list_statistic_ids( """ result = {} - def _display_unit(hass: HomeAssistant, unit: str | None) -> str | None: - if unit is None: - return None - return _configured_unit(unit, hass.config.units) + def _display_unit( + hass: HomeAssistant, statistic_unit: str | None, state_unit: str | None + ) -> str | None: + return _configured_unit(statistic_unit, state_unit, hass.config.units) # Query the database with session_scope(hass=hass) as session: @@ -852,7 +886,7 @@ def list_statistic_ids( "name": meta["name"], "source": meta["source"], "display_unit_of_measurement": _display_unit( - hass, meta["unit_of_measurement"] + hass, meta["unit_of_measurement"], meta["state_unit_of_measurement"] ), "unit_of_measurement": meta["unit_of_measurement"], } @@ -876,7 +910,7 @@ def list_statistic_ids( "name": meta["name"], "source": meta["source"], "display_unit_of_measurement": _display_unit( - hass, meta["unit_of_measurement"] + hass, meta["unit_of_measurement"], meta["state_unit_of_measurement"] ), "unit_of_measurement": meta["unit_of_measurement"], } @@ -1295,7 +1329,7 @@ def _sorted_statistics_to_dict( need_stat_at_start_time: set[int] = set() stats_at_start_time = {} - def no_conversion(val: Any, _: Any) -> float | None: + def no_conversion(val: Any, _unit: str | None, _units: Any) -> float | None: """Return x.""" return val # type: ignore[no-any-return] @@ -1321,10 +1355,13 @@ def _sorted_statistics_to_dict( # Append all statistic entries, and optionally do unit conversion for meta_id, group in groupby(stats, lambda stat: stat.metadata_id): # type: ignore[no-any-return] unit = metadata[meta_id]["unit_of_measurement"] + state_unit = metadata[meta_id]["state_unit_of_measurement"] statistic_id = metadata[meta_id]["statistic_id"] - convert: Callable[[Any, Any], float | None] - if convert_units: - convert = STATISTIC_UNIT_TO_DISPLAY_UNIT_CONVERSIONS.get(unit, lambda x, units: x) # type: ignore[arg-type,no-any-return] + convert: Callable[[Any, Any, Any], float | None] + if unit is not None and convert_units: + convert = STATISTIC_UNIT_TO_DISPLAY_UNIT_CONVERSIONS.get( + unit, no_conversion + ) else: convert = no_conversion ent_results = result[meta_id] @@ -1336,14 +1373,14 @@ def _sorted_statistics_to_dict( "statistic_id": statistic_id, "start": start if start_time_as_datetime else start.isoformat(), "end": end.isoformat(), - "mean": convert(db_state.mean, units), - "min": convert(db_state.min, units), - "max": convert(db_state.max, units), + "mean": convert(db_state.mean, state_unit, units), + "min": convert(db_state.min, state_unit, units), + "max": convert(db_state.max, state_unit, units), "last_reset": process_timestamp_to_utc_isoformat( db_state.last_reset ), - "state": convert(db_state.state, units), - "sum": convert(db_state.sum, units), + "state": convert(db_state.state, state_unit, units), + "sum": convert(db_state.sum, state_unit, units), } ) @@ -1531,7 +1568,7 @@ def adjust_statistics( units = instance.hass.config.units statistic_unit = metadata[statistic_id][1]["unit_of_measurement"] - display_unit = _configured_unit(statistic_unit, units) + display_unit = _configured_unit(statistic_unit, None, units) convert = DISPLAY_UNIT_TO_STATISTIC_UNIT_CONVERSIONS.get(display_unit, lambda x, units: x) # type: ignore[arg-type] sum_adjustment = convert(sum_adjustment, units) diff --git a/homeassistant/components/recorder/websocket_api.py b/homeassistant/components/recorder/websocket_api.py index 70552bca67e..c625620f4c0 100644 --- a/homeassistant/components/recorder/websocket_api.py +++ b/homeassistant/components/recorder/websocket_api.py @@ -219,7 +219,10 @@ async def ws_get_statistics_metadata( def ws_update_statistics_metadata( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict ) -> None: - """Update statistics metadata for a statistic_id.""" + """Update statistics metadata for a statistic_id. + + Only the normalized unit of measurement can be updated. + """ get_instance(hass).async_update_statistics_metadata( msg["statistic_id"], new_unit_of_measurement=msg["unit_of_measurement"] ) @@ -286,6 +289,7 @@ def ws_import_statistics( """Adjust sum statistics.""" metadata = msg["metadata"] stats = msg["stats"] + metadata["state_unit_of_measurement"] = metadata["unit_of_measurement"] if valid_entity_id(metadata["statistic_id"]): async_import_statistics(hass, metadata, stats) diff --git a/homeassistant/components/sensor/recorder.py b/homeassistant/components/sensor/recorder.py index b2542d98738..dfbcc67f80d 100644 --- a/homeassistant/components/sensor/recorder.py +++ b/homeassistant/components/sensor/recorder.py @@ -87,7 +87,7 @@ UNIT_CONVERSIONS: dict[str, dict[str, Callable]] = { ENERGY_MEGA_WATT_HOUR: lambda x: x * 1000, ENERGY_WATT_HOUR: lambda x: x / 1000, }, - # Convert power W + # Convert power to W SensorDeviceClass.POWER: { POWER_WATT: lambda x: x, POWER_KILO_WATT: lambda x: x * 1000, @@ -202,9 +202,9 @@ def _normalize_states( entity_history: Iterable[State], device_class: str | None, entity_id: str, -) -> tuple[str | None, list[tuple[float, State]]]: +) -> tuple[str | None, str | None, list[tuple[float, State]]]: """Normalize units.""" - unit = None + state_unit = None if device_class not in UNIT_CONVERSIONS: # We're not normalizing this device class, return the state as they are @@ -238,9 +238,9 @@ def _normalize_states( extra, LINK_DEV_STATISTICS, ) - return None, [] - unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT) - return unit, fstates + return None, None, [] + state_unit = fstates[0][1].attributes.get(ATTR_UNIT_OF_MEASUREMENT) + return state_unit, state_unit, fstates fstates = [] @@ -249,9 +249,9 @@ def _normalize_states( fstate = _parse_float(state.state) except ValueError: continue - unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) # Exclude unsupported units from statistics - if unit not in UNIT_CONVERSIONS[device_class]: + if state_unit not in UNIT_CONVERSIONS[device_class]: if WARN_UNSUPPORTED_UNIT not in hass.data: hass.data[WARN_UNSUPPORTED_UNIT] = set() if entity_id not in hass.data[WARN_UNSUPPORTED_UNIT]: @@ -259,14 +259,14 @@ def _normalize_states( _LOGGER.warning( "%s has unit %s which is unsupported for device_class %s", entity_id, - unit, + state_unit, device_class, ) continue - fstates.append((UNIT_CONVERSIONS[device_class][unit](fstate), state)) + fstates.append((UNIT_CONVERSIONS[device_class][state_unit](fstate), state)) - return DEVICE_CLASS_UNITS[device_class], fstates + return DEVICE_CLASS_UNITS[device_class], state_unit, fstates def _suggest_report_issue(hass: HomeAssistant, entity_id: str) -> str: @@ -455,7 +455,7 @@ def _compile_statistics( # noqa: C901 device_class = _state.attributes.get(ATTR_DEVICE_CLASS) entity_history = history_list[entity_id] - unit, fstates = _normalize_states( + normalized_unit, state_unit, fstates = _normalize_states( hass, session, old_metadatas, @@ -469,7 +469,9 @@ def _compile_statistics( # noqa: C901 state_class = _state.attributes[ATTR_STATE_CLASS] - to_process.append((entity_id, unit, state_class, fstates)) + to_process.append( + (entity_id, normalized_unit, state_unit, state_class, fstates) + ) if "sum" in wanted_statistics[entity_id]: to_query.append(entity_id) @@ -478,13 +480,14 @@ def _compile_statistics( # noqa: C901 ) for ( # pylint: disable=too-many-nested-blocks entity_id, - unit, + normalized_unit, + state_unit, state_class, fstates, ) in to_process: # Check metadata if old_metadata := old_metadatas.get(entity_id): - if old_metadata[1]["unit_of_measurement"] != unit: + if old_metadata[1]["unit_of_measurement"] != normalized_unit: if WARN_UNSTABLE_UNIT not in hass.data: hass.data[WARN_UNSTABLE_UNIT] = set() if entity_id not in hass.data[WARN_UNSTABLE_UNIT]: @@ -496,7 +499,7 @@ def _compile_statistics( # noqa: C901 "Go to %s to fix this", "normalized " if device_class in DEVICE_CLASS_UNITS else "", entity_id, - unit, + normalized_unit, old_metadata[1]["unit_of_measurement"], old_metadata[1]["unit_of_measurement"], LINK_DEV_STATISTICS, @@ -509,8 +512,9 @@ def _compile_statistics( # noqa: C901 "has_sum": "sum" in wanted_statistics[entity_id], "name": None, "source": RECORDER_DOMAIN, + "state_unit_of_measurement": state_unit, "statistic_id": entity_id, - "unit_of_measurement": unit, + "unit_of_measurement": normalized_unit, } # Make calculations @@ -627,7 +631,7 @@ def list_statistic_ids( for state in entities: state_class = state.attributes[ATTR_STATE_CLASS] device_class = state.attributes.get(ATTR_DEVICE_CLASS) - native_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + state_unit = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) provided_statistics = DEFAULT_STATISTICS[state_class] if statistic_type is not None and statistic_type not in provided_statistics: @@ -649,12 +653,13 @@ def list_statistic_ids( "has_sum": "sum" in provided_statistics, "name": None, "source": RECORDER_DOMAIN, + "state_unit_of_measurement": state_unit, "statistic_id": state.entity_id, - "unit_of_measurement": native_unit, + "unit_of_measurement": state_unit, } continue - if native_unit not in UNIT_CONVERSIONS[device_class]: + if state_unit not in UNIT_CONVERSIONS[device_class]: continue statistics_unit = DEVICE_CLASS_UNITS[device_class] @@ -663,6 +668,7 @@ def list_statistic_ids( "has_sum": "sum" in provided_statistics, "name": None, "source": RECORDER_DOMAIN, + "state_unit_of_measurement": state_unit, "statistic_id": state.entity_id, "unit_of_measurement": statistics_unit, } diff --git a/homeassistant/components/tibber/sensor.py b/homeassistant/components/tibber/sensor.py index ca0c253590f..93fdba107ed 100644 --- a/homeassistant/components/tibber/sensor.py +++ b/homeassistant/components/tibber/sensor.py @@ -642,6 +642,7 @@ class TibberDataCoordinator(DataUpdateCoordinator): has_sum=True, name=f"{home.name} {sensor_type}", source=TIBBER_DOMAIN, + state_unit_of_measurement=unit, statistic_id=statistic_id, unit_of_measurement=unit, ) diff --git a/tests/components/recorder/db_schema_29.py b/tests/components/recorder/db_schema_29.py new file mode 100644 index 00000000000..54aa4b2b13c --- /dev/null +++ b/tests/components/recorder/db_schema_29.py @@ -0,0 +1,616 @@ +"""Models for SQLAlchemy. + +This file contains the model definitions for schema version 28. +It is used to test the schema migration logic. +""" +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime, timedelta +import logging +from typing import Any, TypeVar, cast + +import ciso8601 +from fnvhash import fnv1a_32 +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + Column, + DateTime, + Float, + ForeignKey, + Identity, + Index, + Integer, + SmallInteger, + String, + Text, + distinct, + type_coerce, +) +from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import aliased, declarative_base, relationship +from sqlalchemy.orm.session import Session + +from homeassistant.components.recorder.const import ALL_DOMAIN_EXCLUDE_ATTRS +from homeassistant.components.recorder.models import ( + StatisticData, + StatisticMetaData, + process_timestamp, +) +from homeassistant.const import ( + MAX_LENGTH_EVENT_CONTEXT_ID, + MAX_LENGTH_EVENT_EVENT_TYPE, + MAX_LENGTH_EVENT_ORIGIN, + MAX_LENGTH_STATE_ENTITY_ID, + MAX_LENGTH_STATE_STATE, +) +from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers.json import ( + JSON_DECODE_EXCEPTIONS, + JSON_DUMP, + json_bytes, + json_loads, +) +import homeassistant.util.dt as dt_util + +# SQLAlchemy Schema +# pylint: disable=invalid-name +Base = declarative_base() + +SCHEMA_VERSION = 29 + +_StatisticsBaseSelfT = TypeVar("_StatisticsBaseSelfT", bound="StatisticsBase") + +_LOGGER = logging.getLogger(__name__) + +TABLE_EVENTS = "events" +TABLE_EVENT_DATA = "event_data" +TABLE_STATES = "states" +TABLE_STATE_ATTRIBUTES = "state_attributes" +TABLE_RECORDER_RUNS = "recorder_runs" +TABLE_SCHEMA_CHANGES = "schema_changes" +TABLE_STATISTICS = "statistics" +TABLE_STATISTICS_META = "statistics_meta" +TABLE_STATISTICS_RUNS = "statistics_runs" +TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" + +ALL_TABLES = [ + TABLE_STATES, + TABLE_STATE_ATTRIBUTES, + TABLE_EVENTS, + TABLE_EVENT_DATA, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, + TABLE_STATISTICS, + TABLE_STATISTICS_META, + TABLE_STATISTICS_RUNS, + TABLE_STATISTICS_SHORT_TERM, +] + +TABLES_TO_CHECK = [ + TABLE_STATES, + TABLE_EVENTS, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, +] + +LAST_UPDATED_INDEX = "ix_states_last_updated" +ENTITY_ID_LAST_UPDATED_INDEX = "ix_states_entity_id_last_updated" +EVENTS_CONTEXT_ID_INDEX = "ix_events_context_id" +STATES_CONTEXT_ID_INDEX = "ix_states_context_id" + + +class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): # type: ignore[misc] + """Use ciso8601 to parse datetimes instead of sqlalchemy built-in regex.""" + + def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] + """Offload the datetime parsing to ciso8601.""" + return lambda value: None if value is None else ciso8601.parse_datetime(value) + + +JSON_VARIENT_CAST = Text().with_variant( + postgresql.JSON(none_as_null=True), "postgresql" +) +JSONB_VARIENT_CAST = Text().with_variant( + postgresql.JSONB(none_as_null=True), "postgresql" +) +DATETIME_TYPE = ( + DateTime(timezone=True) + .with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql") + .with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") +) +DOUBLE_TYPE = ( + Float() + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql") + .with_variant(oracle.DOUBLE_PRECISION(), "oracle") + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") +) + + +class JSONLiteral(JSON): # type: ignore[misc] + """Teach SA how to literalize json.""" + + def literal_processor(self, dialect: str) -> Callable[[Any], str]: + """Processor to convert a value to JSON.""" + + def process(value: Any) -> str: + """Dump json.""" + return JSON_DUMP(value) + + return process + + +EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote] +EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)} + + +class Events(Base): # type: ignore[misc,valid-type] + """Event history data.""" + + __table_args__ = ( + # Used for fetching events at a specific time + # see logbook + Index("ix_events_event_type_time_fired", "event_type", "time_fired"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_EVENTS + event_id = Column(Integer, Identity(), primary_key=True) + event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE)) + event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + origin = Column(String(MAX_LENGTH_EVENT_ORIGIN)) # no longer used for new rows + origin_idx = Column(SmallInteger) + time_fired = Column(DATETIME_TYPE, index=True) + context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + data_id = Column(Integer, ForeignKey("event_data.data_id"), index=True) + event_data_rel = relationship("EventData") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> Events: + """Create an event database object from a native event.""" + return Events( + event_type=event.event_type, + event_data=None, + origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin), + time_fired=event.time_fired, + context_id=event.context.id, + context_user_id=event.context.user_id, + context_parent_id=event.context.parent_id, + ) + + def to_native(self, validate_entity_id: bool = True) -> Event | None: + """Convert to a native HA Event.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id, + parent_id=self.context_parent_id, + ) + try: + return Event( + self.event_type, + json_loads(self.event_data) if self.event_data else {}, + EventOrigin(self.origin) + if self.origin + else EVENT_ORIGIN_ORDER[self.origin_idx], + process_timestamp(self.time_fired), + context=context, + ) + except JSON_DECODE_EXCEPTIONS: + # When json_loads fails + _LOGGER.exception("Error converting to event: %s", self) + return None + + +class EventData(Base): # type: ignore[misc,valid-type] + """Event data history.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_EVENT_DATA + data_id = Column(Integer, Identity(), primary_key=True) + hash = Column(BigInteger, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> EventData: + """Create object from an event.""" + shared_data = json_bytes(event.data) + return EventData( + shared_data=shared_data.decode("utf-8"), + hash=EventData.hash_shared_data_bytes(shared_data), + ) + + @staticmethod + def shared_data_bytes_from_event(event: Event) -> bytes: + """Create shared_data from an event.""" + return json_bytes(event.data) + + @staticmethod + def hash_shared_data_bytes(shared_data_bytes: bytes) -> int: + """Return the hash of json encoded shared data.""" + return cast(int, fnv1a_32(shared_data_bytes)) + + def to_native(self) -> dict[str, Any]: + """Convert to an HA state object.""" + try: + return cast(dict[str, Any], json_loads(self.shared_data)) + except JSON_DECODE_EXCEPTIONS: + _LOGGER.exception("Error converting row to event data: %s", self) + return {} + + +class States(Base): # type: ignore[misc,valid-type] + """State change history.""" + + __table_args__ = ( + # Used for fetching the state of entities at a specific time + # (get_states in history.py) + Index(ENTITY_ID_LAST_UPDATED_INDEX, "entity_id", "last_updated"), + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATES + state_id = Column(Integer, Identity(), primary_key=True) + entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID)) + state = Column(String(MAX_LENGTH_STATE_STATE)) + attributes = Column( + Text().with_variant(mysql.LONGTEXT, "mysql") + ) # no longer used for new rows + event_id = Column( # no longer used for new rows + Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True + ) + last_changed = Column(DATETIME_TYPE) + last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True) + old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True) + attributes_id = Column( + Integer, ForeignKey("state_attributes.attributes_id"), index=True + ) + context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True) + context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID)) + origin_idx = Column(SmallInteger) # 0 is local, 1 is remote + old_state = relationship("States", remote_side=[state_id]) + state_attributes = relationship("StateAttributes") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> States: + """Create object from a state_changed event.""" + entity_id = event.data["entity_id"] + state: State | None = event.data.get("new_state") + dbstate = States( + entity_id=entity_id, + attributes=None, + context_id=event.context.id, + context_user_id=event.context.user_id, + context_parent_id=event.context.parent_id, + origin_idx=EVENT_ORIGIN_TO_IDX.get(event.origin), + ) + + # None state means the state was removed from the state machine + if state is None: + dbstate.state = "" + dbstate.last_updated = event.time_fired + dbstate.last_changed = None + return dbstate + + dbstate.state = state.state + dbstate.last_updated = state.last_updated + if state.last_updated == state.last_changed: + dbstate.last_changed = None + else: + dbstate.last_changed = state.last_changed + + return dbstate + + def to_native(self, validate_entity_id: bool = True) -> State | None: + """Convert to an HA state object.""" + context = Context( + id=self.context_id, + user_id=self.context_user_id, + parent_id=self.context_parent_id, + ) + try: + attrs = json_loads(self.attributes) if self.attributes else {} + except JSON_DECODE_EXCEPTIONS: + # When json_loads fails + _LOGGER.exception("Error converting row to state: %s", self) + return None + if self.last_changed is None or self.last_changed == self.last_updated: + last_changed = last_updated = process_timestamp(self.last_updated) + else: + last_updated = process_timestamp(self.last_updated) + last_changed = process_timestamp(self.last_changed) + return State( + self.entity_id, + self.state, + # Join the state_attributes table on attributes_id to get the attributes + # for newer states + attrs, + last_changed, + last_updated, + context=context, + validate_entity_id=validate_entity_id, + ) + + +class StateAttributes(Base): # type: ignore[misc,valid-type] + """State attribute change history.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATE_ATTRIBUTES + attributes_id = Column(Integer, Identity(), primary_key=True) + hash = Column(BigInteger, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql")) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def from_event(event: Event) -> StateAttributes: + """Create object from a state_changed event.""" + state: State | None = event.data.get("new_state") + # None state means the state was removed from the state machine + attr_bytes = b"{}" if state is None else json_bytes(state.attributes) + dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8")) + dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes) + return dbstate + + @staticmethod + def shared_attrs_bytes_from_event( + event: Event, exclude_attrs_by_domain: dict[str, set[str]] + ) -> bytes: + """Create shared_attrs from a state_changed event.""" + state: State | None = event.data.get("new_state") + # None state means the state was removed from the state machine + if state is None: + return b"{}" + domain = split_entity_id(state.entity_id)[0] + exclude_attrs = ( + exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS + ) + return json_bytes( + {k: v for k, v in state.attributes.items() if k not in exclude_attrs} + ) + + @staticmethod + def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int: + """Return the hash of json encoded shared attributes.""" + return cast(int, fnv1a_32(shared_attrs_bytes)) + + def to_native(self) -> dict[str, Any]: + """Convert to an HA state object.""" + try: + return cast(dict[str, Any], json_loads(self.shared_attrs)) + except JSON_DECODE_EXCEPTIONS: + # When json_loads fails + _LOGGER.exception("Error converting row to state attributes: %s", self) + return {} + + +class StatisticsBase: + """Statistics base class.""" + + id = Column(Integer, Identity(), primary_key=True) + created = Column(DATETIME_TYPE, default=dt_util.utcnow) + + @declared_attr # type: ignore[misc] + def metadata_id(self) -> Column: + """Define the metadata_id column for sub classes.""" + return Column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + index=True, + ) + + start = Column(DATETIME_TYPE, index=True) + mean = Column(DOUBLE_TYPE) + min = Column(DOUBLE_TYPE) + max = Column(DOUBLE_TYPE) + last_reset = Column(DATETIME_TYPE) + state = Column(DOUBLE_TYPE) + sum = Column(DOUBLE_TYPE) + + @classmethod + def from_stats( + cls: type[_StatisticsBaseSelfT], metadata_id: int, stats: StatisticData + ) -> _StatisticsBaseSelfT: + """Create object from a statistics.""" + return cls( # type: ignore[call-arg,misc] + metadata_id=metadata_id, + **stats, + ) + + +class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type] + """Long term statistics.""" + + duration = timedelta(hours=1) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True), + ) + __tablename__ = TABLE_STATISTICS + + +class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type] + """Short term statistics.""" + + duration = timedelta(minutes=5) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start", + "metadata_id", + "start", + unique=True, + ), + ) + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticsMeta(Base): # type: ignore[misc,valid-type] + """Statistics meta data.""" + + __table_args__ = ( + {"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"}, + ) + __tablename__ = TABLE_STATISTICS_META + id = Column(Integer, Identity(), primary_key=True) + statistic_id = Column(String(255), index=True, unique=True) + source = Column(String(32)) + unit_of_measurement = Column(String(255)) + has_mean = Column(Boolean) + has_sum = Column(Boolean) + name = Column(String(255)) + + @staticmethod + def from_meta(meta: StatisticMetaData) -> StatisticsMeta: + """Create object from meta data.""" + return StatisticsMeta(**meta) + + +class RecorderRuns(Base): # type: ignore[misc,valid-type] + """Representation of recorder run.""" + + __table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),) + __tablename__ = TABLE_RECORDER_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True), default=dt_util.utcnow) + end = Column(DateTime(timezone=True)) + closed_incorrect = Column(Boolean, default=False) + created = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + end = ( + f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None + ) + return ( + f"" + ) + + def entity_ids(self, point_in_time: datetime | None = None) -> list[str]: + """Return the entity ids that existed in this run. + + Specify point_in_time if you want to know which existed at that point + in time inside the run. + """ + session = Session.object_session(self) + + assert session is not None, "RecorderRuns need to be persisted" + + query = session.query(distinct(States.entity_id)).filter( + States.last_updated >= self.start + ) + + if point_in_time is not None: + query = query.filter(States.last_updated < point_in_time) + elif self.end is not None: + query = query.filter(States.last_updated < self.end) + + return [row[0] for row in query] + + def to_native(self, validate_entity_id: bool = True) -> RecorderRuns: + """Return self, native format is this model.""" + return self + + +class SchemaChanges(Base): # type: ignore[misc,valid-type] + """Representation of schema version changes.""" + + __tablename__ = TABLE_SCHEMA_CHANGES + change_id = Column(Integer, Identity(), primary_key=True) + schema_version = Column(Integer) + changed = Column(DateTime(timezone=True), default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +class StatisticsRuns(Base): # type: ignore[misc,valid-type] + """Representation of statistics run.""" + + __tablename__ = TABLE_STATISTICS_RUNS + run_id = Column(Integer, Identity(), primary_key=True) + start = Column(DateTime(timezone=True), index=True) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +EVENT_DATA_JSON = type_coerce( + EventData.shared_data.cast(JSONB_VARIENT_CAST), JSONLiteral(none_as_null=True) +) +OLD_FORMAT_EVENT_DATA_JSON = type_coerce( + Events.event_data.cast(JSONB_VARIENT_CAST), JSONLiteral(none_as_null=True) +) + +SHARED_ATTRS_JSON = type_coerce( + StateAttributes.shared_attrs.cast(JSON_VARIENT_CAST), JSON(none_as_null=True) +) +OLD_FORMAT_ATTRS_JSON = type_coerce( + States.attributes.cast(JSON_VARIENT_CAST), JSON(none_as_null=True) +) + +ENTITY_ID_IN_EVENT: Column = EVENT_DATA_JSON["entity_id"] +OLD_ENTITY_ID_IN_EVENT: Column = OLD_FORMAT_EVENT_DATA_JSON["entity_id"] +DEVICE_ID_IN_EVENT: Column = EVENT_DATA_JSON["device_id"] +OLD_STATE = aliased(States, name="old_state") diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index c5b4774ab34..cbba4dab26b 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -21,18 +21,23 @@ from sqlalchemy.pool import StaticPool from homeassistant.bootstrap import async_setup_component from homeassistant.components import persistent_notification as pn, recorder from homeassistant.components.recorder import db_schema, migration +from homeassistant.components.recorder.const import SQLITE_URL_PREFIX from homeassistant.components.recorder.db_schema import ( SCHEMA_VERSION, RecorderRuns, States, ) +from homeassistant.components.recorder.statistics import get_start_time from homeassistant.components.recorder.util import session_scope from homeassistant.helpers import recorder as recorder_helper +from homeassistant.setup import setup_component import homeassistant.util.dt as dt_util -from .common import async_wait_recording_done, create_engine_test +from .common import async_wait_recording_done, create_engine_test, wait_recording_done -from tests.common import async_fire_time_changed +from tests.common import async_fire_time_changed, get_test_home_assistant + +ORIG_TZ = dt_util.DEFAULT_TIME_ZONE def _get_native_states(hass, entity_id): @@ -358,6 +363,114 @@ async def test_schema_migrate(hass, start_version, live): assert recorder.util.async_migration_in_progress(hass) is not True +def test_set_state_unit(caplog, tmpdir): + """Test state unit column is initialized.""" + + def _create_engine_29(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + module = "tests.components.recorder.db_schema_29" + importlib.import_module(module) + old_db_schema = sys.modules[module] + engine = create_engine(*args, **kwargs) + old_db_schema.Base.metadata.create_all(engine) + with Session(engine) as session: + session.add(recorder.db_schema.StatisticsRuns(start=get_start_time())) + session.add( + recorder.db_schema.SchemaChanges( + schema_version=old_db_schema.SCHEMA_VERSION + ) + ) + session.commit() + return engine + + test_db_file = tmpdir.mkdir("sqlite").join("test_run_info.db") + dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}" + + module = "tests.components.recorder.db_schema_29" + importlib.import_module(module) + old_db_schema = sys.modules[module] + + external_energy_metadata_1 = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "test", + "statistic_id": "test:total_energy_import_tariff_1", + "unit_of_measurement": "kWh", + } + external_co2_metadata = { + "has_mean": True, + "has_sum": False, + "name": "Fossil percentage", + "source": "test", + "statistic_id": "test:fossil_percentage", + "unit_of_measurement": "%", + } + + # Create some statistics_meta with schema version 29 + with patch.object(recorder, "db_schema", old_db_schema), patch.object( + recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION + ), patch( + "homeassistant.components.recorder.core.create_engine", new=_create_engine_29 + ): + hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + session.add( + recorder.db_schema.StatisticsMeta.from_meta(external_energy_metadata_1) + ) + session.add( + recorder.db_schema.StatisticsMeta.from_meta(external_co2_metadata) + ) + + with session_scope(hass=hass) as session: + tmp = session.query(recorder.db_schema.StatisticsMeta).all() + assert len(tmp) == 2 + assert tmp[0].id == 1 + assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" + assert tmp[0].unit_of_measurement == "kWh" + assert not hasattr(tmp[0], "state_unit_of_measurement") + assert tmp[1].id == 2 + assert tmp[1].statistic_id == "test:fossil_percentage" + assert tmp[1].unit_of_measurement == "%" + assert not hasattr(tmp[1], "state_unit_of_measurement") + + hass.stop() + dt_util.DEFAULT_TIME_ZONE = ORIG_TZ + + # Test that the state_unit column is initialized during migration from schema 28 + hass = get_test_home_assistant() + recorder_helper.async_initialize_recorder(hass) + setup_component(hass, "recorder", {"recorder": {"db_url": dburl}}) + hass.start() + wait_recording_done(hass) + wait_recording_done(hass) + + with session_scope(hass=hass) as session: + tmp = session.query(recorder.db_schema.StatisticsMeta).all() + assert len(tmp) == 2 + assert tmp[0].id == 1 + assert tmp[0].statistic_id == "test:total_energy_import_tariff_1" + assert tmp[0].unit_of_measurement == "kWh" + assert hasattr(tmp[0], "state_unit_of_measurement") + assert tmp[0].state_unit_of_measurement == "kWh" + assert tmp[1].id == 2 + assert tmp[1].statistic_id == "test:fossil_percentage" + assert hasattr(tmp[1], "state_unit_of_measurement") + assert tmp[1].state_unit_of_measurement == "%" + assert tmp[1].state_unit_of_measurement == "%" + + hass.stop() + dt_util.DEFAULT_TIME_ZONE = ORIG_TZ + + def test_invalid_update(hass): """Test that an invalid new version raises an exception.""" with pytest.raises(ValueError): diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 970a7feac61..beb7cef2fb9 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -35,10 +35,9 @@ from homeassistant.helpers import recorder as recorder_helper from homeassistant.setup import setup_component import homeassistant.util.dt as dt_util -from .common import async_wait_recording_done, do_adhoc_statistics +from .common import async_wait_recording_done, do_adhoc_statistics, wait_recording_done from tests.common import get_test_home_assistant, mock_registry -from tests.components.recorder.common import wait_recording_done ORIG_TZ = dt_util.DEFAULT_TIME_ZONE @@ -157,11 +156,12 @@ def mock_sensor_statistics(): """Generate fake statistics.""" return { "meta": { - "statistic_id": entity_id, - "unit_of_measurement": "dogs", "has_mean": True, "has_sum": False, "name": None, + "state_unit_of_measurement": "dogs", + "statistic_id": entity_id, + "unit_of_measurement": "dogs", }, "stat": {"start": start}, } @@ -488,6 +488,7 @@ async def test_import_statistics( "has_sum": True, "name": "Total imported energy", "source": source, + "state_unit_of_measurement": "kWh", "statistic_id": statistic_id, "unit_of_measurement": "kWh", } @@ -542,6 +543,7 @@ async def test_import_statistics( "has_sum": True, "name": "Total imported energy", "source": source, + "state_unit_of_measurement": "kWh", "statistic_id": statistic_id, "unit_of_measurement": "kWh", }, @@ -601,7 +603,7 @@ async def test_import_statistics( ] } - # Update the previously inserted statistics + rename + # Update the previously inserted statistics + rename and change unit external_statistics = { "start": period1, "max": 1, @@ -612,6 +614,7 @@ async def test_import_statistics( "sum": 5, } external_metadata["name"] = "Total imported energy renamed" + external_metadata["state_unit_of_measurement"] = "MWh" import_fn(hass, external_metadata, (external_statistics,)) await async_wait_recording_done(hass) statistic_ids = list_statistic_ids(hass) @@ -635,6 +638,7 @@ async def test_import_statistics( "has_sum": True, "name": "Total imported energy renamed", "source": source, + "state_unit_of_measurement": "MWh", "statistic_id": statistic_id, "unit_of_measurement": "kWh", }, @@ -1051,6 +1055,7 @@ def test_duplicate_statistics_handle_integrity_error(hass_recorder, caplog): "has_sum": True, "name": "Total imported energy", "source": "test", + "state_unit_of_measurement": "kWh", "statistic_id": "test:total_energy_import_tariff_1", "unit_of_measurement": "kWh", } diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index cdec26be26d..1e0633248bd 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -30,21 +30,26 @@ from .common import ( from tests.common import async_fire_time_changed -POWER_SENSOR_ATTRIBUTES = { +POWER_SENSOR_KW_ATTRIBUTES = { "device_class": "power", "state_class": "measurement", "unit_of_measurement": "kW", } -PRESSURE_SENSOR_ATTRIBUTES = { +PRESSURE_SENSOR_HPA_ATTRIBUTES = { "device_class": "pressure", "state_class": "measurement", "unit_of_measurement": "hPa", } -TEMPERATURE_SENSOR_ATTRIBUTES = { +TEMPERATURE_SENSOR_C_ATTRIBUTES = { "device_class": "temperature", "state_class": "measurement", "unit_of_measurement": "°C", } +TEMPERATURE_SENSOR_F_ATTRIBUTES = { + "device_class": "temperature", + "state_class": "measurement", + "unit_of_measurement": "°F", +} ENERGY_SENSOR_ATTRIBUTES = { "device_class": "energy", "state_class": "total", @@ -60,12 +65,14 @@ GAS_SENSOR_ATTRIBUTES = { @pytest.mark.parametrize( "units, attributes, state, value", [ - (IMPERIAL_SYSTEM, POWER_SENSOR_ATTRIBUTES, 10, 10000), - (METRIC_SYSTEM, POWER_SENSOR_ATTRIBUTES, 10, 10000), - (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_ATTRIBUTES, 10, 50), - (METRIC_SYSTEM, TEMPERATURE_SENSOR_ATTRIBUTES, 10, 10), - (IMPERIAL_SYSTEM, PRESSURE_SENSOR_ATTRIBUTES, 1000, 14.503774389728312), - (METRIC_SYSTEM, PRESSURE_SENSOR_ATTRIBUTES, 1000, 100000), + (IMPERIAL_SYSTEM, POWER_SENSOR_KW_ATTRIBUTES, 10, 10), + (METRIC_SYSTEM, POWER_SENSOR_KW_ATTRIBUTES, 10, 10), + (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_C_ATTRIBUTES, 10, 10), + (METRIC_SYSTEM, TEMPERATURE_SENSOR_C_ATTRIBUTES, 10, 10), + (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_F_ATTRIBUTES, 10, 10), + (METRIC_SYSTEM, TEMPERATURE_SENSOR_F_ATTRIBUTES, 10, 10), + (IMPERIAL_SYSTEM, PRESSURE_SENSOR_HPA_ATTRIBUTES, 1000, 1000), + (METRIC_SYSTEM, PRESSURE_SENSOR_HPA_ATTRIBUTES, 1000, 1000), ], ) async def test_statistics_during_period( @@ -129,12 +136,12 @@ async def test_statistics_during_period( @pytest.mark.parametrize( "units, attributes, state, value", [ - (IMPERIAL_SYSTEM, POWER_SENSOR_ATTRIBUTES, 10, 10000), - (METRIC_SYSTEM, POWER_SENSOR_ATTRIBUTES, 10, 10000), - (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_ATTRIBUTES, 10, 50), - (METRIC_SYSTEM, TEMPERATURE_SENSOR_ATTRIBUTES, 10, 10), - (IMPERIAL_SYSTEM, PRESSURE_SENSOR_ATTRIBUTES, 1000, 14.503774389728312), - (METRIC_SYSTEM, PRESSURE_SENSOR_ATTRIBUTES, 1000, 100000), + (IMPERIAL_SYSTEM, POWER_SENSOR_KW_ATTRIBUTES, 10, 10), + (METRIC_SYSTEM, POWER_SENSOR_KW_ATTRIBUTES, 10, 10), + (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_C_ATTRIBUTES, 10, 10), + (METRIC_SYSTEM, TEMPERATURE_SENSOR_C_ATTRIBUTES, 10, 10), + (IMPERIAL_SYSTEM, PRESSURE_SENSOR_HPA_ATTRIBUTES, 1000, 1000), + (METRIC_SYSTEM, PRESSURE_SENSOR_HPA_ATTRIBUTES, 1000, 1000), ], ) async def test_statistics_during_period_in_the_past( @@ -302,12 +309,14 @@ async def test_statistics_during_period_bad_end_time( @pytest.mark.parametrize( "units, attributes, display_unit, statistics_unit", [ - (IMPERIAL_SYSTEM, POWER_SENSOR_ATTRIBUTES, "W", "W"), - (METRIC_SYSTEM, POWER_SENSOR_ATTRIBUTES, "W", "W"), - (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_ATTRIBUTES, "°F", "°C"), - (METRIC_SYSTEM, TEMPERATURE_SENSOR_ATTRIBUTES, "°C", "°C"), - (IMPERIAL_SYSTEM, PRESSURE_SENSOR_ATTRIBUTES, "psi", "Pa"), - (METRIC_SYSTEM, PRESSURE_SENSOR_ATTRIBUTES, "Pa", "Pa"), + (IMPERIAL_SYSTEM, POWER_SENSOR_KW_ATTRIBUTES, "kW", "W"), + (METRIC_SYSTEM, POWER_SENSOR_KW_ATTRIBUTES, "kW", "W"), + (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_C_ATTRIBUTES, "°C", "°C"), + (METRIC_SYSTEM, TEMPERATURE_SENSOR_C_ATTRIBUTES, "°C", "°C"), + (IMPERIAL_SYSTEM, TEMPERATURE_SENSOR_F_ATTRIBUTES, "°F", "°C"), + (METRIC_SYSTEM, TEMPERATURE_SENSOR_F_ATTRIBUTES, "°F", "°C"), + (IMPERIAL_SYSTEM, PRESSURE_SENSOR_HPA_ATTRIBUTES, "hPa", "Pa"), + (METRIC_SYSTEM, PRESSURE_SENSOR_HPA_ATTRIBUTES, "hPa", "Pa"), ], ) async def test_list_statistic_ids( @@ -429,9 +438,9 @@ async def test_clear_statistics(hass, hass_ws_client, recorder_mock): now = dt_util.utcnow() units = METRIC_SYSTEM - attributes = POWER_SENSOR_ATTRIBUTES + attributes = POWER_SENSOR_KW_ATTRIBUTES state = 10 - value = 10000 + value = 10 hass.config.units = units await async_setup_component(hass, "sensor", {}) @@ -555,7 +564,7 @@ async def test_update_statistics_metadata( now = dt_util.utcnow() units = METRIC_SYSTEM - attributes = POWER_SENSOR_ATTRIBUTES + attributes = POWER_SENSOR_KW_ATTRIBUTES state = 10 hass.config.units = units @@ -575,7 +584,7 @@ async def test_update_statistics_metadata( assert response["result"] == [ { "statistic_id": "sensor.test", - "display_unit_of_measurement": "W", + "display_unit_of_measurement": "kW", "has_mean": True, "has_sum": False, "name": None, @@ -602,7 +611,7 @@ async def test_update_statistics_metadata( assert response["result"] == [ { "statistic_id": "sensor.test", - "display_unit_of_measurement": new_unit, + "display_unit_of_measurement": "kW", "has_mean": True, "has_sum": False, "name": None, @@ -1016,6 +1025,7 @@ async def test_import_statistics( "has_sum": True, "name": "Total imported energy", "source": source, + "state_unit_of_measurement": "kWh", "statistic_id": statistic_id, "unit_of_measurement": "kWh", }, diff --git a/tests/components/sensor/test_recorder.py b/tests/components/sensor/test_recorder.py index e7421e6a616..b20b270ee69 100644 --- a/tests/components/sensor/test_recorder.py +++ b/tests/components/sensor/test_recorder.py @@ -84,12 +84,12 @@ def set_time_zone(): ("humidity", "%", "%", "%", 13.050847, -10, 30), ("humidity", None, None, None, 13.050847, -10, 30), ("pressure", "Pa", "Pa", "Pa", 13.050847, -10, 30), - ("pressure", "hPa", "Pa", "Pa", 1305.0847, -1000, 3000), - ("pressure", "mbar", "Pa", "Pa", 1305.0847, -1000, 3000), - ("pressure", "inHg", "Pa", "Pa", 44195.25, -33863.89, 101591.67), - ("pressure", "psi", "Pa", "Pa", 89982.42, -68947.57, 206842.71), + ("pressure", "hPa", "hPa", "Pa", 13.050847, -10, 30), + ("pressure", "mbar", "mbar", "Pa", 13.050847, -10, 30), + ("pressure", "inHg", "inHg", "Pa", 13.050847, -10, 30), + ("pressure", "psi", "psi", "Pa", 13.050847, -10, 30), ("temperature", "°C", "°C", "°C", 13.050847, -10, 30), - ("temperature", "°F", "°C", "°C", -10.52731, -23.33333, -1.111111), + ("temperature", "°F", "°F", "°C", 13.050847, -10, 30), ], ) def test_compile_hourly_statistics( @@ -1513,12 +1513,12 @@ def test_compile_hourly_energy_statistics_multiple(hass_recorder, caplog): ("humidity", "%", 30), ("humidity", None, 30), ("pressure", "Pa", 30), - ("pressure", "hPa", 3000), - ("pressure", "mbar", 3000), - ("pressure", "inHg", 101591.67), - ("pressure", "psi", 206842.71), + ("pressure", "hPa", 30), + ("pressure", "mbar", 30), + ("pressure", "inHg", 30), + ("pressure", "psi", 30), ("temperature", "°C", 30), - ("temperature", "°F", -1.111111), + ("temperature", "°F", 30), ], ) def test_compile_hourly_statistics_unchanged( @@ -1600,12 +1600,12 @@ def test_compile_hourly_statistics_partially_unavailable(hass_recorder, caplog): ("humidity", "%", 30), ("humidity", None, 30), ("pressure", "Pa", 30), - ("pressure", "hPa", 3000), - ("pressure", "mbar", 3000), - ("pressure", "inHg", 101591.67), - ("pressure", "psi", 206842.71), + ("pressure", "hPa", 30), + ("pressure", "mbar", 30), + ("pressure", "inHg", 30), + ("pressure", "psi", 30), ("temperature", "°C", 30), - ("temperature", "°F", -1.111111), + ("temperature", "°F", 30), ], ) def test_compile_hourly_statistics_unavailable( @@ -1685,12 +1685,12 @@ def test_compile_hourly_statistics_fails(hass_recorder, caplog): ("measurement", "gas", "m³", "m³", "m³", "mean"), ("measurement", "gas", "ft³", "m³", "m³", "mean"), ("measurement", "pressure", "Pa", "Pa", "Pa", "mean"), - ("measurement", "pressure", "hPa", "Pa", "Pa", "mean"), - ("measurement", "pressure", "mbar", "Pa", "Pa", "mean"), - ("measurement", "pressure", "inHg", "Pa", "Pa", "mean"), - ("measurement", "pressure", "psi", "Pa", "Pa", "mean"), + ("measurement", "pressure", "hPa", "hPa", "Pa", "mean"), + ("measurement", "pressure", "mbar", "mbar", "Pa", "mean"), + ("measurement", "pressure", "inHg", "inHg", "Pa", "mean"), + ("measurement", "pressure", "psi", "psi", "Pa", "mean"), ("measurement", "temperature", "°C", "°C", "°C", "mean"), - ("measurement", "temperature", "°F", "°C", "°C", "mean"), + ("measurement", "temperature", "°F", "°F", "°C", "mean"), ], ) def test_list_statistic_ids( @@ -2162,13 +2162,21 @@ def test_compile_hourly_statistics_changing_device_class_1( @pytest.mark.parametrize( - "device_class,state_unit,statistic_unit,mean,min,max", + "device_class,state_unit,display_unit,statistic_unit,mean,min,max", [ - ("power", "kW", "W", 13050.847, -10000, 30000), + ("power", "kW", "kW", "W", 13.050847, -10, 30), ], ) def test_compile_hourly_statistics_changing_device_class_2( - hass_recorder, caplog, device_class, state_unit, statistic_unit, mean, min, max + hass_recorder, + caplog, + device_class, + state_unit, + display_unit, + statistic_unit, + mean, + min, + max, ): """Test compiling hourly statistics where device class changes from one hour to the next.""" zero = dt_util.utcnow() @@ -2191,7 +2199,7 @@ def test_compile_hourly_statistics_changing_device_class_2( assert statistic_ids == [ { "statistic_id": "sensor.test1", - "display_unit_of_measurement": statistic_unit, + "display_unit_of_measurement": display_unit, "has_mean": True, "has_sum": False, "name": None, @@ -2240,7 +2248,7 @@ def test_compile_hourly_statistics_changing_device_class_2( assert statistic_ids == [ { "statistic_id": "sensor.test1", - "display_unit_of_measurement": statistic_unit, + "display_unit_of_measurement": display_unit, "has_mean": True, "has_sum": False, "name": None, @@ -2325,6 +2333,7 @@ def test_compile_hourly_statistics_changing_statistics( "has_sum": False, "name": None, "source": "recorder", + "state_unit_of_measurement": None, "statistic_id": "sensor.test1", "unit_of_measurement": None, }, @@ -2360,6 +2369,7 @@ def test_compile_hourly_statistics_changing_statistics( "has_sum": True, "name": None, "source": "recorder", + "state_unit_of_measurement": None, "statistic_id": "sensor.test1", "unit_of_measurement": None, },